From e55e9a7b67e8eb820c5d30efde4a5a5b8359c158 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 31 Jul 2023 21:37:59 -0700 Subject: [PATCH] Fix unit tests and truncation logic --- src/khoj/processor/conversation/utils.py | 2 +- tests/test_conversation_utils.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index af21db0b..2a4a92f2 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -123,7 +123,7 @@ def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) current_message = "\n".join(messages[0].content.split("\n")[:-1]) original_question = "\n".join(messages[0].content.split("\n")[-1:]) original_question_tokens = len(encoder.encode(original_question)) - remaining_tokens = max_prompt_size - original_question_tokens + remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() logger.debug( f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}" diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index ac8a7665..8fc9e127 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -19,7 +19,6 @@ class TestTruncateMessage: def test_truncate_message_all_small(self): chat_messages = ChatMessageFactory.build_batch(500) - tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) @@ -27,7 +26,6 @@ class TestTruncateMessage: # The original object has been modified. Verify certain properties assert len(chat_messages) < 500 assert len(chat_messages) > 1 - assert prompt == chat_messages assert tokens <= self.max_prompt_size def test_truncate_message_first_large(self): @@ -52,14 +50,14 @@ class TestTruncateMessage: big_chat_message.content = big_chat_message.content + "\n" + "Question?" copy_big_chat_message = big_chat_message.copy() - chat_messages.append(big_chat_message) + chat_messages.insert(0, big_chat_message) tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) # The original object has been modified. Verify certain properties - assert len(chat_messages) < 26 - assert len(chat_messages) > 1 + assert len(prompt) < 26 + assert len(prompt) > 1 assert prompt[0] != copy_big_chat_message assert tokens <= self.max_prompt_size