diff --git a/pyproject.toml b/pyproject.toml index 93df0b42..12c7789c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,7 @@ dev = [ "mypy >= 1.0.1", "black >= 23.1.0", "pre-commit >= 3.0.4", + "gitpython ~= 3.1.43", ] [tool.hatch.version] diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index 493ca366..268e21aa 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -34,6 +34,7 @@ def extract_questions_anthropic( query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, + tracer: dict = {}, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -89,6 +90,7 @@ def extract_questions_anthropic( model_name=model, temperature=temperature, api_key=api_key, + tracer=tracer, ) # Extract, Clean Message from Claude's Response @@ -110,7 +112,7 @@ def extract_questions_anthropic( return questions -def anthropic_send_message_to_model(messages, api_key, model): +def anthropic_send_message_to_model(messages, api_key, model, tracer={}): """ Send message to model """ @@ -122,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model): system_prompt=system_prompt, model_name=model, api_key=api_key, + tracer=tracer, ) @@ -141,6 +144,7 @@ def converse_anthropic( agent: Agent = None, query_images: Optional[list[str]] = None, vision_available: bool = False, + tracer: dict = {}, ): """ Converse with user using Anthropic's Claude @@ -213,4 +217,5 @@ def converse_anthropic( system_prompt=system_prompt, completion_func=completion_func, max_prompt_size=max_prompt_size, + tracer=tracer, ) diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index a4a71a6d..6673555b 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -12,8 +12,13 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url -from khoj.utils.helpers import is_none_or_empty +from khoj.processor.conversation.utils import ( + ThreadedGenerator, + commit_conversation_trace, + get_image_from_url, +) +from khoj.utils import state +from khoj.utils.helpers import in_debug_mode, is_none_or_empty logger = logging.getLogger(__name__) @@ -30,7 +35,7 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000 reraise=True, ) def anthropic_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None + messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={} ) -> str: if api_key not in anthropic_clients: client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key) @@ -58,6 +63,12 @@ def anthropic_completion_with_backoff( for text in stream.text_stream: aggregated_response += text + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) + return aggregated_response @@ -78,18 +89,19 @@ def anthropic_chat_completion_with_backoff( max_prompt_size=None, completion_func=None, model_kwargs=None, + tracer={}, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( target=anthropic_llm_thread, - args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs), + args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer), ) t.start() return g def anthropic_llm_thread( - g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None + g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={} ): try: if api_key not in anthropic_clients: @@ -102,6 +114,7 @@ def anthropic_llm_thread( anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages ] + aggregated_response = "" with client.messages.stream( messages=formatted_messages, model=model_name, # type: ignore @@ -112,7 +125,14 @@ def anthropic_llm_thread( **(model_kwargs or dict()), ) as stream: for text in stream.text_stream: + aggregated_response += text g.send(text) + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) finally: diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index b3f89031..ae33d40d 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -35,6 +35,7 @@ def extract_questions_gemini( query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, + tracer: dict = {}, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -85,7 +86,7 @@ def extract_questions_gemini( messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")] response = gemini_send_message_to_model( - messages, api_key, model, response_type="json_object", temperature=temperature + messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer ) # Extract, Clean Message from Gemini's Response @@ -107,7 +108,9 @@ def extract_questions_gemini( return questions -def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None): +def gemini_send_message_to_model( + messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={} +): """ Send message to model """ @@ -125,6 +128,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text", api_key=api_key, temperature=temperature, model_kwargs=model_kwargs, + tracer=tracer, ) @@ -145,6 +149,7 @@ def converse_gemini( agent: Agent = None, query_images: Optional[list[str]] = None, vision_available: bool = False, + tracer={}, ): """ Converse with user using Google's Gemini @@ -217,4 +222,5 @@ def converse_gemini( api_key=api_key, system_prompt=system_prompt, completion_func=completion_func, + tracer=tracer, ) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 964fe80b..7b848324 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -19,8 +19,13 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url -from khoj.utils.helpers import is_none_or_empty +from khoj.processor.conversation.utils import ( + ThreadedGenerator, + commit_conversation_trace, + get_image_from_url, +) +from khoj.utils import state +from khoj.utils.helpers import in_debug_mode, is_none_or_empty logger = logging.getLogger(__name__) @@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192 reraise=True, ) def gemini_completion_with_backoff( - messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None + messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={} ) -> str: genai.configure(api_key=api_key) model_kwargs = model_kwargs or dict() @@ -60,16 +65,23 @@ def gemini_completion_with_backoff( try: # Generate the response. The last message is considered to be the current prompt - aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"]) - return aggregated_response.text + response = chat_session.send_message(formatted_messages[-1]["parts"]) + response_text = response.text except StopCandidateException as e: - response_message, _ = handle_gemini_response(e.args) + response_text, _ = handle_gemini_response(e.args) # Respond with reason for stopping logger.warning( - f"LLM Response Prevented for {model_name}: {response_message}.\n" + f"LLM Response Prevented for {model_name}: {response_text}.\n" + f"Last Message by {messages[-1].role}: {messages[-1].content}" ) - return response_message + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, response_text, tracer) + + return response_text @retry( @@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff( system_prompt, completion_func=None, model_kwargs=None, + tracer: dict = {}, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( target=gemini_llm_thread, - args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs), + args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer), ) t.start() return g -def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None): +def gemini_llm_thread( + g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {} +): try: genai.configure(api_key=api_key) model_kwargs = model_kwargs or dict() @@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k }, ) + aggregated_response = "" formatted_messages = [{"role": message.role, "parts": message.content} for message in messages] + # all messages up to the last are considered to be part of the chat history chat_session = model.start_chat(history=formatted_messages[0:-1]) # the last message is considered to be the current prompt for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True): message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message = message or chunk.text + aggregated_response += message g.send(message) if stopped: raise StopCandidateException(message) + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) except StopCandidateException as e: logger.warning( f"LLM Response Prevented for {model_name}: {e.args[0]}.\n" diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index d9cbd507..2d2354ed 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( ThreadedGenerator, + commit_conversation_trace, generate_chatml_messages_with_context, ) from khoj.utils import state from khoj.utils.constants import empty_escape_sequences -from khoj.utils.helpers import ConversationCommand, is_none_or_empty +from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty from khoj.utils.rawconfig import LocationData logger = logging.getLogger(__name__) @@ -34,6 +35,7 @@ def extract_questions_offline( max_prompt_size: int = None, temperature: float = 0.7, personality_context: Optional[str] = None, + tracer: dict = {}, ) -> List[str]: """ Infer search queries to retrieve relevant notes to answer user query @@ -94,6 +96,7 @@ def extract_questions_offline( max_prompt_size=max_prompt_size, temperature=temperature, response_type="json_object", + tracer=tracer, ) finally: state.chat_lock.release() @@ -146,6 +149,7 @@ def converse_offline( location_data: LocationData = None, user_name: str = None, agent: Agent = None, + tracer: dict = {}, ) -> Union[ThreadedGenerator, Iterator[str]]: """ Converse with user using Llama @@ -153,8 +157,9 @@ def converse_offline( # Initialize Variables assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) - compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references}) + tracer["chat_model"] = model + compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references}) current_date = datetime.now() if agent and agent.personality: @@ -215,13 +220,14 @@ def converse_offline( logger.debug(f"Conversation Context for {model}: {truncated_messages}") g = ThreadedGenerator(references, online_results, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size)) + t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer)) t.start() return g -def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None): +def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}): stop_phrases = ["", "INST]", "Notes:"] + aggregated_response = "" state.chat_lock.acquire() try: @@ -229,7 +235,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True ) for response in response_iterator: - g.send(response["choices"][0]["delta"].get("content", "")) + response_delta = response["choices"][0]["delta"].get("content", "") + aggregated_response += response_delta + g.send(response_delta) + + # Save conversation trace + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) + finally: state.chat_lock.release() g.close() @@ -244,6 +257,7 @@ def send_message_to_model_offline( stop=[], max_prompt_size: int = None, response_type: str = "text", + tracer: dict = {}, ): assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) @@ -251,7 +265,17 @@ def send_message_to_model_offline( response = offline_chat_model.create_chat_completion( messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type} ) + if streaming: return response - else: - return response["choices"][0]["message"].get("content", "") + + response_text = response["choices"][0]["message"].get("content", "") + + # Save conversation trace for non-streaming responses + # Streamed responses need to be saved by the calling function + tracer["chat_model"] = model + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, response_text, tracer) + + return response_text diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index fa845ec1..3c4552d9 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -33,6 +33,7 @@ def extract_questions( query_images: Optional[list[str]] = None, vision_enabled: bool = False, personality_context: Optional[str] = None, + tracer: dict = {}, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -82,7 +83,13 @@ def extract_questions( messages = [ChatMessage(content=prompt, role="user")] response = send_message_to_model( - messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature + messages, + api_key, + model, + response_type="json_object", + api_base_url=api_base_url, + temperature=temperature, + tracer=tracer, ) # Extract, Clean Message from GPT's Response @@ -103,7 +110,9 @@ def extract_questions( return questions -def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0): +def send_message_to_model( + messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {} +): """ Send message to model """ @@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba temperature=temperature, api_base_url=api_base_url, model_kwargs={"response_format": {"type": response_type}}, + tracer=tracer, ) @@ -137,6 +147,7 @@ def converse( agent: Agent = None, query_images: Optional[list[str]] = None, vision_available: bool = False, + tracer: dict = {}, ): """ Converse with user using OpenAI's ChatGPT @@ -207,4 +218,5 @@ def converse( api_base_url=api_base_url, completion_func=completion_func, model_kwargs={"stop": ["Notes:\n["]}, + tracer=tracer, ) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 878dbb9c..6e519f5a 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -12,7 +12,12 @@ from tenacity import ( wait_random_exponential, ) -from khoj.processor.conversation.utils import ThreadedGenerator +from khoj.processor.conversation.utils import ( + ThreadedGenerator, + commit_conversation_trace, +) +from khoj.utils import state +from khoj.utils.helpers import in_debug_mode logger = logging.getLogger(__name__) @@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {} reraise=True, ) def completion_with_backoff( - messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None + messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {} ) -> str: client_key = f"{openai_api_key}--{api_base_url}" client: openai.OpenAI | None = openai_clients.get(client_key) @@ -77,6 +82,12 @@ def completion_with_backoff( elif delta_chunk.content: aggregated_response += delta_chunk.content + # Save conversation trace + tracer["chat_model"] = model + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) + return aggregated_response @@ -103,26 +114,37 @@ def chat_completion_with_backoff( api_base_url=None, completion_func=None, model_kwargs=None, + tracer: dict = {}, ): g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func) t = Thread( - target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs) + target=llm_thread, + args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer), ) t.start() return g -def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None): +def llm_thread( + g, + messages, + model_name, + temperature, + openai_api_key=None, + api_base_url=None, + model_kwargs=None, + tracer: dict = {}, +): try: client_key = f"{openai_api_key}--{api_base_url}" if client_key not in openai_clients: - client: openai.OpenAI = openai.OpenAI( + client = openai.OpenAI( api_key=openai_api_key, base_url=api_base_url, ) openai_clients[client_key] = client else: - client: openai.OpenAI = openai_clients[client_key] + client = openai_clients[client_key] formatted_messages = [{"role": message.role, "content": message.content} for message in messages] stream = True @@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba **(model_kwargs or dict()), ) + aggregated_response = "" if not stream: - g.send(chat.choices[0].message.content) + aggregated_response = chat.choices[0].message.content + g.send(aggregated_response) else: for chunk in chat: if len(chunk.choices) == 0: continue delta_chunk = chunk.choices[0].delta + text_chunk = "" if isinstance(delta_chunk, str): - g.send(delta_chunk) + text_chunk = delta_chunk elif delta_chunk.content: - g.send(delta_chunk.content) + text_chunk = delta_chunk.content + if text_chunk: + aggregated_response += text_chunk + g.send(text_chunk) + + # Save conversation trace + tracer["chat_model"] = model_name + tracer["temperature"] = temperature + if in_debug_mode() or state.verbose > 1: + commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: logger.error(f"Error in llm_thread: {e}", exc_info=True) finally: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 36e681c0..bc7a7858 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -2,6 +2,7 @@ import base64 import logging import math import mimetypes +import os import queue from dataclasses import dataclass from datetime import datetime @@ -12,6 +13,8 @@ from typing import Any, Dict, List, Optional import PIL.Image import requests import tiktoken +import yaml +from git import Repo from langchain.schema import ChatMessage from llama_cpp.llama import Llama from transformers import AutoTokenizer @@ -21,7 +24,7 @@ from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.utils import state -from khoj.utils.helpers import is_none_or_empty, merge_dicts +from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts logger = logging.getLogger(__name__) model_to_prompt_size = { @@ -117,6 +120,7 @@ def save_to_conversation_log( conversation_id: str = None, automation_id: str = None, query_images: List[str] = None, + tracer: Dict[str, Any] = {}, ): user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") updated_conversation = message_to_log( @@ -142,6 +146,9 @@ def save_to_conversation_log( user_message=q, ) + if in_debug_mode() or state.verbose > 1: + merge_message_into_conversation_trace(q, chat_response, tracer) + logger.info( f""" Saved Conversation Turn @@ -354,3 +361,163 @@ def get_image_from_url(image_url: str, type="pil"): except requests.exceptions.RequestException as e: logger.error(f"Failed to get image from URL {image_url}: {e}") return ImageWithType(content=None, type=None) + + +def commit_conversation_trace( + session: list[ChatMessage], + response: str | list[dict], + tracer: dict, + system_message: str | list[dict] = "", + repo_path: str = "/tmp/promptrace", +) -> str: + """ + Save trace of conversation step using git. Useful to visualize, compare and debug traces. + Returns the path to the repository. + """ + # Serialize session, system message and response to yaml + system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False) + response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False) + formatted_session = [{"role": message.role, "content": message.content} for message in session] + session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False) + query = ( + yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False) + .strip() + .removeprefix("'") + .removesuffix("'") + ) # Extract serialized query from chat session + + # Extract chat metadata for session + uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid") + + # Infer repository path from environment variable or provided path + repo_path = os.getenv("PROMPTRACE_DIR", repo_path) + + try: + # Prepare git repository + os.makedirs(repo_path, exist_ok=True) + repo = Repo.init(repo_path) + + # Remove post-commit hook if it exists + hooks_dir = os.path.join(repo_path, ".git", "hooks") + post_commit_hook = os.path.join(hooks_dir, "post-commit") + if os.path.exists(post_commit_hook): + os.remove(post_commit_hook) + + # Configure git user if not set + if not repo.config_reader().has_option("user", "email"): + repo.config_writer().set_value("user", "name", "Prompt Tracer").release() + repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release() + + # Create an initial commit if the repository is newly created + if not repo.head.is_valid(): + repo.index.commit("And then there was a trace") + + # Check out the initial commit + initial_commit = repo.commit("HEAD~0") + repo.head.reference = initial_commit + repo.head.reset(index=True, working_tree=True) + + # Create or switch to user branch from initial commit + user_branch = f"u_{uid}" + if user_branch not in repo.branches: + repo.create_head(user_branch) + repo.heads[user_branch].checkout() + + # Create or switch to conversation branch from user branch + conv_branch = f"c_{cid}" + if conv_branch not in repo.branches: + repo.create_head(conv_branch) + repo.heads[conv_branch].checkout() + + # Create or switch to message branch from conversation branch + msg_branch = f"m_{mid}" if mid else None + if msg_branch and msg_branch not in repo.branches: + repo.create_head(msg_branch) + if msg_branch: + repo.heads[msg_branch].checkout() + + # Include file with content to commit + files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml} + + # Write files and stage them + for filename, content in files_to_commit.items(): + file_path = os.path.join(repo_path, filename) + # Unescape special characters in content for better readability + content = content.strip().replace("\\n", "\n").replace("\\t", "\t") + with open(file_path, "w", encoding="utf-8") as f: + f.write(content) + repo.index.add([filename]) + + # Create commit + metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False) + commit_message = f""" +{query[:250]} + +Response: +--- +{response[:500]}... + +Metadata +--- +{metadata_yaml} +""".strip() + + repo.index.commit(commit_message) + + logger.debug(f"Saved conversation trace to repo at {repo_path}") + return repo_path + except Exception as e: + logger.error(f"Failed to add conversation trace to repo: {str(e)}", exc_info=True) + return None + + +def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path="/tmp/promptrace") -> bool: + """ + Merge the message branch into its parent conversation branch. + + Args: + query: User query + response: Assistant response + tracer: Dictionary containing uid, cid and mid + repo_path: Path to the git repository + + Returns: + bool: True if merge was successful, False otherwise + """ + try: + # Extract branch names + msg_branch = f"m_{tracer['mid']}" + conv_branch = f"c_{tracer['cid']}" + + # Infer repository path from environment variable or provided path + repo_path = os.getenv("PROMPTRACE_DIR", repo_path) + repo = Repo(repo_path) + + # Checkout conversation branch + repo.heads[conv_branch].checkout() + + # Create commit message + metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False) + commit_message = f""" +{query[:250]} + +Response: +--- +{response[:500]}... + +Metadata +--- +{metadata_yaml} +""".strip() + + # Merge message branch into conversation branch + repo.git.merge(msg_branch, no_ff=True, m=commit_message) + + # Delete message branch after merge + repo.delete_head(msg_branch, force=True) + + logger.debug(f"Successfully merged {msg_branch} into {conv_branch}") + return True + except Exception as e: + logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}", exc_info=True) + return False diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 343e44b3..bdc00e09 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -28,6 +28,7 @@ async def text_to_image( send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, agent: Agent = None, + tracer: dict = {}, ): status_code = 200 image = None @@ -68,6 +69,7 @@ async def text_to_image( query_images=query_images, user=user, agent=agent, + tracer=tracer, ) if send_status_func: diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index fdf1ba9f..9afb3d67 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -64,6 +64,7 @@ async def search_online( custom_filters: List[str] = [], query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ): query += " ".join(custom_filters) if not is_internet_connected(): @@ -73,7 +74,7 @@ async def search_online( # Breakdown the query into subqueries to get the correct answer subqueries = await generate_online_subqueries( - query, conversation_history, location, user, query_images=query_images, agent=agent + query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer ) response_dict = {} @@ -111,7 +112,7 @@ async def search_online( async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} tasks = [ - read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent) + read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer) for link, data in webpages.items() ] results = await asyncio.gather(*tasks) @@ -153,6 +154,7 @@ async def read_webpages( send_status_func: Optional[Callable] = None, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ): "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") @@ -166,7 +168,7 @@ async def read_webpages( webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} - tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls] + tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls] results = await asyncio.gather(*tasks) response: Dict[str, Dict] = defaultdict(dict) @@ -192,7 +194,12 @@ async def read_webpage( async def read_webpage_and_extract_content( - subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None + subqueries: set[str], + url: str, + content: str = None, + user: KhojUser = None, + agent: Agent = None, + tracer: dict = {}, ) -> Tuple[set[str], str, Union[None, str]]: # Select the web scrapers to use for reading the web page web_scrapers = await ConversationAdapters.aget_enabled_webscrapers() @@ -214,7 +221,9 @@ async def read_webpage_and_extract_content( # Extract relevant information from the web page if is_none_or_empty(extracted_info): with timer(f"Extracting relevant information from web page at '{url}' took", logger): - extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent) + extracted_info = await extract_relevant_info( + subqueries, content, user=user, agent=agent, tracer=tracer + ) # If we successfully extracted information, break the loop if not is_none_or_empty(extracted_info): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index f89ca87a..c1f218d7 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -350,6 +350,7 @@ async def extract_references_and_questions( send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, agent: Agent = None, + tracer: dict = {}, ): user = request.user.object if request.user.is_authenticated else None @@ -425,6 +426,7 @@ async def extract_references_and_questions( user=user, max_prompt_size=conversation_config.max_prompt_size, personality_context=personality_context, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = conversation_config.openai_config @@ -442,6 +444,7 @@ async def extract_references_and_questions( query_images=query_images, vision_enabled=vision_enabled, personality_context=personality_context, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key @@ -456,6 +459,7 @@ async def extract_references_and_questions( user=user, vision_enabled=vision_enabled, personality_context=personality_context, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -471,6 +475,7 @@ async def extract_references_and_questions( user=user, vision_enabled=vision_enabled, personality_context=personality_context, + tracer=tracer, ) # Collate search results as context for GPT diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 3cc541b1..62a1f3b9 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -3,6 +3,7 @@ import base64 import json import logging import time +import uuid from datetime import datetime from functools import partial from typing import Dict, Optional @@ -563,6 +564,12 @@ async def chat( event_delimiter = "␃🔚␗" q = unquote(q) nonlocal conversation_id + tracer: dict = { + "mid": f"{uuid.uuid4()}", + "cid": conversation_id, + "uid": user.id, + "khoj_version": state.khoj_version, + } uploaded_images: list[str] = [] if images: @@ -682,6 +689,7 @@ async def chat( user=user, query_images=uploaded_images, agent=agent, + tracer=tracer, ) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) async for result in send_event( @@ -689,7 +697,9 @@ async def chat( ): yield result - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent) + mode = await aget_relevant_output_modes( + q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer + ) async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): yield result if mode not in conversation_commands: @@ -755,6 +765,7 @@ async def chat( query_images=uploaded_images, user=user, agent=agent, + tracer=tracer, ) response_log = str(response) async for result in send_llm_response(response_log): @@ -774,6 +785,7 @@ async def chat( client_application=request.user.client_app, conversation_id=conversation_id, query_images=uploaded_images, + tracer=tracer, ) return @@ -795,7 +807,7 @@ async def chat( if ConversationCommand.Automation in conversation_commands: try: automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log + q, timezone, user, request.url, meta_log, tracer=tracer ) except Exception as e: logger.error(f"Error scheduling task {q} for {user.email}: {e}") @@ -817,6 +829,7 @@ async def chat( inferred_queries=[query_to_run], automation_id=automation.id, query_images=uploaded_images, + tracer=tracer, ) async for result in send_llm_response(llm_response): yield result @@ -838,6 +851,7 @@ async def chat( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -882,6 +896,7 @@ async def chat( custom_filters, query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -906,6 +921,7 @@ async def chat( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -956,6 +972,7 @@ async def chat( send_status_func=partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -986,6 +1003,7 @@ async def chat( compiled_references=compiled_references, online_results=online_results, query_images=uploaded_images, + tracer=tracer, ) content_obj = { "intentType": intent_type, @@ -1014,6 +1032,7 @@ async def chat( user=user, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1041,6 +1060,7 @@ async def chat( compiled_references=compiled_references, online_results=online_results, query_images=uploaded_images, + tracer=tracer, ) async for result in send_llm_response(json.dumps(content_obj)): @@ -1064,6 +1084,7 @@ async def chat( location, user_name, uploaded_images, + tracer, ) # Send Response diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 6cc44c4f..1475c5cd 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -301,6 +301,7 @@ async def aget_relevant_information_sources( user: KhojUser, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. @@ -337,6 +338,7 @@ async def aget_relevant_information_sources( relevant_tools_prompt, response_type="json_object", user=user, + tracer=tracer, ) try: @@ -378,6 +380,7 @@ async def aget_relevant_output_modes( user: KhojUser = None, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. @@ -413,7 +416,9 @@ async def aget_relevant_output_modes( ) with timer("Chat actor: Infer output mode for chat response", logger): - response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user) + response = await send_message_to_model_wrapper( + relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer + ) try: response = response.strip() @@ -444,6 +449,7 @@ async def infer_webpage_urls( user: KhojUser, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ) -> List[str]: """ Infer webpage links from the given query @@ -468,7 +474,11 @@ async def infer_webpage_urls( with timer("Chat actor: Infer webpage urls to read", logger): response = await send_message_to_model_wrapper( - online_queries_prompt, query_images=query_images, response_type="json_object", user=user + online_queries_prompt, + query_images=query_images, + response_type="json_object", + user=user, + tracer=tracer, ) # Validate that the response is a non-empty, JSON-serializable list of URLs @@ -490,6 +500,7 @@ async def generate_online_subqueries( user: KhojUser, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ) -> List[str]: """ Generate subqueries from the given query @@ -514,7 +525,11 @@ async def generate_online_subqueries( with timer("Chat actor: Generate online search subqueries", logger): response = await send_message_to_model_wrapper( - online_queries_prompt, query_images=query_images, response_type="json_object", user=user + online_queries_prompt, + query_images=query_images, + response_type="json_object", + user=user, + tracer=tracer, ) # Validate that the response is a non-empty, JSON-serializable list @@ -533,7 +548,7 @@ async def generate_online_subqueries( async def schedule_query( - q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None + q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {} ) -> Tuple[str, ...]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. @@ -546,7 +561,7 @@ async def schedule_query( ) raw_response = await send_message_to_model_wrapper( - crontime_prompt, query_images=query_images, response_type="json_object", user=user + crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer ) # Validate that the response is a non-empty, JSON-serializable list @@ -561,7 +576,7 @@ async def schedule_query( async def extract_relevant_info( - qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None + qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {} ) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus @@ -584,6 +599,7 @@ async def extract_relevant_info( extract_relevant_information, prompts.system_prompt_extract_relevant_information, user=user, + tracer=tracer, ) return response.strip() @@ -595,6 +611,7 @@ async def extract_relevant_summary( query_images: List[str] = None, user: KhojUser = None, agent: Agent = None, + tracer: dict = {}, ) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus @@ -622,6 +639,7 @@ async def extract_relevant_summary( prompts.system_prompt_extract_relevant_summary, user=user, query_images=query_images, + tracer=tracer, ) return response.strip() @@ -636,6 +654,7 @@ async def generate_excalidraw_diagram( user: KhojUser = None, agent: Agent = None, send_status_func: Optional[Callable] = None, + tracer: dict = {}, ): if send_status_func: async for event in send_status_func("**Enhancing the Diagramming Prompt**"): @@ -650,6 +669,7 @@ async def generate_excalidraw_diagram( query_images=query_images, user=user, agent=agent, + tracer=tracer, ) if send_status_func: @@ -660,6 +680,7 @@ async def generate_excalidraw_diagram( q=better_diagram_description_prompt, user=user, agent=agent, + tracer=tracer, ) yield better_diagram_description_prompt, excalidraw_diagram_description @@ -674,6 +695,7 @@ async def generate_better_diagram_description( query_images: List[str] = None, user: KhojUser = None, agent: Agent = None, + tracer: dict = {}, ) -> str: """ Generate a diagram description from the given query and context @@ -711,7 +733,7 @@ async def generate_better_diagram_description( with timer("Chat actor: Generate better diagram description", logger): response = await send_message_to_model_wrapper( - improve_diagram_description_prompt, query_images=query_images, user=user + improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): @@ -724,6 +746,7 @@ async def generate_excalidraw_diagram_from_description( q: str, user: KhojUser = None, agent: Agent = None, + tracer: dict = {}, ) -> str: personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" @@ -735,7 +758,9 @@ async def generate_excalidraw_diagram_from_description( ) with timer("Chat actor: Generate excalidraw diagram", logger): - raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user) + raw_response = await send_message_to_model_wrapper( + message=excalidraw_diagram_generation, user=user, tracer=tracer + ) raw_response = raw_response.strip() raw_response = remove_json_codeblock(raw_response) response: Dict[str, str] = json.loads(raw_response) @@ -756,6 +781,7 @@ async def generate_better_image_prompt( query_images: Optional[List[str]] = None, user: KhojUser = None, agent: Agent = None, + tracer: dict = {}, ) -> str: """ Generate a better image prompt from the given query @@ -802,7 +828,9 @@ async def generate_better_image_prompt( ) with timer("Chat actor: Generate contextual image prompt", logger): - response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user) + response = await send_message_to_model_wrapper( + image_prompt, query_images=query_images, user=user, tracer=tracer + ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): response = response[1:-1] @@ -816,6 +844,7 @@ async def send_message_to_model_wrapper( response_type: str = "text", user: KhojUser = None, query_images: List[str] = None, + tracer: dict = {}, ): conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) vision_available = conversation_config.vision_enabled @@ -862,6 +891,7 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, streaming=False, response_type=response_type, + tracer=tracer, ) elif model_type == ChatModelOptions.ModelType.OPENAI: @@ -885,6 +915,7 @@ async def send_message_to_model_wrapper( model=chat_model, response_type=response_type, api_base_url=api_base_url, + tracer=tracer, ) elif model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key @@ -903,6 +934,7 @@ async def send_message_to_model_wrapper( messages=truncated_messages, api_key=api_key, model=chat_model, + tracer=tracer, ) elif model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -918,7 +950,7 @@ async def send_message_to_model_wrapper( ) return gemini_send_message_to_model( - messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type + messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer ) else: raise HTTPException(status_code=500, detail="Invalid conversation config") @@ -929,6 +961,7 @@ def send_message_to_model_wrapper_sync( system_message: str = "", response_type: str = "text", user: KhojUser = None, + tracer: dict = {}, ): conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) @@ -961,6 +994,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, streaming=False, response_type=response_type, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: @@ -975,7 +1009,11 @@ def send_message_to_model_wrapper_sync( ) openai_response = send_message_to_model( - messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type + messages=truncated_messages, + api_key=api_key, + model=chat_model, + response_type=response_type, + tracer=tracer, ) return openai_response @@ -995,6 +1033,7 @@ def send_message_to_model_wrapper_sync( messages=truncated_messages, api_key=api_key, model=chat_model, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: @@ -1013,6 +1052,7 @@ def send_message_to_model_wrapper_sync( api_key=api_key, model=chat_model, response_type=response_type, + tracer=tracer, ) else: raise HTTPException(status_code=500, detail="Invalid conversation config") @@ -1032,6 +1072,7 @@ def generate_chat_response( location_data: LocationData = None, user_name: Optional[str] = None, query_images: Optional[List[str]] = None, + tracer: dict = {}, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables chat_response = None @@ -1051,6 +1092,7 @@ def generate_chat_response( client_application=client_application, conversation_id=conversation_id, query_images=query_images, + tracer=tracer, ) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) @@ -1077,6 +1119,7 @@ def generate_chat_response( location_data=location_data, user_name=user_name, agent=agent, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: @@ -1100,6 +1143,7 @@ def generate_chat_response( user_name=user_name, agent=agent, vision_available=vision_available, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: @@ -1120,6 +1164,7 @@ def generate_chat_response( user_name=user_name, agent=agent, vision_available=vision_available, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -1139,6 +1184,7 @@ def generate_chat_response( user_name=user_name, agent=agent, vision_available=vision_available, + tracer=tracer, ) metadata.update({"chat_model": conversation_config.chat_model}) @@ -1495,9 +1541,15 @@ def scheduled_chat( async def create_automation( - q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None + q: str, + timezone: str, + user: KhojUser, + calling_url: URL, + meta_log: dict = {}, + conversation_id: str = None, + tracer: dict = {}, ): - crontime, query_to_run, subject = await schedule_query(q, meta_log, user) + crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer) job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id) return job, crontime, query_to_run, subject