diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index c60498ee..2a3781b3 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -21,6 +21,7 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"): # Check if the model is already downloaded model_path = load_model_from_cache(repo_id, filename) + chat_model = None try: if model_path: chat_model = Llama(model_path, **kwargs) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index c7c38d46..844a64b8 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -101,8 +101,3 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_ chat(messages=messages) g.close() - - -def extract_summaries(metadata): - """Extract summaries from metadata""" - return "".join([f'\n{session["summary"]}' for session in metadata]) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index ff1ca1e1..845ccb48 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -232,12 +232,17 @@ def truncate_messages( original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else "" original_question = f"\n{original_question}" original_question_tokens = len(encoder.encode(original_question)) - remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens - truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() + remaining_tokens = max_prompt_size - system_message_tokens + if remaining_tokens > original_question_tokens: + remaining_tokens -= original_question_tokens + truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() + messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)] + else: + truncated_message = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip() + messages = [ChatMessage(content=truncated_message, role=messages[0].role)] logger.debug( f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}" ) - messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)] return messages + [system_message] if system_message else messages diff --git a/tests/test_conversation_utils.py b/tests/test_conversation_utils.py index bc8c5315..7172978b 100644 --- a/tests/test_conversation_utils.py +++ b/tests/test_conversation_utils.py @@ -96,3 +96,23 @@ class TestTruncateMessage: assert final_tokens <= self.max_prompt_size assert len(chat_messages) == 1 assert truncated_chat_history[0] != copy_big_chat_message + + def test_truncate_single_large_question(self): + # Arrange + big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1)) + big_chat_message = ChatMessageFactory.build(content=big_chat_message_content) + big_chat_message.role = "user" + copy_big_chat_message = big_chat_message.copy() + chat_messages = [big_chat_message] + initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages]) + + # Act + truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name) + final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history]) + + # Assert + # The original object has been modified. Verify certain properties + assert initial_tokens > self.max_prompt_size + assert final_tokens <= self.max_prompt_size + assert len(chat_messages) == 1 + assert truncated_chat_history[0] != copy_big_chat_message