diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index e02b8dfc..1047ac3b 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -125,7 +125,7 @@ def converse_offline( # Get Conversation Primer appropriate to Conversation Type # TODO If compiled_references_message is too long, we need to truncate it. if compiled_references_message == "": - conversation_primer = prompts.conversation_llamav2.format(query=user_query) + conversation_primer = user_query else: conversation_primer = prompts.notes_conversation_llamav2.format( query=user_query, references=compiled_references_message diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b739217c..af21db0b 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -102,7 +102,7 @@ def generate_chatml_messages_with_context( return messages[::-1] -def truncate_messages(messages, max_prompt_size, model_name): +def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" if "llama" in model_name: @@ -110,24 +110,27 @@ def truncate_messages(messages, max_prompt_size, model_name): else: encoder = tiktoken.encoding_for_model(model_name) + system_message = messages.pop() + system_message_tokens = len(encoder.encode(system_message.content)) + tokens = sum([len(encoder.encode(message.content)) for message in messages]) - while tokens > max_prompt_size and len(messages) > 1: + while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1: messages.pop() tokens = sum([len(encoder.encode(message.content)) for message in messages]) - # Truncate last message if still over max supported prompt size by model - if tokens > max_prompt_size: - last_message = "\n".join(messages[-1].content.split("\n")[:-1]) - original_question = "\n".join(messages[-1].content.split("\n")[-1:]) + # Truncate current message if still over max supported prompt size by model + if (tokens + system_message_tokens) > max_prompt_size: + current_message = "\n".join(messages[0].content.split("\n")[:-1]) + original_question = "\n".join(messages[0].content.split("\n")[-1:]) original_question_tokens = len(encoder.encode(original_question)) remaining_tokens = max_prompt_size - original_question_tokens - truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip() + truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() logger.debug( - f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}" + f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}" ) - messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)] + messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)] - return messages + return messages + [system_message] def reciprocal_conversation_to_chatml(message_pair):