mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Save conversation traces to git for visualization
This commit is contained in:
parent
7e0a692d16
commit
10c8fd3b2a
2 changed files with 161 additions and 0 deletions
|
@ -119,6 +119,7 @@ dev = [
|
||||||
"mypy >= 1.0.1",
|
"mypy >= 1.0.1",
|
||||||
"black >= 23.1.0",
|
"black >= 23.1.0",
|
||||||
"pre-commit >= 3.0.4",
|
"pre-commit >= 3.0.4",
|
||||||
|
"gitpython ~= 3.1.43",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.hatch.version]
|
[tool.hatch.version]
|
||||||
|
|
|
@ -2,6 +2,7 @@ import base64
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import mimetypes
|
import mimetypes
|
||||||
|
import os
|
||||||
import queue
|
import queue
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -12,6 +13,8 @@ from typing import Any, Dict, List, Optional
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import requests
|
import requests
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
import yaml
|
||||||
|
from git import Repo
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
from llama_cpp.llama import Llama
|
from llama_cpp.llama import Llama
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
@ -344,3 +347,160 @@ def get_image_from_url(image_url: str, type="pil"):
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||||
return ImageWithType(content=None, type=None)
|
return ImageWithType(content=None, type=None)
|
||||||
|
|
||||||
|
|
||||||
|
def commit_conversation_trace(
|
||||||
|
session: list[ChatMessage],
|
||||||
|
response: str | list[dict],
|
||||||
|
tracer: dict,
|
||||||
|
system_message: str | list[dict] = "",
|
||||||
|
repo_path: str = "/tmp/khoj_promptrace",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
|
||||||
|
Returns the path to the repository.
|
||||||
|
"""
|
||||||
|
# Serialize session, system message and response to yaml
|
||||||
|
system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||||
|
response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||||
|
formatted_session = [{"role": message.role, "content": message.content} for message in session]
|
||||||
|
session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||||
|
query = (
|
||||||
|
yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||||
|
.strip()
|
||||||
|
.removeprefix("'")
|
||||||
|
.removesuffix("'")
|
||||||
|
) # Extract serialized query from chat session
|
||||||
|
|
||||||
|
# Extract chat metadata for session
|
||||||
|
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
|
||||||
|
|
||||||
|
# Infer repository path from environment variable or provided path
|
||||||
|
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Prepare git repository
|
||||||
|
os.makedirs(repo_path, exist_ok=True)
|
||||||
|
repo = Repo.init(repo_path)
|
||||||
|
|
||||||
|
# Remove post-commit hook if it exists
|
||||||
|
hooks_dir = os.path.join(repo_path, ".git", "hooks")
|
||||||
|
post_commit_hook = os.path.join(hooks_dir, "post-commit")
|
||||||
|
if os.path.exists(post_commit_hook):
|
||||||
|
os.remove(post_commit_hook)
|
||||||
|
|
||||||
|
# Configure git user if not set
|
||||||
|
if not repo.config_reader().has_option("user", "email"):
|
||||||
|
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
|
||||||
|
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
|
||||||
|
|
||||||
|
# Create an initial commit if the repository is newly created
|
||||||
|
if not repo.head.is_valid():
|
||||||
|
repo.index.commit("And then there was a trace")
|
||||||
|
|
||||||
|
# Check out the initial commit
|
||||||
|
initial_commit = repo.commit("HEAD~0")
|
||||||
|
repo.head.reference = initial_commit
|
||||||
|
repo.head.reset(index=True, working_tree=True)
|
||||||
|
|
||||||
|
# Create or switch to user branch from initial commit
|
||||||
|
user_branch = f"u_{uid}"
|
||||||
|
if user_branch not in repo.branches:
|
||||||
|
repo.create_head(user_branch)
|
||||||
|
repo.heads[user_branch].checkout()
|
||||||
|
|
||||||
|
# Create or switch to conversation branch from user branch
|
||||||
|
conv_branch = f"c_{cid}"
|
||||||
|
if conv_branch not in repo.branches:
|
||||||
|
repo.create_head(conv_branch)
|
||||||
|
repo.heads[conv_branch].checkout()
|
||||||
|
|
||||||
|
# Create or switch to message branch from conversation branch
|
||||||
|
msg_branch = f"m_{mid}" if mid else None
|
||||||
|
if msg_branch and msg_branch not in repo.branches:
|
||||||
|
repo.create_head(msg_branch)
|
||||||
|
repo.heads[msg_branch].checkout()
|
||||||
|
|
||||||
|
# Include file with content to commit
|
||||||
|
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
|
||||||
|
|
||||||
|
# Write files and stage them
|
||||||
|
for filename, content in files_to_commit.items():
|
||||||
|
file_path = os.path.join(repo_path, filename)
|
||||||
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(content)
|
||||||
|
repo.index.add([filename])
|
||||||
|
|
||||||
|
# Create commit
|
||||||
|
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||||
|
commit_message = f"""
|
||||||
|
{query[:250]}
|
||||||
|
|
||||||
|
Response:
|
||||||
|
---
|
||||||
|
{response[:500]}...
|
||||||
|
|
||||||
|
Metadata
|
||||||
|
---
|
||||||
|
{metadata_yaml}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
repo.index.commit(commit_message)
|
||||||
|
|
||||||
|
logger.debug(f"Saved conversation trace to repo at {repo_path}")
|
||||||
|
return repo_path
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to add conversation trace to repo: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path=None) -> bool:
|
||||||
|
"""
|
||||||
|
Merge the message branch into its parent conversation branch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: User query
|
||||||
|
response: Assistant response
|
||||||
|
tracer: Dictionary containing uid, cid and mid
|
||||||
|
repo_path: Path to the git repository
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if merge was successful, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Infer repository path from environment variable or provided path
|
||||||
|
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
|
||||||
|
repo = Repo(repo_path)
|
||||||
|
|
||||||
|
# Extract branch names
|
||||||
|
msg_branch = f"m_{tracer['mid']}"
|
||||||
|
conv_branch = f"c_{tracer['cid']}"
|
||||||
|
|
||||||
|
# Checkout conversation branch
|
||||||
|
repo.heads[conv_branch].checkout()
|
||||||
|
|
||||||
|
# Create commit message
|
||||||
|
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||||
|
commit_message = f"""
|
||||||
|
{query[:250]}
|
||||||
|
|
||||||
|
Response:
|
||||||
|
---
|
||||||
|
{response[:500]}...
|
||||||
|
|
||||||
|
Metadata
|
||||||
|
---
|
||||||
|
{metadata_yaml}
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
# Merge message branch into conversation branch
|
||||||
|
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
|
||||||
|
|
||||||
|
# Delete message branch after merge
|
||||||
|
repo.delete_head(msg_branch, force=True)
|
||||||
|
|
||||||
|
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
Loading…
Reference in a new issue