Add Prompt Tracer to Visualize, Analyze and Debug Khoj's Train of Thought (#951)

## Overview
Use git to capture prompt traces of khoj's train of thought. View, analyze and debug them using your favorite git client (e.g vscode, magit).

- Each commit captures an interaction with an LLM
  The commit writes the query, response and system message each to a separate file in the repo.
  The commit message captures the chat model, Khoj version and other metadata
- Each conversation turn can have multiple interactions with an LLM (e.g Khoj's train of thought)
- Each new conversation turn forks from and merges back into its conversation branch
- Each new conversation branches from the user branch
- Each new user branches from root commit on the main branch

## Usage
1. Set `KHOJ_DEBUG=true` or start khoj in very verbose mode with `khoj -vv` to turn on prompt tracing
2. Chat with Khoj as usual 
3. Open the promptrace git repo to view the generated prompt traces using your favorite git porcelain. 
   The Khoj prompt trace git repo is created at `/tmp/khoj_promptrace` by default. You can configure the prompt trace directory by setting the `PROMPTRACE_DIR`environment variable.

## Implementation
- Add utility functions to capture prompt traces using git (via `gitpython`)
- Make each model provider in Khoj commit their LLM interactions with promptrace
- Weave chat metadata from chat API through all chat actors and commit it to the prompt trace
This commit is contained in:
Debanjum 2024-11-01 11:33:54 -07:00 committed by GitHub
commit 1c920273dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 439 additions and 57 deletions

View file

@ -119,6 +119,7 @@ dev = [
"mypy >= 1.0.1",
"black >= 23.1.0",
"pre-commit >= 3.0.4",
"gitpython ~= 3.1.43",
]
[tool.hatch.version]

View file

@ -34,6 +34,7 @@ def extract_questions_anthropic(
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
@ -89,6 +90,7 @@ def extract_questions_anthropic(
model_name=model,
temperature=temperature,
api_key=api_key,
tracer=tracer,
)
# Extract, Clean Message from Claude's Response
@ -110,7 +112,7 @@ def extract_questions_anthropic(
return questions
def anthropic_send_message_to_model(messages, api_key, model):
def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
"""
Send message to model
"""
@ -122,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model):
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
tracer=tracer,
)
@ -141,6 +144,7 @@ def converse_anthropic(
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
tracer: dict = {},
):
"""
Converse with user using Anthropic's Claude
@ -213,4 +217,5 @@ def converse_anthropic(
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,
tracer=tracer,
)

View file

@ -12,8 +12,13 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
from khoj.utils.helpers import is_none_or_empty
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__)
@ -30,7 +35,7 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
reraise=True,
)
def anthropic_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
) -> str:
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
@ -58,6 +63,12 @@ def anthropic_completion_with_backoff(
for text in stream.text_stream:
aggregated_response += text
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response
@ -78,18 +89,19 @@ def anthropic_chat_completion_with_backoff(
max_prompt_size=None,
completion_func=None,
model_kwargs=None,
tracer={},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=anthropic_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs, tracer),
)
t.start()
return g
def anthropic_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None, tracer={}
):
try:
if api_key not in anthropic_clients:
@ -102,6 +114,7 @@ def anthropic_llm_thread(
anthropic.types.MessageParam(role=message.role, content=message.content) for message in messages
]
aggregated_response = ""
with client.messages.stream(
messages=formatted_messages,
model=model_name, # type: ignore
@ -112,7 +125,14 @@ def anthropic_llm_thread(
**(model_kwargs or dict()),
) as stream:
for text in stream.text_stream:
aggregated_response += text
g.send(text)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)
finally:

View file

@ -35,6 +35,7 @@ def extract_questions_gemini(
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
@ -85,7 +86,7 @@ def extract_questions_gemini(
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
)
# Extract, Clean Message from Gemini's Response
@ -107,7 +108,9 @@ def extract_questions_gemini(
return questions
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
def gemini_send_message_to_model(
messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={}
):
"""
Send message to model
"""
@ -125,6 +128,7 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text",
api_key=api_key,
temperature=temperature,
model_kwargs=model_kwargs,
tracer=tracer,
)
@ -145,6 +149,7 @@ def converse_gemini(
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
tracer={},
):
"""
Converse with user using Google's Gemini
@ -217,4 +222,5 @@ def converse_gemini(
api_key=api_key,
system_prompt=system_prompt,
completion_func=completion_func,
tracer=tracer,
)

View file

@ -19,8 +19,13 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import ThreadedGenerator, get_image_from_url
from khoj.utils.helpers import is_none_or_empty
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
get_image_from_url,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty
logger = logging.getLogger(__name__)
@ -35,7 +40,7 @@ MAX_OUTPUT_TOKENS_GEMINI = 8192
reraise=True,
)
def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, tracer={}
) -> str:
genai.configure(api_key=api_key)
model_kwargs = model_kwargs or dict()
@ -60,16 +65,23 @@ def gemini_completion_with_backoff(
try:
# Generate the response. The last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
return aggregated_response.text
response = chat_session.send_message(formatted_messages[-1]["parts"])
response_text = response.text
except StopCandidateException as e:
response_message, _ = handle_gemini_response(e.args)
response_text, _ = handle_gemini_response(e.args)
# Respond with reason for stopping
logger.warning(
f"LLM Response Prevented for {model_name}: {response_message}.\n"
f"LLM Response Prevented for {model_name}: {response_text}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
return response_message
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, response_text, tracer)
return response_text
@retry(
@ -88,17 +100,20 @@ def gemini_chat_completion_with_backoff(
system_prompt,
completion_func=None,
model_kwargs=None,
tracer: dict = {},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs),
args=(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs, tracer),
)
t.start()
return g
def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None):
def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, model_kwargs=None, tracer: dict = {}
):
try:
genai.configure(api_key=api_key)
model_kwargs = model_kwargs or dict()
@ -117,16 +132,25 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
},
)
aggregated_response = ""
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text
aggregated_response += message
g.send(message)
if stopped:
raise StopCandidateException(message)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except StopCandidateException as e:
logger.warning(
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"

View file

@ -12,11 +12,12 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
generate_chatml_messages_with_context,
)
from khoj.utils import state
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
@ -34,6 +35,7 @@ def extract_questions_offline(
max_prompt_size: int = None,
temperature: float = 0.7,
personality_context: Optional[str] = None,
tracer: dict = {},
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
@ -94,6 +96,7 @@ def extract_questions_offline(
max_prompt_size=max_prompt_size,
temperature=temperature,
response_type="json_object",
tracer=tracer,
)
finally:
state.chat_lock.release()
@ -146,6 +149,7 @@ def converse_offline(
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
Converse with user using Llama
@ -153,8 +157,9 @@ def converse_offline(
# Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
tracer["chat_model"] = model
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
current_date = datetime.now()
if agent and agent.personality:
@ -215,13 +220,14 @@ def converse_offline(
logger.debug(f"Conversation Context for {model}: {truncated_messages}")
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size, tracer))
t.start()
return g
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None, tracer: dict = {}):
stop_phrases = ["<s>", "INST]", "Notes:"]
aggregated_response = ""
state.chat_lock.acquire()
try:
@ -229,7 +235,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
)
for response in response_iterator:
g.send(response["choices"][0]["delta"].get("content", ""))
response_delta = response["choices"][0]["delta"].get("content", "")
aggregated_response += response_delta
g.send(response_delta)
# Save conversation trace
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
finally:
state.chat_lock.release()
g.close()
@ -244,6 +257,7 @@ def send_message_to_model_offline(
stop=[],
max_prompt_size: int = None,
response_type: str = "text",
tracer: dict = {},
):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
@ -251,7 +265,17 @@ def send_message_to_model_offline(
response = offline_chat_model.create_chat_completion(
messages_dict, stop=stop, stream=streaming, temperature=temperature, response_format={"type": response_type}
)
if streaming:
return response
else:
return response["choices"][0]["message"].get("content", "")
response_text = response["choices"][0]["message"].get("content", "")
# Save conversation trace for non-streaming responses
# Streamed responses need to be saved by the calling function
tracer["chat_model"] = model
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, response_text, tracer)
return response_text

