Update truncation test to reduce flakyness in cloud tests

Removed dependency on faker, factory for the truncation tests as that
seems to be the point of flakiness
This commit is contained in:
Debanjum Singh Solanky 2024-06-07 19:40:53 +05:30
parent dbb06466bf
commit 5f2442450c

View file

@ -1,26 +1,17 @@
import factory
import tiktoken
from langchain.schema import ChatMessage
from khoj.processor.conversation import utils
class ChatMessageFactory(factory.Factory):
class Meta:
model = ChatMessage
content = factory.Faker("paragraph")
role = factory.Faker("name")
class TestTruncateMessage:
max_prompt_size = 4096
max_prompt_size = 10
model_name = "gpt-3.5-turbo"
encoder = tiktoken.encoding_for_model(model_name)
def test_truncate_message_all_small(self):
# Arrange
chat_history = ChatMessageFactory.build_batch(500)
chat_history = generate_chat_history(50)
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
@ -28,15 +19,14 @@ class TestTruncateMessage:
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_history) < 500
assert len(chat_history) < 50
assert len(chat_history) > 1
assert tokens <= self.max_prompt_size
def test_truncate_message_first_large(self):
# Arrange
chat_history = ChatMessageFactory.build_batch(25)
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
chat_history = generate_chat_history(5)
big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?")
copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
@ -53,10 +43,9 @@ class TestTruncateMessage:
def test_truncate_message_last_large(self):
# Arrange
chat_history = ChatMessageFactory.build_batch(25)
chat_history = generate_chat_history(5)
chat_history[0].role = "system" # Mark the first message as system message
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message)
@ -68,10 +57,10 @@ class TestTruncateMessage:
# Assert
# The original object has been modified. Verify certain properties.
assert len(truncated_chat_history) == (
len(chat_history) + 1
) # Because the system_prompt is popped off from the chat_messages lsit
assert len(truncated_chat_history) < 26
assert (
len(truncated_chat_history) == len(chat_history) + 1
) # Because the system_prompt is popped off from the chat_messages list
assert len(truncated_chat_history) < 10
assert len(truncated_chat_history) > 1
assert truncated_chat_history[0] != copy_big_chat_message
assert initial_tokens > self.max_prompt_size
@ -79,9 +68,7 @@ class TestTruncateMessage:
def test_truncate_single_large_non_system_message(self):
# Arrange
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
big_chat_message.role = "user"
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
@ -100,8 +87,7 @@ class TestTruncateMessage:
def test_truncate_single_large_question(self):
# Arrange
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
big_chat_message = ChatMessageFactory.build(content=big_chat_message_content)
big_chat_message.role = "user"
big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
@ -116,3 +102,14 @@ class TestTruncateMessage:
assert final_tokens <= self.max_prompt_size
assert len(chat_messages) == 1
assert truncated_chat_history[0] != copy_big_chat_message
def generate_content(count):
return " ".join([f"{index}" for index, _ in enumerate(range(count))])
def generate_chat_history(count):
return [
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}")
for index, _ in enumerate(range(count))
]