Standardize chat message truncation and serialization before print

Previously chatml messages were just strings, now they can be list of
strings or list of dicts as well.

- Use json seriallization to manage their variations and truncate them
  before printing for context.
- Put logic in single function for use across chat models
This commit is contained in:
Debanjum 2024-11-13 16:30:17 -08:00
parent f4e37209a2
commit 8851b5f78a
5 changed files with 15 additions and 10 deletions

View file

@ -17,6 +17,7 @@ from khoj.processor.conversation.utils import (
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
@ -214,9 +215,7 @@ def converse_anthropic(
)
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
truncated_messages = "\n".join({f"{content[:70]}..." for message in messages for content in message.content})
logger.debug(f"Conversation Context for Claude: {truncated_messages}")
logger.debug(f"Conversation Context for Claude: {messages_to_print(messages)}")
# Get Response from Claude
return anthropic_chat_completion_with_backoff(

View file

@ -17,6 +17,7 @@ from khoj.processor.conversation.utils import (
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
@ -225,9 +226,7 @@ def converse_gemini(
)
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
truncated_messages = "\n".join({f"{content[:70]}..." for message in messages for content in message.content})
logger.debug(f"Conversation Context for Gemini: {truncated_messages}")
logger.debug(f"Conversation Context for Gemini: {messages_to_print(messages)}")
# Get Response from Google AI
return gemini_chat_completion_with_backoff(

View file

@ -15,6 +15,7 @@ from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils import state
from khoj.utils.constants import empty_escape_sequences
@ -222,8 +223,7 @@ def converse_offline(
query_files=query_files,
)
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
logger.debug(f"Conversation Context for {model}: {truncated_messages}")
logger.debug(f"Conversation Context for {model}: {messages_to_print(messages)}")
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))

View file

@ -15,6 +15,7 @@ from khoj.processor.conversation.utils import (
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
messages_to_print,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
@ -212,8 +213,7 @@ def converse(
model_type=ChatModelOptions.ModelType.OPENAI,
query_files=query_files,
)
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
logger.debug(f"Conversation Context for GPT: {messages_to_print(messages)}")
# Get Response from GPT
return chat_completion_with_backoff(

View file

@ -736,3 +736,10 @@ Metadata
except Exception as e:
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}", exc_info=True)
return False
def messages_to_print(messages: list[ChatMessage], max_length: int = 70) -> str:
"""
Format, truncate messages to print
"""
return "\n".join([f"{json.dumps(message.content)[:max_length]}..." for message in messages])