diff --git a/pyproject.toml b/pyproject.toml index 93df0b42..12c7789c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,7 @@ dev = [ "mypy >= 1.0.1", "black >= 23.1.0", "pre-commit >= 3.0.4", + "gitpython ~= 3.1.43", ] [tool.hatch.version] diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 77ae17e7..cfefe676 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -34,6 +34,7 @@ def extract_questions_anthropic( query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, + tracer: dict = {}, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -89,6 +90,7 @@ def extract_questions_anthropic( model_name=model, temperature=temperature, api_key=api_key, + tracer=tracer, ) # Extract, Clean Message from Claude's Response @@ -110,7 +112,7 @@ def extract_questions_anthropic( return questions -def anthropic_send_message_to_model(messages, api_key, model): +def anthropic_send_message_to_model(messages, api_key, model, tracer={}): """ Send message to model """ @@ -122,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model): system_prompt=system_prompt, model_name=model, api_key=api_key, + tracer=tracer, ) @@ -142,6 +145,7 @@ def converse_anthropic( agent: Agent = None, query_images: Optional[list[str]] = None, vision_available: bool = False, + tracer: dict = {}, ): """ Converse with user using Anthropic's Claude @@ -220,4 +224,5 @@ def converse_anthropic( system_prompt=system_prompt, completion_func=completion_func, max_prompt_size=max_prompt_size, + tracer=tracer, ) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index a4a71a6d..6673555b 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -12,8 +12,13 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url -from khoj.utils.helpers import is_none_or_empty +from khoj.processor.conversation.utils import ( + ThreadedGenerator, + commit_conversation_trace, + get_image_from_url, +) +from khoj.utils import state +from khoj.utils.helpers import in_debug_mode, is_none_or_empty logger = logging.getLogger(__name__) @@ -30,7 +35,7 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000 reraise=True, ) def anthropic_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None + messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={} ) -> str: if api_key not in anthropic_clients: client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) @@ -58,6 +63,12 @@ def anthropic_completion_with_backoff( for text in stream.text_stream: aggregated_response += text + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) + return aggregated_response @@ -78,18 +89,19 @@ def anthropic_chat_completion_with_backoff( max_prompt_size=None, completion_func=None, model_kwargs=None, + tracer={}, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( target=anthropic_llm_thread, - args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs), + args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer), ) t.start() return g def anthropic_llm_thread( - g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None + g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={} ): try: if api_key not in anthropic_clients: @@ -102,6 +114,7 @@ def anthropic_llm_thread( anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages ] + aggregated_response = "" with client.messages.stream( messages=formatted_messages, model=model_name, # type: ignore @@ -112,7 +125,14 @@ def anthropic_llm_thread( **(model_kwargs or dict()), ) as stream: for text in stream.text_stream: + aggregated_response += text g.send(text) + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) finally: diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index 7cff27f0..f8c9c6b3 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -35,6 +35,7 @@ def extract_questions_gemini( query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, + tracer: dict = {}, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -85,7 +86,7 @@ def extract_questions_gemini( messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")] response = gemini_send_message_to_model( - messages, api_key, model, response_type="json_object", temperature=temperature + messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer ) # Extract, Clean Message from Gemini's Response @@ -107,7 +108,9 @@ def extract_questions_gemini( return questions -def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None): +def gemini_send_message_to_model( + messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={} +): """ Send message to model """ @@ -125,6 +128,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text", api_key=api_key, temperature=temperature, model_kwargs=model_kwargs, + tracer=tracer, ) @@ -146,6 +150,7 @@ def converse_gemini( agent: Agent = None, query_images: Optional[list[str]] = None, vision_available: bool = False, + tracer={}, ): """ Converse with user using Google's Gemini @@ -224,4 +229,5 @@ def converse_gemini( api_key=api_key, system_prompt=system_prompt, completion_func=completion_func, + tracer=tracer, ) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 964fe80b..7b848324 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -19,8 +19,13 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url -from khoj.utils.helpers import is_none_or_empty +from khoj.processor.conversation.utils import ( + ThreadedGenerator, + commit_conversation_trace, + get_image_from_url, +) +from khoj.utils import state +from khoj.utils.helpers import in_debug_mode, is_none_or_empty logger = logging.getLogger(__name__) @@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192 reraise=True, ) def gemini_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None + messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={} ) -> str: genai.configure(api_key=api_key) model_kwargs = model_kwargs or dict() @@ -60,16 +65,23 @@ def gemini_completion_with_backoff( try: # Generate the response. The last message is considered to be the current prompt - aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"]) - return aggregated_response.text + response = chat_session.send_message(formatted_messages[-1]["parts"]) + response_text = response.text except StopCandidateException as e: - response_message, _ = handle_gemini_response(e.args) + response_text, _ = handle_gemini_response(e.args) # Respond with reason for stopping logger.warning( - f"LLM Response Prevented for {model_name}: {response_message}.\n" + f"LLM Response Prevented for {model_name}: {response_text}.\n" + f"Last Message by {messages[-1].role}: {messages[-1].content}" ) - return response_message + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, response_text, tracer) + + return response_text @retry( @@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff( system_prompt, completion_func=None, model_kwargs=None, + tracer: dict = {}, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( target=gemini_llm_thread, - args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs), + args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer), ) t.start() return g -def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None): +def gemini_llm_thread( + g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {} +): try: genai.configure(api_key=api_key) model_kwargs = model_kwargs or dict() @@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k }, ) + aggregated_response = "" formatted_messages = [{"role": message.role, "parts": message.content} for message in messages] + # all messages up to the last are considered to be part of the chat history chat_session = model.start_chat(history=formatted_messages[0:-1]) # the last message is considered to be the current prompt for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True): message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message = message or chunk.text + aggregated_response += message g.send(message) if stopped: raise StopCandidateException(message) + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) except StopCandidateException as e: logger.warning( f"LLM Response Prevented for {model_name}: {e.args[0]}.\n" diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index d9d99f21..2bbb0e2e 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( ThreadedGenerator, + commit_conversation_trace, generate_chatml_messages_with_context, ) from khoj.utils import state from khoj.utils.constants import empty_escape_sequences -from khoj.utils.helpers import ConversationCommand, is_none_or_empty +from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) @@ -34,6 +35,7 @@ def extract_questions_offline( max_prompt_size: int = None, temperature: float = 0.7, personality_context: Optional[str] = None, + tracer: dict = {}, ) -> List[str]: """ Infer search queries to retrieve relevant notes to answer user query @@ -94,6 +96,7 @@ def extract_questions_offline( max_prompt_size=max_prompt_size, temperature=temperature, response_type="json_object", + tracer=tracer, ) finally: state.chat_lock.release() @@ -147,6 +150,7 @@ def converse_offline( location_data: LocationData = None, user_name: str = None, agent: Agent = None, + tracer: dict = {}, ) -> Union[ThreadedGenerator, Iterator[str]]: """ Converse with user using Llama @@ -155,6 +159,7 @@ def converse_offline( assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) compiled_references_message = "\n\n".join({f"{item['compiled']}" for item in references}) + tracer["chat_model"] = model current_date = datetime.now() @@ -218,13 +223,14 @@ def converse_offline( logger.debug(f"Conversation Context for {model}: {truncated_messages}") g = ThreadedGenerator(references, online_results, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size)) + t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer)) t.start() return g -def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None): +def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}): stop_phrases = ["", "INST]", "Notes:"] + aggregated_response = "" state.chat_lock.acquire() try: @@ -232,7 +238,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True ) for response in response_iterator: - g.send(response["choices"][0]["delta"].get("content", "")) + response_delta = response["choices"][0]["delta"].get("content", "") + aggregated_response += response_delta + g.send(response_delta) + + # Save conversation trace + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) + finally: state.chat_lock.release() g.close() @@ -247,6 +260,7 @@ def send_message_to_model_offline( stop=[], max_prompt_size: int = None, response_type: str = "text", + tracer: dict = {}, ): assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) @@ -254,7 +268,17 @@ def send_message_to_model_offline( response = offline_chat_model.create_chat_completion( messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type} ) + if streaming: return response - else: - return response["choices"][0]["message"].get("content", "") + + response_text = response["choices"][0]["message"].get("content", "") + + # Save conversation trace for non-streaming responses + # Streamed responses need to be saved by the calling function + tracer["chat_model"] = model + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, response_text, tracer) + + return response_text diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index e4f71f5b..761d7ac0 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -33,6 +33,7 @@ def extract_questions( query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, + tracer: dict = {}, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -82,7 +83,13 @@ def extract_questions( messages = [ChatMessage(content=prompt, role="user")] response = send_message_to_model( - messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature + messages, + api_key, + model, + response_type="json_object", + api_base_url=api_base_url, + temperature=temperature, + tracer=tracer, ) # Extract, Clean Message from GPT's Response @@ -103,7 +110,9 @@ def extract_questions( return questions -def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0): +def send_message_to_model( + messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {} +): """ Send message to model """ @@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba temperature=temperature, api_base_url=api_base_url, model_kwargs={"response_format": {"type": response_type}}, + tracer=tracer, ) @@ -138,6 +148,7 @@ def converse( agent: Agent = None, query_images: Optional[list[str]] = None, vision_available: bool = False, + tracer: dict = {}, ): """ Converse with user using OpenAI's ChatGPT @@ -214,4 +225,5 @@ def converse( api_base_url=api_base_url, completion_func=completion_func, model_kwargs={"stop": ["Notes:\n["]}, + tracer=tracer, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 878dbb9c..6e519f5a 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -12,7 +12,12 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator +from khoj.processor.conversation.utils import ( + ThreadedGenerator, + commit_conversation_trace, +) +from khoj.utils import state +from khoj.utils.helpers import in_debug_mode logger = logging.getLogger(__name__) @@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {} reraise=True, ) def completion_with_backoff( - messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None + messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {} ) -> str: client_key = f"{openai_api_key}--{api_base_url}" client: openai.OpenAI | None = openai_clients.get(client_key) @@ -77,6 +82,12 @@ def completion_with_backoff( elif delta_chunk.content: aggregated_response += delta_chunk.content + # Save conversation trace + tracer["chat_model"] = model + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) + return aggregated_response @@ -103,26 +114,37 @@ def chat_completion_with_backoff( api_base_url=None, completion_func=None, model_kwargs=None, + tracer: dict = {}, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( - target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs) + target=llm_thread, + args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer), ) t.start() return g -def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None): +def llm_thread( + g, + messages, + model_name, + temperature, + openai_api_key=None, + api_base_url=None, + model_kwargs=None, + tracer: dict = {}, +): try: client_key = f"{openai_api_key}--{api_base_url}" if client_key not in openai_clients: - client: openai.OpenAI = openai.OpenAI( + client = openai.OpenAI( api_key=openai_api_key, base_url=api_base_url, ) openai_clients[client_key] = client else: - client: openai.OpenAI = openai_clients[client_key] + client = openai_clients[client_key] formatted_messages = [{"role": message.role, "content": message.content} for message in messages] stream = True @@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba **(model_kwargs or dict()), ) + aggregated_response = "" if not stream: - g.send(chat.choices[0].message.content) + aggregated_response = chat.choices[0].message.content + g.send(aggregated_response) else: for chunk in chat: if len(chunk.choices) == 0: continue delta_chunk = chunk.choices[0].delta + text_chunk = "" if isinstance(delta_chunk, str): - g.send(delta_chunk) + text_chunk = delta_chunk elif delta_chunk.content: - g.send(delta_chunk.content) + text_chunk = delta_chunk.content + if text_chunk: + aggregated_response += text_chunk + g.send(text_chunk) + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: logger.error(f"Error in llm_thread: {e}", exc_info=True) finally: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index db985432..92192f52 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -2,6 +2,7 @@ import base64 import logging import math import mimetypes +import os import queue from dataclasses import dataclass from datetime import datetime @@ -13,6 +14,8 @@ from typing import Any, Dict, List, Optional import PIL.Image import requests import tiktoken +import yaml +from git import Repo from langchain.schema import ChatMessage from llama_cpp.llama import Llama from transformers import AutoTokenizer @@ -24,7 +27,7 @@ from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter from khoj.utils import state -from khoj.utils.helpers import is_none_or_empty, merge_dicts +from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts logger = logging.getLogger(__name__) model_to_prompt_size = { @@ -178,6 +181,7 @@ def save_to_conversation_log( conversation_id: str = None, automation_id: str = None, query_images: List[str] = None, + tracer: Dict[str, Any] = {}, ): user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") updated_conversation = message_to_log( @@ -204,6 +208,9 @@ def save_to_conversation_log( user_message=q, ) + if in_debug_mode() or state.verbose > 1: + merge_message_into_conversation_trace(q, chat_response, tracer) + logger.info( f""" Saved Conversation Turn @@ -415,3 +422,160 @@ def get_image_from_url(image_url: str, type="pil"): except requests.exceptions.RequestException as e: logger.error(f"Failed to get image from URL {image_url}: {e}") return ImageWithType(content=None, type=None) + + +def commit_conversation_trace( + session: list[ChatMessage], + response: str | list[dict], + tracer: dict, + system_message: str | list[dict] = "", + repo_path: str = "/tmp/khoj_promptrace", +) -> str: + """ + Save trace of conversation step using git. Useful to visualize, compare and debug traces. + Returns the path to the repository. + """ + # Serialize session, system message and response to yaml + system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False) + response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False) + formatted_session = [{"role": message.role, "content": message.content} for message in session] + session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False) + query = ( + yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False) + .strip() + .removeprefix("'") + .removesuffix("'") + ) # Extract serialized query from chat session + + # Extract chat metadata for session + uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid") + + # Infer repository path from environment variable or provided path + repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace" + + try: + # Prepare git repository + os.makedirs(repo_path, exist_ok=True) + repo = Repo.init(repo_path) + + # Remove post-commit hook if it exists + hooks_dir = os.path.join(repo_path, ".git", "hooks") + post_commit_hook = os.path.join(hooks_dir, "post-commit") + if os.path.exists(post_commit_hook): + os.remove(post_commit_hook) + + # Configure git user if not set + if not repo.config_reader().has_option("user", "email"): + repo.config_writer().set_value("user", "name", "Prompt Tracer").release() + repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release() + + # Create an initial commit if the repository is newly created + if not repo.head.is_valid(): + repo.index.commit("And then there was a trace") + + # Check out the initial commit + initial_commit = repo.commit("HEAD~0") + repo.head.reference = initial_commit + repo.head.reset(index=True, working_tree=True) + + # Create or switch to user branch from initial commit + user_branch = f"u_{uid}" + if user_branch not in repo.branches: + repo.create_head(user_branch) + repo.heads[user_branch].checkout() + + # Create or switch to conversation branch from user branch + conv_branch = f"c_{cid}" + if conv_branch not in repo.branches: + repo.create_head(conv_branch) + repo.heads[conv_branch].checkout() + + # Create or switch to message branch from conversation branch + msg_branch = f"m_{mid}" if mid else None + if msg_branch and msg_branch not in repo.branches: + repo.create_head(msg_branch) + repo.heads[msg_branch].checkout() + + # Include file with content to commit + files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml} + + # Write files and stage them + for filename, content in files_to_commit.items(): + file_path = os.path.join(repo_path, filename) + with open(file_path, "w", encoding="utf-8") as f: + f.write(content) + repo.index.add([filename]) + + # Create commit + metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False) + commit_message = f""" +{query[:250]} + +Response: +--- +{response[:500]}... + +Metadata +--- +{metadata_yaml} +""".strip() + + repo.index.commit(commit_message) + + logger.debug(f"Saved conversation trace to repo at {repo_path}") + return repo_path + except Exception as e: + logger.error(f"Failed to add conversation trace to repo: {str(e)}") + return None + + +def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path=None) -> bool: + """ + Merge the message branch into its parent conversation branch. + + Args: + query: User query + response: Assistant response + tracer: Dictionary containing uid, cid and mid + repo_path: Path to the git repository + + Returns: + bool: True if merge was successful, False otherwise + """ + try: + # Infer repository path from environment variable or provided path + repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace" + repo = Repo(repo_path) + + # Extract branch names + msg_branch = f"m_{tracer['mid']}" + conv_branch = f"c_{tracer['cid']}" + + # Checkout conversation branch + repo.heads[conv_branch].checkout() + + # Create commit message + metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False) + commit_message = f""" +{query[:250]} + +Response: +--- +{response[:500]}... + +Metadata +--- +{metadata_yaml} +""".strip() + + # Merge message branch into conversation branch + repo.git.merge(msg_branch, no_ff=True, m=commit_message) + + # Delete message branch after merge + repo.delete_head(msg_branch, force=True) + + logger.debug(f"Successfully merged {msg_branch} into {conv_branch}") + return True + except Exception as e: + logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}") + return False diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 343e44b3..bdc00e09 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -28,6 +28,7 @@ async def text_to_image( send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, agent: Agent = None, + tracer: dict = {}, ): status_code = 200 image = None @@ -68,6 +69,7 @@ async def text_to_image( query_images=query_images, user=user, agent=agent, + tracer=tracer, ) if send_status_func: diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 739d4c70..329ca2ea 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -66,6 +66,7 @@ async def search_online( max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ): query += " ".join(custom_filters) if not is_internet_connected(): @@ -75,7 +76,7 @@ async def search_online( # Breakdown the query into subqueries to get the correct answer subqueries = await generate_online_subqueries( - query, conversation_history, location, user, query_images=query_images, agent=agent + query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer ) response_dict = {} @@ -113,7 +114,7 @@ async def search_online( async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} tasks = [ - read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent) + read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer) for link, data in webpages.items() ] results = await asyncio.gather(*tasks) @@ -155,6 +156,7 @@ async def read_webpages( send_status_func: Optional[Callable] = None, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ): "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") @@ -168,7 +170,7 @@ async def read_webpages( webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} - tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls] + tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls] results = await asyncio.gather(*tasks) response: Dict[str, Dict] = defaultdict(dict) @@ -194,7 +196,12 @@ async def read_webpage( async def read_webpage_and_extract_content( - subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None + subqueries: set[str], + url: str, + content: str = None, + user: KhojUser = None, + agent: Agent = None, + tracer: dict = {}, ) -> Tuple[set[str], str, Union[None, str]]: # Select the web scrapers to use for reading the web page web_scrapers = await ConversationAdapters.aget_enabled_webscrapers() @@ -216,7 +223,9 @@ async def read_webpage_and_extract_content( # Extract relevant information from the web page if is_none_or_empty(extracted_info): with timer(f"Extracting relevant information from web page at '{url}' took", logger): - extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent) + extracted_info = await extract_relevant_info( + subqueries, content, user=user, agent=agent, tracer=tracer + ) # If we successfully extracted information, break the loop if not is_none_or_empty(extracted_info): diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index ecbc494d..40e157ff 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -35,6 +35,7 @@ async def run_code( query_images: List[str] = None, agent: Agent = None, sandbox_url: str = SANDBOX_URL, + tracer: dict = {}, ): # Generate Code if send_status_func: @@ -43,7 +44,14 @@ async def run_code( try: with timer("Chat actor: Generate programs to execute", logger): codes = await generate_python_code( - query, conversation_history, previous_iterations_history, location_data, user, query_images, agent + query, + conversation_history, + previous_iterations_history, + location_data, + user, + query_images, + agent, + tracer, ) except Exception as e: raise ValueError(f"Failed to generate code for {query} with error: {e}") @@ -72,6 +80,7 @@ async def generate_python_code( user: KhojUser, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ) -> List[str]: location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" @@ -98,6 +107,7 @@ async def generate_python_code( query_images=query_images, response_type="json_object", user=user, + tracer=tracer, ) # Validate that the response is a non-empty, JSON-serializable list diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 7db47395..1e3fb092 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -351,6 +351,7 @@ async def extract_references_and_questions( send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, agent: Agent = None, + tracer: dict = {}, ): user = request.user.object if request.user.is_authenticated else None @@ -424,6 +425,7 @@ async def extract_references_and_questions( user=user, max_prompt_size=conversation_config.max_prompt_size, personality_context=personality_context, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = conversation_config.openai_config @@ -441,6 +443,7 @@ async def extract_references_and_questions( query_images=query_images, vision_enabled=vision_enabled, personality_context=personality_context, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key @@ -455,6 +458,7 @@ async def extract_references_and_questions( user=user, vision_enabled=vision_enabled, personality_context=personality_context, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -470,6 +474,7 @@ async def extract_references_and_questions( user=user, vision_enabled=vision_enabled, personality_context=personality_context, + tracer=tracer, ) # Collate search results as context for GPT diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 5d811fe7..e80d215f 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -3,6 +3,7 @@ import base64 import json import logging import time +import uuid from datetime import datetime from functools import partial from typing import Any, Dict, List, Optional @@ -570,6 +571,12 @@ async def chat( event_delimiter = "␃🔚␗" q = unquote(q) nonlocal conversation_id + tracer: dict = { + "mid": f"{uuid.uuid4()}", + "cid": conversation_id, + "uid": user.id, + "khoj_version": state.khoj_version, + } uploaded_images: list[str] = [] if images: @@ -703,6 +710,7 @@ async def chat( user_name=user_name, location=location, file_filters=conversation.file_filters if conversation else [], + tracer=tracer, ): if isinstance(research_result, InformationCollectionIteration): if research_result.summarizedResult: @@ -732,9 +740,12 @@ async def chat( user=user, query_images=uploaded_images, agent=agent, + tracer=tracer, ) - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent) + mode = await aget_relevant_output_modes( + q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer + ) async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): yield result if mode not in conversation_commands: @@ -778,6 +789,7 @@ async def chat( query_images=uploaded_images, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), + tracer=tracer, ): if isinstance(response, dict) and ChatEvent.STATUS in response: yield result[ChatEvent.STATUS] @@ -796,6 +808,7 @@ async def chat( client_application=request.user.client_app, conversation_id=conversation_id, query_images=uploaded_images, + tracer=tracer, ) return @@ -817,7 +830,7 @@ async def chat( if ConversationCommand.Automation in conversation_commands: try: automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log + q, timezone, user, request.url, meta_log, tracer=tracer ) except Exception as e: logger.error(f"Error scheduling task {q} for {user.email}: {e}") @@ -839,6 +852,7 @@ async def chat( inferred_queries=[query_to_run], automation_id=automation.id, query_images=uploaded_images, + tracer=tracer, ) async for result in send_llm_response(llm_response): yield result @@ -860,6 +874,7 @@ async def chat( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -905,6 +920,7 @@ async def chat( custom_filters, query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -930,6 +946,7 @@ async def chat( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -984,6 +1001,7 @@ async def chat( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1010,6 +1028,7 @@ async def chat( send_status_func=partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1040,6 +1059,7 @@ async def chat( compiled_references=compiled_references, online_results=online_results, query_images=uploaded_images, + tracer=tracer, ) content_obj = { "intentType": intent_type, @@ -1068,6 +1088,7 @@ async def chat( user=user, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1095,6 +1116,7 @@ async def chat( compiled_references=compiled_references, online_results=online_results, query_images=uploaded_images, + tracer=tracer, ) async for result in send_llm_response(json.dumps(content_obj)): @@ -1120,6 +1142,7 @@ async def chat( user_name, researched_results, uploaded_images, + tracer, ) # Send Response diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index c9eece4a..0f5a7006 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -291,6 +291,7 @@ async def aget_relevant_information_sources( user: KhojUser, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. @@ -327,6 +328,7 @@ async def aget_relevant_information_sources( relevant_tools_prompt, response_type="json_object", user=user, + tracer=tracer, ) try: @@ -368,6 +370,7 @@ async def aget_relevant_output_modes( user: KhojUser = None, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. @@ -403,7 +406,9 @@ async def aget_relevant_output_modes( ) with timer("Chat actor: Infer output mode for chat response", logger): - response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user) + response = await send_message_to_model_wrapper( + relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer + ) try: response = response.strip() @@ -434,6 +439,7 @@ async def infer_webpage_urls( user: KhojUser, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ) -> List[str]: """ Infer webpage links from the given query @@ -458,7 +464,11 @@ async def infer_webpage_urls( with timer("Chat actor: Infer webpage urls to read", logger): response = await send_message_to_model_wrapper( - online_queries_prompt, query_images=query_images, response_type="json_object", user=user + online_queries_prompt, + query_images=query_images, + response_type="json_object", + user=user, + tracer=tracer, ) # Validate that the response is a non-empty, JSON-serializable list of URLs @@ -481,6 +491,7 @@ async def generate_online_subqueries( user: KhojUser, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ) -> List[str]: """ Generate subqueries from the given query @@ -505,7 +516,11 @@ async def generate_online_subqueries( with timer("Chat actor: Generate online search subqueries", logger): response = await send_message_to_model_wrapper( - online_queries_prompt, query_images=query_images, response_type="json_object", user=user + online_queries_prompt, + query_images=query_images, + response_type="json_object", + user=user, + tracer=tracer, ) # Validate that the response is a non-empty, JSON-serializable list @@ -524,7 +539,7 @@ async def generate_online_subqueries( async def schedule_query( - q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None + q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {} ) -> Tuple[str, ...]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. @@ -537,7 +552,7 @@ async def schedule_query( ) raw_response = await send_message_to_model_wrapper( - crontime_prompt, query_images=query_images, response_type="json_object", user=user + crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer ) # Validate that the response is a non-empty, JSON-serializable list @@ -552,7 +567,7 @@ async def schedule_query( async def extract_relevant_info( - qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None + qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {} ) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus @@ -575,6 +590,7 @@ async def extract_relevant_info( extract_relevant_information, prompts.system_prompt_extract_relevant_information, user=user, + tracer=tracer, ) return response.strip() @@ -586,6 +602,7 @@ async def extract_relevant_summary( query_images: List[str] = None, user: KhojUser = None, agent: Agent = None, + tracer: dict = {}, ) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus @@ -613,6 +630,7 @@ async def extract_relevant_summary( prompts.system_prompt_extract_relevant_summary, user=user, query_images=query_images, + tracer=tracer, ) return response.strip() @@ -625,6 +643,7 @@ async def generate_summary_from_files( query_images: List[str] = None, agent: Agent = None, send_status_func: Optional[Callable] = None, + tracer: dict = {}, ): try: file_object = None @@ -653,6 +672,7 @@ async def generate_summary_from_files( query_images=query_images, user=user, agent=agent, + tracer=tracer, ) response_log = str(response) @@ -673,6 +693,7 @@ async def generate_excalidraw_diagram( user: KhojUser = None, agent: Agent = None, send_status_func: Optional[Callable] = None, + tracer: dict = {}, ): if send_status_func: async for event in send_status_func("**Enhancing the Diagramming Prompt**"): @@ -687,6 +708,7 @@ async def generate_excalidraw_diagram( query_images=query_images, user=user, agent=agent, + tracer=tracer, ) if send_status_func: @@ -697,6 +719,7 @@ async def generate_excalidraw_diagram( q=better_diagram_description_prompt, user=user, agent=agent, + tracer=tracer, ) yield better_diagram_description_prompt, excalidraw_diagram_description @@ -711,6 +734,7 @@ async def generate_better_diagram_description( query_images: List[str] = None, user: KhojUser = None, agent: Agent = None, + tracer: dict = {}, ) -> str: """ Generate a diagram description from the given query and context @@ -748,7 +772,7 @@ async def generate_better_diagram_description( with timer("Chat actor: Generate better diagram description", logger): response = await send_message_to_model_wrapper( - improve_diagram_description_prompt, query_images=query_images, user=user + improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): @@ -761,6 +785,7 @@ async def generate_excalidraw_diagram_from_description( q: str, user: KhojUser = None, agent: Agent = None, + tracer: dict = {}, ) -> str: personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" @@ -772,7 +797,9 @@ async def generate_excalidraw_diagram_from_description( ) with timer("Chat actor: Generate excalidraw diagram", logger): - raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user) + raw_response = await send_message_to_model_wrapper( + message=excalidraw_diagram_generation, user=user, tracer=tracer + ) raw_response = raw_response.strip() raw_response = remove_json_codeblock(raw_response) response: Dict[str, str] = json.loads(raw_response) @@ -793,6 +820,7 @@ async def generate_better_image_prompt( query_images: Optional[List[str]] = None, user: KhojUser = None, agent: Agent = None, + tracer: dict = {}, ) -> str: """ Generate a better image prompt from the given query @@ -839,7 +867,9 @@ async def generate_better_image_prompt( ) with timer("Chat actor: Generate contextual image prompt", logger): - response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user) + response = await send_message_to_model_wrapper( + image_prompt, query_images=query_images, user=user, tracer=tracer + ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): response = response[1:-1] @@ -853,6 +883,7 @@ async def send_message_to_model_wrapper( response_type: str = "text", user: KhojUser = None, query_images: List[str] = None, + tracer: dict = {}, ): conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) vision_available = conversation_config.vision_enabled @@ -899,6 +930,7 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, streaming=False, response_type=response_type, + tracer=tracer, ) elif model_type == ChatModelOptions.ModelType.OPENAI: @@ -922,6 +954,7 @@ async def send_message_to_model_wrapper( model=chat_model, response_type=response_type, api_base_url=api_base_url, + tracer=tracer, ) elif model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key @@ -940,6 +973,7 @@ async def send_message_to_model_wrapper( messages=truncated_messages, api_key=api_key, model=chat_model, + tracer=tracer, ) elif model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -955,7 +989,7 @@ async def send_message_to_model_wrapper( ) return gemini_send_message_to_model( - messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type + messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer ) else: raise HTTPException(status_code=500, detail="Invalid conversation config") @@ -966,6 +1000,7 @@ def send_message_to_model_wrapper_sync( system_message: str = "", response_type: str = "text", user: KhojUser = None, + tracer: dict = {}, ): conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) @@ -998,6 +1033,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, streaming=False, response_type=response_type, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: @@ -1012,7 +1048,11 @@ def send_message_to_model_wrapper_sync( ) openai_response = send_message_to_model( - messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type + messages=truncated_messages, + api_key=api_key, + model=chat_model, + response_type=response_type, + tracer=tracer, ) return openai_response @@ -1032,6 +1072,7 @@ def send_message_to_model_wrapper_sync( messages=truncated_messages, api_key=api_key, model=chat_model, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: @@ -1050,6 +1091,7 @@ def send_message_to_model_wrapper_sync( api_key=api_key, model=chat_model, response_type=response_type, + tracer=tracer, ) else: raise HTTPException(status_code=500, detail="Invalid conversation config") @@ -1071,6 +1113,7 @@ def generate_chat_response( user_name: Optional[str] = None, meta_research: str = "", query_images: Optional[List[str]] = None, + tracer: dict = {}, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables chat_response = None @@ -1094,6 +1137,7 @@ def generate_chat_response( client_application=client_application, conversation_id=conversation_id, query_images=query_images, + tracer=tracer, ) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) @@ -1120,6 +1164,7 @@ def generate_chat_response( location_data=location_data, user_name=user_name, agent=agent, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: @@ -1144,6 +1189,7 @@ def generate_chat_response( user_name=user_name, agent=agent, vision_available=vision_available, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: @@ -1165,6 +1211,7 @@ def generate_chat_response( user_name=user_name, agent=agent, vision_available=vision_available, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -1184,6 +1231,7 @@ def generate_chat_response( user_name=user_name, agent=agent, vision_available=vision_available, + tracer=tracer, ) metadata.update({"chat_model": conversation_config.chat_model}) @@ -1540,9 +1588,15 @@ def scheduled_chat( async def create_automation( - q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None + q: str, + timezone: str, + user: KhojUser, + calling_url: URL, + meta_log: dict = {}, + conversation_id: str = None, + tracer: dict = {}, ): - crontime, query_to_run, subject = await schedule_query(q, meta_log, user) + crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer) job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id) return job, crontime, query_to_run, subject diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 1beb1f69..8221fd5c 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -45,6 +45,7 @@ async def apick_next_tool( previous_iterations_history: str = None, max_iterations: int = 5, send_status_func: Optional[Callable] = None, + tracer: dict = {}, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer. @@ -93,6 +94,7 @@ async def apick_next_tool( response_type="json_object", user=user, query_images=query_images, + tracer=tracer, ) try: @@ -135,6 +137,7 @@ async def execute_information_collection( user_name: str = None, location: LocationData = None, file_filters: List[str] = [], + tracer: dict = {}, ): current_iteration = 0 MAX_ITERATIONS = 5 @@ -159,6 +162,7 @@ async def execute_information_collection( previous_iterations_history, MAX_ITERATIONS, send_status_func, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -180,6 +184,7 @@ async def execute_information_collection( send_status_func, query_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -211,6 +216,7 @@ async def execute_information_collection( max_webpages_to_read=0, query_images=query_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -228,6 +234,7 @@ async def execute_information_collection( send_status_func, query_images=query_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -258,6 +265,7 @@ async def execute_information_collection( send_status_func, query_images=query_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS]