Update truncation test to reduce flakyness in cloud tests

Removed dependency on faker, factory for the truncation tests as that
seems to be the point of flakiness
This commit is contained in:
Debanjum Singh Solanky 2024-06-07 19:40:53 +05:30
parent dbb06466bf
commit 5f2442450c

View file

@ -1,26 +1,17 @@
import factory
import tiktoken import tiktoken
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from khoj.processor.conversation import utils from khoj.processor.conversation import utils
class ChatMessageFactory(factory.Factory):
class Meta:
model = ChatMessage
content = factory.Faker("paragraph")
role = factory.Faker("name")
class TestTruncateMessage: class TestTruncateMessage:
max_prompt_size = 4096 max_prompt_size = 10
model_name = "gpt-3.5-turbo" model_name = "gpt-3.5-turbo"
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
def test_truncate_message_all_small(self): def test_truncate_message_all_small(self):
# Arrange # Arrange
chat_history = ChatMessageFactory.build_batch(500) chat_history = generate_chat_history(50)
# Act # Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
@ -28,15 +19,14 @@ class TestTruncateMessage:
# Assert # Assert
# The original object has been modified. Verify certain properties # The original object has been modified. Verify certain properties
assert len(chat_history) < 500 assert len(chat_history) < 50
assert len(chat_history) > 1 assert len(chat_history) > 1
assert tokens <= self.max_prompt_size assert tokens <= self.max_prompt_size
def test_truncate_message_first_large(self): def test_truncate_message_first_large(self):
# Arrange # Arrange
chat_history = ChatMessageFactory.build_batch(25) chat_history = generate_chat_history(5)
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000)) big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?")
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message) chat_history.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
@ -53,10 +43,9 @@ class TestTruncateMessage:
def test_truncate_message_last_large(self): def test_truncate_message_last_large(self):
# Arrange # Arrange
chat_history = ChatMessageFactory.build_batch(25) chat_history = generate_chat_history(5)
chat_history[0].role = "system" # Mark the first message as system message chat_history[0].role = "system" # Mark the first message as system message
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000)) big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message) chat_history.insert(0, big_chat_message)
@ -68,10 +57,10 @@ class TestTruncateMessage:
# Assert # Assert
# The original object has been modified. Verify certain properties. # The original object has been modified. Verify certain properties.
assert len(truncated_chat_history) == ( assert (
len(chat_history) + 1 len(truncated_chat_history) == len(chat_history) + 1
) # Because the system_prompt is popped off from the chat_messages lsit ) # Because the system_prompt is popped off from the chat_messages list
assert len(truncated_chat_history) < 26 assert len(truncated_chat_history) < 10
assert len(truncated_chat_history) > 1 assert len(truncated_chat_history) > 1
assert truncated_chat_history[0] != copy_big_chat_message assert truncated_chat_history[0] != copy_big_chat_message
assert initial_tokens > self.max_prompt_size assert initial_tokens > self.max_prompt_size
@ -79,9 +68,7 @@ class TestTruncateMessage:
def test_truncate_single_large_non_system_message(self): def test_truncate_single_large_non_system_message(self):
# Arrange # Arrange
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000)) big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
big_chat_message.role = "user"
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message] chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
@ -100,8 +87,7 @@ class TestTruncateMessage:
def test_truncate_single_large_question(self): def test_truncate_single_large_question(self):
# Arrange # Arrange
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1)) big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
big_chat_message = ChatMessageFactory.build(content=big_chat_message_content) big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
big_chat_message.role = "user"
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message] chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
@ -116,3 +102,14 @@ class TestTruncateMessage:
assert final_tokens <= self.max_prompt_size assert final_tokens <= self.max_prompt_size
assert len(chat_messages) == 1 assert len(chat_messages) == 1
assert truncated_chat_history[0] != copy_big_chat_message assert truncated_chat_history[0] != copy_big_chat_message
def generate_content(count):
return " ".join([f"{index}" for index, _ in enumerate(range(count))])
def generate_chat_history(count):
return [
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}")
for index, _ in enumerate(range(count))
]