View file

@ -33,6 +33,7 @@ def extract_questions(
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
tracer: dict = {},
):
"""
Infer search queries to retrieve relevant notes to answer user query
@ -82,7 +83,13 @@ def extract_questions(
messages = [ChatMessage(content=prompt, role="user")]
response = send_message_to_model(
messages, api_key, model, response_type="json_object", api_base_url=api_base_url, temperature=temperature
messages,
api_key,
model,
response_type="json_object",
api_base_url=api_base_url,
temperature=temperature,
tracer=tracer,
)
# Extract, Clean Message from GPT's Response
@ -103,7 +110,9 @@ def extract_questions(
return questions
def send_message_to_model(messages, api_key, model, response_type="text", api_base_url=None, temperature=0):
def send_message_to_model(
messages, api_key, model, response_type="text", api_base_url=None, temperature=0, tracer: dict = {}
):
"""
Send message to model
"""
@ -116,6 +125,7 @@ def send_message_to_model(messages, api_key, model, response_type="text", api_ba
temperature=temperature,
api_base_url=api_base_url,
model_kwargs={"response_format": {"type": response_type}},
tracer=tracer,
)
@ -137,6 +147,7 @@ def converse(
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
tracer: dict = {},
):
"""
Converse with user using OpenAI's ChatGPT
@ -207,4 +218,5 @@ def converse(
api_base_url=api_base_url,
completion_func=completion_func,
model_kwargs={"stop": ["Notes:\n["]},
tracer=tracer,
)

View file

@ -12,7 +12,12 @@ from tenacity import (
wait_random_exponential,
)
from khoj.processor.conversation.utils import ThreadedGenerator
from khoj.processor.conversation.utils import (
ThreadedGenerator,
commit_conversation_trace,
)
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode
logger = logging.getLogger(__name__)
@ -33,7 +38,7 @@ openai_clients: Dict[str, openai.OpenAI] = {}
reraise=True,
)
def completion_with_backoff(
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None
messages, model, temperature=0, openai_api_key=None, api_base_url=None, model_kwargs=None, tracer: dict = {}
) -> str:
client_key = f"{openai_api_key}--{api_base_url}"
client: openai.OpenAI | None = openai_clients.get(client_key)
@ -77,6 +82,12 @@ def completion_with_backoff(
elif delta_chunk.content:
aggregated_response += delta_chunk.content
# Save conversation trace
tracer["chat_model"] = model
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response
@ -103,26 +114,37 @@ def chat_completion_with_backoff(
api_base_url=None,
completion_func=None,
model_kwargs=None,
tracer: dict = {},
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs)
target=llm_thread,
args=(g, messages, model_name, temperature, openai_api_key, api_base_url, model_kwargs, tracer),
)
t.start()
return g
def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_base_url=None, model_kwargs=None):
def llm_thread(
g,
messages,
model_name,
temperature,
openai_api_key=None,
api_base_url=None,
model_kwargs=None,
tracer: dict = {},
):
try:
client_key = f"{openai_api_key}--{api_base_url}"
if client_key not in openai_clients:
client: openai.OpenAI = openai.OpenAI(
client = openai.OpenAI(
api_key=openai_api_key,
base_url=api_base_url,
)
openai_clients[client_key] = client
else:
client: openai.OpenAI = openai_clients[client_key]
client = openai_clients[client_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
stream = True
@ -144,17 +166,29 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba
**(model_kwargs or dict()),
)
aggregated_response = ""
if not stream:
g.send(chat.choices[0].message.content)
aggregated_response = chat.choices[0].message.content
g.send(aggregated_response)
else:
for chunk in chat:
if len(chunk.choices) == 0:
continue
delta_chunk = chunk.choices[0].delta
text_chunk = ""
if isinstance(delta_chunk, str):
g.send(delta_chunk)
text_chunk = delta_chunk
elif delta_chunk.content:
g.send(delta_chunk.content)
text_chunk = delta_chunk.content
if text_chunk:
aggregated_response += text_chunk
g.send(text_chunk)
# Save conversation trace
tracer["chat_model"] = model_name
tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1:
commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e:
logger.error(f"Error in llm_thread: {e}", exc_info=True)
finally:

View file

@ -2,6 +2,7 @@ import base64
import logging
import math
import mimetypes
import os
import queue
from dataclasses import dataclass
from datetime import datetime
@ -12,6 +13,8 @@ from typing import Any, Dict, List, Optional
import PIL.Image
import requests
import tiktoken
import yaml
from git import Repo
from langchain.schema import ChatMessage
from llama_cpp.llama import Llama
from transformers import AutoTokenizer
@ -21,7 +24,7 @@ from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, merge_dicts
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
logger = logging.getLogger(__name__)
model_to_prompt_size = {
@ -117,6 +120,7 @@ def save_to_conversation_log(
conversation_id: str = None,
automation_id: str = None,
query_images: List[str] = None,
tracer: Dict[str, Any] = {},
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
@ -142,6 +146,9 @@ def save_to_conversation_log(
user_message=q,
)
if in_debug_mode() or state.verbose > 1:
merge_message_into_conversation_trace(q, chat_response, tracer)
logger.info(
f"""
Saved Conversation Turn
@ -354,3 +361,163 @@ def get_image_from_url(image_url: str, type="pil"):
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return ImageWithType(content=None, type=None)
def commit_conversation_trace(
session: list[ChatMessage],
response: str | list[dict],
tracer: dict,
system_message: str | list[dict] = "",
repo_path: str = "/tmp/promptrace",
) -> str:
"""
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
Returns the path to the repository.
"""
# Serialize session, system message and response to yaml
system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False)
response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False)
formatted_session = [{"role": message.role, "content": message.content} for message in session]
session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False)
query = (
yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False)
.strip()
.removeprefix("'")
.removesuffix("'")
) # Extract serialized query from chat session
# Extract chat metadata for session
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
# Infer repository path from environment variable or provided path
repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
try:
# Prepare git repository
os.makedirs(repo_path, exist_ok=True)
repo = Repo.init(repo_path)
# Remove post-commit hook if it exists
hooks_dir = os.path.join(repo_path, ".git", "hooks")
post_commit_hook = os.path.join(hooks_dir, "post-commit")
if os.path.exists(post_commit_hook):
os.remove(post_commit_hook)
# Configure git user if not set
if not repo.config_reader().has_option("user", "email"):
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
# Create an initial commit if the repository is newly created
if not repo.head.is_valid():
repo.index.commit("And then there was a trace")
# Check out the initial commit
initial_commit = repo.commit("HEAD~0")
repo.head.reference = initial_commit
repo.head.reset(index=True, working_tree=True)
# Create or switch to user branch from initial commit
user_branch = f"u_{uid}"
if user_branch not in repo.branches:
repo.create_head(user_branch)
repo.heads[user_branch].checkout()
# Create or switch to conversation branch from user branch
conv_branch = f"c_{cid}"
if conv_branch not in repo.branches:
repo.create_head(conv_branch)
repo.heads[conv_branch].checkout()
# Create or switch to message branch from conversation branch
msg_branch = f"m_{mid}" if mid else None
if msg_branch and msg_branch not in repo.branches:
repo.create_head(msg_branch)
if msg_branch:
repo.heads[msg_branch].checkout()
# Include file with content to commit
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
# Write files and stage them
for filename, content in files_to_commit.items():
file_path = os.path.join(repo_path, filename)
# Unescape special characters in content for better readability
content = content.strip().replace("\\n", "\n").replace("\\t", "\t")
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
repo.index.add([filename])
# Create commit
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
commit_message = f"""
{query[:250]}
Response:
---
{response[:500]}...
Metadata
---
{metadata_yaml}
""".strip()
repo.index.commit(commit_message)
logger.debug(f"Saved conversation trace to repo at {repo_path}")
return repo_path
except Exception as e:
logger.error(f"Failed to add conversation trace to repo: {str(e)}", exc_info=True)
return None
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path="/tmp/promptrace") -> bool:
"""
Merge the message branch into its parent conversation branch.
Args:
query: User query
response: Assistant response
tracer: Dictionary containing uid, cid and mid
repo_path: Path to the git repository
Returns:
bool: True if merge was successful, False otherwise
"""
try:
# Extract branch names
msg_branch = f"m_{tracer['mid']}"
conv_branch = f"c_{tracer['cid']}"
# Infer repository path from environment variable or provided path
repo_path = os.getenv("PROMPTRACE_DIR", repo_path)
repo = Repo(repo_path)
# Checkout conversation branch
repo.heads[conv_branch].checkout()
# Create commit message
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
commit_message = f"""
{query[:250]}
Response:
---
{response[:500]}...
Metadata
---
{metadata_yaml}
""".strip()
# Merge message branch into conversation branch
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
# Delete message branch after merge
repo.delete_head(msg_branch, force=True)
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
return True
except Exception as e:
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}", exc_info=True)
return False

View file

@ -28,6 +28,7 @@ async def text_to_image(
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
tracer: dict = {},
):
status_code = 200
image = None
@ -68,6 +69,7 @@ async def text_to_image(
query_images=query_images,
user=user,
agent=agent,
tracer=tracer,
)
if send_status_func:

View file

@ -64,6 +64,7 @@ async def search_online(
custom_filters: List[str] = [],
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
query += " ".join(custom_filters)
if not is_internet_connected():
@ -73,7 +74,7 @@ async def search_online(
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(
query, conversation_history, location, user, query_images=query_images, agent=agent
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
)
response_dict = {}
@ -111,7 +112,7 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
for link, data in webpages.items()
]
results = await asyncio.gather(*tasks)
@ -153,6 +154,7 @@ async def read_webpages(
send_status_func: Optional[Callable] = None,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
@ -166,7 +168,7 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
@ -192,7 +194,12 @@ async def read_webpage(
async def read_webpage_and_extract_content(
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
subqueries: set[str],
url: str,
content: str = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Tuple[set[str], str, Union[None, str]]:
# Select the web scrapers to use for reading the web page
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
@ -214,7 +221,9 @@ async def read_webpage_and_extract_content(
# Extract relevant information from the web page
if is_none_or_empty(extracted_info):
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
extracted_info = await extract_relevant_info(
subqueries, content, user=user, agent=agent, tracer=tracer
)
# If we successfully extracted information, break the loop
if not is_none_or_empty(extracted_info):

View file

@ -350,6 +350,7 @@ async def extract_references_and_questions(
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
tracer: dict = {},
):
user = request.user.object if request.user.is_authenticated else None
@ -425,6 +426,7 @@ async def extract_references_and_questions(
user=user,
max_prompt_size=conversation_config.max_prompt_size,
personality_context=personality_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
@ -442,6 +444,7 @@ async def extract_references_and_questions(
query_images=query_images,
vision_enabled=vision_enabled,
personality_context=personality_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
@ -456,6 +459,7 @@ async def extract_references_and_questions(
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
@ -471,6 +475,7 @@ async def extract_references_and_questions(
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
tracer=tracer,
)
# Collate search results as context for GPT

View file

@ -3,6 +3,7 @@ import base64
import json
import logging
import time
import uuid
from datetime import datetime
from functools import partial
from typing import Dict, Optional
@ -563,6 +564,12 @@ async def chat(
event_delimiter = "␃🔚␗"
q = unquote(q)
nonlocal conversation_id
tracer: dict = {
"mid": f"{uuid.uuid4()}",
"cid": conversation_id,
"uid": user.id,
"khoj_version": state.khoj_version,
}
uploaded_images: list[str] = []
if images:
@ -682,6 +689,7 @@ async def chat(
user=user,
query_images=uploaded_images,
agent=agent,
tracer=tracer,
)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
@ -689,7 +697,9 @@ async def chat(
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
mode = await aget_relevant_output_modes(
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
@ -755,6 +765,7 @@ async def chat(
query_images=uploaded_images,
user=user,
agent=agent,
tracer=tracer,
)
response_log = str(response)
async for result in send_llm_response(response_log):
@ -774,6 +785,7 @@ async def chat(
client_application=request.user.client_app,
conversation_id=conversation_id,
query_images=uploaded_images,
tracer=tracer,
)
return
@ -795,7 +807,7 @@ async def chat(
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log
q, timezone, user, request.url, meta_log, tracer=tracer
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
@ -817,6 +829,7 @@ async def chat(
inferred_queries=[query_to_run],
automation_id=automation.id,
query_images=uploaded_images,
tracer=tracer,
)
async for result in send_llm_response(llm_response):
yield result
@ -838,6 +851,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -882,6 +896,7 @@ async def chat(
custom_filters,
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -906,6 +921,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -956,6 +972,7 @@ async def chat(
send_status_func=partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -986,6 +1003,7 @@ async def chat(
compiled_references=compiled_references,
online_results=online_results,
query_images=uploaded_images,
tracer=tracer,
)
content_obj = {
"intentType": intent_type,
@ -1014,6 +1032,7 @@ async def chat(
user=user,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -1041,6 +1060,7 @@ async def chat(
compiled_references=compiled_references,
online_results=online_results,
query_images=uploaded_images,
tracer=tracer,
)
async for result in send_llm_response(json.dumps(content_obj)):
@ -1064,6 +1084,7 @@ async def chat(
location,
user_name,
uploaded_images,
tracer,
)
# Send Response

View file

@ -301,6 +301,7 @@ async def aget_relevant_information_sources(
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -337,6 +338,7 @@ async def aget_relevant_information_sources(
relevant_tools_prompt,
response_type="json_object",
user=user,
tracer=tracer,
)
try:
@ -378,6 +380,7 @@ async def aget_relevant_output_modes(
user: KhojUser = None,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -413,7 +416,9 @@ async def aget_relevant_output_modes(
)
with timer("Chat actor: Infer output mode for chat response", logger):
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
response = await send_message_to_model_wrapper(
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer
)
try:
response = response.strip()
@ -444,6 +449,7 @@ async def infer_webpage_urls(
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
) -> List[str]:
"""
Infer webpage links from the given query
@ -468,7 +474,11 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
online_queries_prompt,
query_images=query_images,
response_type="json_object",
user=user,
tracer=tracer,
)
# Validate that the response is a non-empty, JSON-serializable list of URLs
@ -490,6 +500,7 @@ async def generate_online_subqueries(
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
) -> List[str]:
"""
Generate subqueries from the given query
@ -514,7 +525,11 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
online_queries_prompt,
query_images=query_images,
response_type="json_object",
user=user,
tracer=tracer,
)
# Validate that the response is a non-empty, JSON-serializable list
@ -533,7 +548,7 @@ async def generate_online_subqueries(
async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
) -> Tuple[str, ...]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
@ -546,7 +561,7 @@ async def schedule_query(
)
raw_response = await send_message_to_model_wrapper(
crontime_prompt, query_images=query_images, response_type="json_object", user=user
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
)
# Validate that the response is a non-empty, JSON-serializable list
@ -561,7 +576,7 @@ async def schedule_query(
async def extract_relevant_info(
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
@ -584,6 +599,7 @@ async def extract_relevant_info(
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
user=user,
tracer=tracer,
)
return response.strip()
@ -595,6 +611,7 @@ async def extract_relevant_summary(
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
@ -622,6 +639,7 @@ async def extract_relevant_summary(
prompts.system_prompt_extract_relevant_summary,
user=user,
query_images=query_images,
tracer=tracer,
)
return response.strip()
@ -636,6 +654,7 @@ async def generate_excalidraw_diagram(
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
tracer: dict = {},
):
if send_status_func:
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
@ -650,6 +669,7 @@ async def generate_excalidraw_diagram(
query_images=query_images,
user=user,
agent=agent,
tracer=tracer,
)
if send_status_func:
@ -660,6 +680,7 @@ async def generate_excalidraw_diagram(
q=better_diagram_description_prompt,
user=user,
agent=agent,
tracer=tracer,
)
yield better_diagram_description_prompt, excalidraw_diagram_description
@ -674,6 +695,7 @@ async def generate_better_diagram_description(
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> str:
"""
Generate a diagram description from the given query and context
@ -711,7 +733,7 @@ async def generate_better_diagram_description(
with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt, query_images=query_images, user=user
improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@ -724,6 +746,7 @@ async def generate_excalidraw_diagram_from_description(
q: str,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> str:
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@ -735,7 +758,9 @@ async def generate_excalidraw_diagram_from_description(
)
with timer("Chat actor: Generate excalidraw diagram", logger):
raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user)
raw_response = await send_message_to_model_wrapper(
message=excalidraw_diagram_generation, user=user, tracer=tracer
)
raw_response = raw_response.strip()
raw_response = remove_json_codeblock(raw_response)
response: Dict[str, str] = json.loads(raw_response)
@ -756,6 +781,7 @@ async def generate_better_image_prompt(
query_images: Optional[List[str]] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> str:
"""
Generate a better image prompt from the given query
@ -802,7 +828,9 @@ async def generate_better_image_prompt(
)
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
response = await send_message_to_model_wrapper(
image_prompt, query_images=query_images, user=user, tracer=tracer
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
@ -816,6 +844,7 @@ async def send_message_to_model_wrapper(
response_type: str = "text",
user: KhojUser = None,
query_images: List[str] = None,
tracer: dict = {},
):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled
@ -862,6 +891,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.OPENAI:
@ -885,6 +915,7 @@ async def send_message_to_model_wrapper(
model=chat_model,
response_type=response_type,
api_base_url=api_base_url,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
@ -903,6 +934,7 @@ async def send_message_to_model_wrapper(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
@ -918,7 +950,7 @@ async def send_message_to_model_wrapper(
)
return gemini_send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -929,6 +961,7 @@ def send_message_to_model_wrapper_sync(
system_message: str = "",
response_type: str = "text",
user: KhojUser = None,
tracer: dict = {},
):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
@ -961,6 +994,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -975,7 +1009,11 @@ def send_message_to_model_wrapper_sync(
)
openai_response = send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
return openai_response
@ -995,6 +1033,7 @@ def send_message_to_model_wrapper_sync(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
@ -1013,6 +1052,7 @@ def send_message_to_model_wrapper_sync(
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -1032,6 +1072,7 @@ def generate_chat_response(
location_data: LocationData = None,
user_name: Optional[str] = None,
query_images: Optional[List[str]] = None,
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
chat_response = None
@ -1051,6 +1092,7 @@ def generate_chat_response(
client_application=client_application,
conversation_id=conversation_id,
query_images=query_images,
tracer=tracer,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
@ -1077,6 +1119,7 @@ def generate_chat_response(
location_data=location_data,
user_name=user_name,
agent=agent,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -1100,6 +1143,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
vision_available=vision_available,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
@ -1120,6 +1164,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
vision_available=vision_available,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
@ -1139,6 +1184,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
vision_available=vision_available,
tracer=tracer,
)
metadata.update({"chat_model": conversation_config.chat_model})
@ -1495,9 +1541,15 @@ def scheduled_chat(
async def create_automation(
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
q: str,
timezone: str,
user: KhojUser,
calling_url: URL,
meta_log: dict = {},
conversation_id: str = None,
tracer: dict = {},
):
crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer)
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
return job, crontime, query_to_run, subject