diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 5e403c7b..a435f343 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -34,6 +34,7 @@ def extract_questions_anthropic( query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, + tracer: dict = {}, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -89,6 +90,7 @@ def extract_questions_anthropic( model_name=model, temperature=temperature, api_key=api_key, + tracer=tracer, ) # Extract, Clean Message from Claude's Response @@ -110,7 +112,7 @@ def extract_questions_anthropic( return questions -def anthropic_send_message_to_model(messages, api_key, model): +def anthropic_send_message_to_model(messages, api_key, model, tracer={}): """ Send message to model """ @@ -122,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model): system_prompt=system_prompt, model_name=model, api_key=api_key, + tracer=tracer, ) @@ -141,6 +144,7 @@ def converse_anthropic( agent: Agent = None, query_images: Optional[list[str]] = None, vision_available: bool = False, + tracer: dict = {}, ): """ Converse with user using Anthropic's Claude @@ -215,4 +219,5 @@ def converse_anthropic( system_prompt=system_prompt, completion_func=completion_func, max_prompt_size=max_prompt_size, + tracer=tracer, ) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index a4a71a6d..6673555b 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -12,8 +12,13 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url -from khoj.utils.helpers import is_none_or_empty +from khoj.processor.conversation.utils import ( + ThreadedGenerator, + commit_conversation_trace, + get_image_from_url, +) +from khoj.utils import state +from khoj.utils.helpers import in_debug_mode, is_none_or_empty logger = logging.getLogger(__name__) @@ -30,7 +35,7 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000 reraise=True, ) def anthropic_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None + messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={} ) -> str: if api_key not in anthropic_clients: client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) @@ -58,6 +63,12 @@ def anthropic_completion_with_backoff( for text in stream.text_stream: aggregated_response += text + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) + return aggregated_response @@ -78,18 +89,19 @@ def anthropic_chat_completion_with_backoff( max_prompt_size=None, completion_func=None, model_kwargs=None, + tracer={}, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( target=anthropic_llm_thread, - args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs), + args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer), ) t.start() return g def anthropic_llm_thread( - g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None + g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={} ): try: if api_key not in anthropic_clients: @@ -102,6 +114,7 @@ def anthropic_llm_thread( anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages ] + aggregated_response = "" with client.messages.stream( messages=formatted_messages, model=model_name, # type: ignore @@ -112,7 +125,14 @@ def anthropic_llm_thread( **(model_kwargs or dict()), ) as stream: for text in stream.text_stream: + aggregated_response += text g.send(text) + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) finally: