diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 9cf9952f..a63a09c0 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -102,6 +102,7 @@ def generate_chatml_messages_with_context( # Return message in chronological order return messages[::-1] + def truncate_message(messages, max_prompt_size, model_name): """Truncate messages to fit within max prompt size supported by model""" encoder = tiktoken.encoding_for_model(model_name) @@ -112,8 +113,8 @@ def truncate_message(messages, max_prompt_size, model_name): # Truncate last message if still over max supported prompt size by model if tokens > max_prompt_size: - last_message = '\n'.join(messages[-1].content.split("\n")[:-1]) - original_question = '\n'.join(messages[-1].content.split("\n")[-1:]) + last_message = "\n".join(messages[-1].content.split("\n")[:-1]) + original_question = "\n".join(messages[-1].content.split("\n")[-1:]) original_question_tokens = len(encoder.encode(original_question)) remaining_tokens = max_prompt_size - original_question_tokens truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip() diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index 24e97938..43f68884 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -3,16 +3,18 @@ from langchain.schema import ChatMessage import factory import tiktoken + class ChatMessageFactory(factory.Factory): class Meta: model = ChatMessage - content = factory.Faker('paragraph') - role = factory.Faker('name') + content = factory.Faker("paragraph") + role = factory.Faker("name") + class TestTruncateMessage: max_prompt_size = 4096 - model_name = 'gpt-3.5-turbo' + model_name = "gpt-3.5-turbo" encoder = tiktoken.encoding_for_model(model_name) def test_truncate_message_all_small(self): @@ -33,7 +35,7 @@ class TestTruncateMessage: def test_truncate_message_first_large(self): chat_messages = ChatMessageFactory.build_batch(25) - big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000)) + big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000)) big_chat_message.content = big_chat_message.content + "\n" + "Question?" copy_big_chat_message = big_chat_message.copy() chat_messages.insert(0, big_chat_message) @@ -53,10 +55,10 @@ class TestTruncateMessage: def test_truncate_message_last_large(self): chat_messages = ChatMessageFactory.build_batch(25) - big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000)) + big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000)) big_chat_message.content = big_chat_message.content + "\n" + "Question?" copy_big_chat_message = big_chat_message.copy() - + chat_messages.append(big_chat_message) assert len(chat_messages) == 26 tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) @@ -71,4 +73,3 @@ class TestTruncateMessage: tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) assert tokens < self.max_prompt_size -