diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index 7172978b..a9e4169b 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -1,26 +1,17 @@ -import factory import tiktoken from langchain.schema import ChatMessage from khoj.processor.conversation import utils -class ChatMessageFactory(factory.Factory): - class Meta: - model = ChatMessage - - content = factory.Faker("paragraph") - role = factory.Faker("name") - - class TestTruncateMessage: - max_prompt_size = 4096 + max_prompt_size = 10 model_name = "gpt-3.5-turbo" encoder = tiktoken.encoding_for_model(model_name) def test_truncate_message_all_small(self): # Arrange - chat_history = ChatMessageFactory.build_batch(500) + chat_history = generate_chat_history(50) # Act truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name) @@ -28,15 +19,14 @@ class TestTruncateMessage: # Assert # The original object has been modified. Verify certain properties - assert len(chat_history) < 500 + assert len(chat_history) < 50 assert len(chat_history) > 1 assert tokens <= self.max_prompt_size def test_truncate_message_first_large(self): # Arrange - chat_history = ChatMessageFactory.build_batch(25) - big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000)) - big_chat_message.content = big_chat_message.content + "\n" + "Question?" + chat_history = generate_chat_history(5) + big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?") copy_big_chat_message = big_chat_message.copy() chat_history.insert(0, big_chat_message) tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history]) @@ -53,10 +43,9 @@ class TestTruncateMessage: def test_truncate_message_last_large(self): # Arrange - chat_history = ChatMessageFactory.build_batch(25) + chat_history = generate_chat_history(5) chat_history[0].role = "system" # Mark the first message as system message - big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000)) - big_chat_message.content = big_chat_message.content + "\n" + "Question?" + big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?") copy_big_chat_message = big_chat_message.copy() chat_history.insert(0, big_chat_message) @@ -68,10 +57,10 @@ class TestTruncateMessage: # Assert # The original object has been modified. Verify certain properties. - assert len(truncated_chat_history) == ( - len(chat_history) + 1 - ) # Because the system_prompt is popped off from the chat_messages lsit - assert len(truncated_chat_history) < 26 + assert ( + len(truncated_chat_history) == len(chat_history) + 1 + ) # Because the system_prompt is popped off from the chat_messages list + assert len(truncated_chat_history) < 10 assert len(truncated_chat_history) > 1 assert truncated_chat_history[0] != copy_big_chat_message assert initial_tokens > self.max_prompt_size @@ -79,9 +68,7 @@ class TestTruncateMessage: def test_truncate_single_large_non_system_message(self): # Arrange - big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000)) - big_chat_message.content = big_chat_message.content + "\n" + "Question?" - big_chat_message.role = "user" + big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?") copy_big_chat_message = big_chat_message.copy() chat_messages = [big_chat_message] initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) @@ -100,8 +87,7 @@ class TestTruncateMessage: def test_truncate_single_large_question(self): # Arrange big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1)) - big_chat_message = ChatMessageFactory.build(content=big_chat_message_content) - big_chat_message.role = "user" + big_chat_message = ChatMessage(role="user", content=big_chat_message_content) copy_big_chat_message = big_chat_message.copy() chat_messages = [big_chat_message] initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) @@ -116,3 +102,14 @@ class TestTruncateMessage: assert final_tokens <= self.max_prompt_size assert len(chat_messages) == 1 assert truncated_chat_history[0] != copy_big_chat_message + + +def generate_content(count): + return " ".join([f"{index}" for index, _ in enumerate(range(count))]) + + +def generate_chat_history(count): + return [ + ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}") + for index, _ in enumerate(range(count)) + ]