From 6fcd6a5659e9643905614bb238fc10ef30c8e0df Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 24 Oct 2024 14:15:53 -0700 Subject: [PATCH] Allow Gemini API calling functions to save conversation traces --- .../conversation/google/gemini_chat.py | 10 ++++- .../processor/conversation/google/utils.py | 44 ++++++++++++++----- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 10af8b4d..4ff51c5e 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -35,6 +35,7 @@ def extract_questions_gemini( query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, + tracer: dict = {}, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -85,7 +86,7 @@ def extract_questions_gemini( messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")] response = gemini_send_message_to_model( - messages, api_key, model, response_type="json_object", temperature=temperature + messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer ) # Extract, Clean Message from Gemini's Response @@ -107,7 +108,9 @@ def extract_questions_gemini( return questions -def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None): +def gemini_send_message_to_model( + messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={} +): """ Send message to model """ @@ -125,6 +128,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text", api_key=api_key, temperature=temperature, model_kwargs=model_kwargs, + tracer=tracer, ) @@ -145,6 +149,7 @@ def converse_gemini( agent: Agent = None, query_images: Optional[list[str]] = None, vision_available: bool = False, + tracer={}, ): """ Converse with user using Google's Gemini @@ -219,4 +224,5 @@ def converse_gemini( api_key=api_key, system_prompt=system_prompt, completion_func=completion_func, + tracer=tracer, ) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 964fe80b..7b848324 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -19,8 +19,13 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url -from khoj.utils.helpers import is_none_or_empty +from khoj.processor.conversation.utils import ( + ThreadedGenerator, + commit_conversation_trace, + get_image_from_url, +) +from khoj.utils import state +from khoj.utils.helpers import in_debug_mode, is_none_or_empty logger = logging.getLogger(__name__) @@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192 reraise=True, ) def gemini_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None + messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={} ) -> str: genai.configure(api_key=api_key) model_kwargs = model_kwargs or dict() @@ -60,16 +65,23 @@ def gemini_completion_with_backoff( try: # Generate the response. The last message is considered to be the current prompt - aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"]) - return aggregated_response.text + response = chat_session.send_message(formatted_messages[-1]["parts"]) + response_text = response.text except StopCandidateException as e: - response_message, _ = handle_gemini_response(e.args) + response_text, _ = handle_gemini_response(e.args) # Respond with reason for stopping logger.warning( - f"LLM Response Prevented for {model_name}: {response_message}.\n" + f"LLM Response Prevented for {model_name}: {response_text}.\n" + f"Last Message by {messages[-1].role}: {messages[-1].content}" ) - return response_message + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, response_text, tracer) + + return response_text @retry( @@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff( system_prompt, completion_func=None, model_kwargs=None, + tracer: dict = {}, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( target=gemini_llm_thread, - args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs), + args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer), ) t.start() return g -def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None): +def gemini_llm_thread( + g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {} +): try: genai.configure(api_key=api_key) model_kwargs = model_kwargs or dict() @@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k }, ) + aggregated_response = "" formatted_messages = [{"role": message.role, "parts": message.content} for message in messages] + # all messages up to the last are considered to be part of the chat history chat_session = model.start_chat(history=formatted_messages[0:-1]) # the last message is considered to be the current prompt for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True): message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message = message or chunk.text + aggregated_response += message g.send(message) if stopped: raise StopCandidateException(message) + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) except StopCandidateException as e: logger.warning( f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"