From 48e5ac0169c83e56dbf0cb8ee3c04998c978efe6 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 31 Jul 2023 16:36:31 -0700 Subject: [PATCH 1/2] Do not drop system message when truncating context to max prompt size Previously the system message was getting dropped when the context size with chat history would be more than the max prompt size supported by the cat model Now only the previous chat messages are dropped or the current message is truncated but the system message is kept to provide guidance to the chat model --- src/khoj/processor/conversation/utils.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 5be8e8f7..81cac4b8 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -100,30 +100,34 @@ def generate_chatml_messages_with_context( return messages[::-1] -def truncate_messages(messages, max_prompt_size, model_name): +def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" try: encoder = tiktoken.encoding_for_model(model_name) except KeyError: encoder = tiktoken.encoding_for_model("text-davinci-001") + + system_message = messages.pop() + system_message_tokens = len(encoder.encode(system_message.content)) + tokens = sum([len(encoder.encode(message.content)) for message in messages]) - while tokens > max_prompt_size and len(messages) > 1: + while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1: messages.pop() tokens = sum([len(encoder.encode(message.content)) for message in messages]) - # Truncate last message if still over max supported prompt size by model - if tokens > max_prompt_size: - last_message = "\n".join(messages[-1].content.split("\n")[:-1]) - original_question = "\n".join(messages[-1].content.split("\n")[-1:]) + # Truncate current message if still over max supported prompt size by model + if (tokens + system_message_tokens) > max_prompt_size: + current_message = "\n".join(messages[0].content.split("\n")[:-1]) + original_question = "\n".join(messages[0].content.split("\n")[-1:]) original_question_tokens = len(encoder.encode(original_question)) remaining_tokens = max_prompt_size - original_question_tokens - truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip() + truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip() logger.debug( - f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}" + f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}" ) - messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)] + messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)] - return messages + return messages + [system_message] def reciprocal_conversation_to_chatml(message_pair): From ded606c7cb738de611f11f5e98a5822d1fa3bbb9 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 31 Jul 2023 17:07:20 -0700 Subject: [PATCH 2/2] Fix format of user query during general conversation with Llama 2 --- src/khoj/processor/conversation/gpt4all/chat_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index c9e33c6a..20efa7a1 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -125,7 +125,7 @@ def converse_offline( # Get Conversation Primer appropriate to Conversation Type # TODO If compiled_references_message is too long, we need to truncate it. if compiled_references_message == "": - conversation_primer = prompts.conversation_llamav2.format(query=user_query) + conversation_primer = user_query else: conversation_primer = prompts.notes_conversation_llamav2.format( query=user_query, references=compiled_references_message