From 8851b5f78a8a654814f6683832b367cdcc5edebf Mon Sep 17 00:00:00 2001 From: Debanjum Date: Wed, 13 Nov 2024 16:30:17 -0800 Subject: [PATCH] Standardize chat message truncation and serialization before print Previously chatml messages were just strings, now they can be list of strings or list of dicts as well. - Use json seriallization to manage their variations and truncate them before printing for context. - Put logic in single function for use across chat models --- .../processor/conversation/anthropic/anthropic_chat.py | 5 ++--- src/khoj/processor/conversation/google/gemini_chat.py | 5 ++--- src/khoj/processor/conversation/offline/chat_model.py | 4 ++-- src/khoj/processor/conversation/openai/gpt.py | 4 ++-- src/khoj/processor/conversation/utils.py | 7 +++++++ 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 88b10465..57fc9b23 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -17,6 +17,7 @@ from khoj.processor.conversation.utils import ( clean_json, construct_structured_message, generate_chatml_messages_with_context, + messages_to_print, ) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -214,9 +215,7 @@ def converse_anthropic( ) messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) - - truncated_messages = "\n".join({f"{content[:70]}..." for message in messages for content in message.content}) - logger.debug(f"Conversation Context for Claude: {truncated_messages}") + logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}") # Get Response from Claude return anthropic_chat_completion_with_backoff( diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 0c25e3e9..5bb74da4 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -17,6 +17,7 @@ from khoj.processor.conversation.utils import ( clean_json, construct_structured_message, generate_chatml_messages_with_context, + messages_to_print, ) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -225,9 +226,7 @@ def converse_gemini( ) messages, system_prompt = format_messages_for_gemini(messages, system_prompt) - - truncated_messages = "\n".join({f"{content[:70]}..." for message in messages for content in message.content}) - logger.debug(f"Conversation Context for Gemini: {truncated_messages}") + logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}") # Get Response from Google AI return gemini_chat_completion_with_backoff( diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index c41c847b..998589dd 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -15,6 +15,7 @@ from khoj.processor.conversation.utils import ( ThreadedGenerator, commit_conversation_trace, generate_chatml_messages_with_context, + messages_to_print, ) from khoj.utils import state from khoj.utils.constants import empty_escape_sequences @@ -222,8 +223,7 @@ def converse_offline( query_files=query_files, ) - truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) - logger.debug(f"Conversation Context for {model}: {truncated_messages}") + logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}") g = ThreadedGenerator(references, online_results, completion_func=completion_func) t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer)) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index c4cb0c67..13b53911 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -15,6 +15,7 @@ from khoj.processor.conversation.utils import ( clean_json, construct_structured_message, generate_chatml_messages_with_context, + messages_to_print, ) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -212,8 +213,7 @@ def converse( model_type=ChatModelOptions.ModelType.OPENAI, query_files=query_files, ) - truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) - logger.debug(f"Conversation Context for GPT: {truncated_messages}") + logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}") # Get Response from GPT return chat_completion_with_backoff( diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 3fb341ea..6de921b1 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -736,3 +736,10 @@ Metadata except Exception as e: logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}", exc_info=True) return False + + +def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str: + """ + Format, truncate messages to print + """ + return "\n".join([f"{json.dumps(message.content)[:max_length]}..." for message in messages])