Save conversation traces to git for visualization

This commit is contained in:
Debanjum Singh Solanky 2024-10-23 19:21:43 -07:00
parent 7e0a692d16
commit 10c8fd3b2a
2 changed files with 161 additions and 0 deletions

View file

@ -119,6 +119,7 @@ dev = [
"mypy >= 1.0.1",
"black >= 23.1.0",
"pre-commit >= 3.0.4",
"gitpython ~= 3.1.43",
]
[tool.hatch.version]

View file

@ -2,6 +2,7 @@ import base64
import logging
import math
import mimetypes
import os
import queue
from dataclasses import dataclass
from datetime import datetime
@ -12,6 +13,8 @@ from typing import Any, Dict, List, Optional
import PIL.Image
import requests
import tiktoken
import yaml
from git import Repo
from langchain.schema import ChatMessage
from llama_cpp.llama import Llama
from transformers import AutoTokenizer
@ -344,3 +347,160 @@ def get_image_from_url(image_url: str, type="pil"):
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return ImageWithType(content=None, type=None)
def commit_conversation_trace(
session: list[ChatMessage],
response: str | list[dict],
tracer: dict,
system_message: str | list[dict] = "",
repo_path: str = "/tmp/khoj_promptrace",
) -> str:
"""
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
Returns the path to the repository.
"""
# Serialize session, system message and response to yaml
system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False)
response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False)
formatted_session = [{"role": message.role, "content": message.content} for message in session]
session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False)
query = (
yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False)
.strip()
.removeprefix("'")
.removesuffix("'")
) # Extract serialized query from chat session
# Extract chat metadata for session
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
# Infer repository path from environment variable or provided path
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
try:
# Prepare git repository
os.makedirs(repo_path, exist_ok=True)
repo = Repo.init(repo_path)
# Remove post-commit hook if it exists
hooks_dir = os.path.join(repo_path, ".git", "hooks")
post_commit_hook = os.path.join(hooks_dir, "post-commit")
if os.path.exists(post_commit_hook):
os.remove(post_commit_hook)
# Configure git user if not set
if not repo.config_reader().has_option("user", "email"):
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
# Create an initial commit if the repository is newly created
if not repo.head.is_valid():
repo.index.commit("And then there was a trace")
# Check out the initial commit
initial_commit = repo.commit("HEAD~0")
repo.head.reference = initial_commit
repo.head.reset(index=True, working_tree=True)
# Create or switch to user branch from initial commit
user_branch = f"u_{uid}"
if user_branch not in repo.branches:
repo.create_head(user_branch)
repo.heads[user_branch].checkout()
# Create or switch to conversation branch from user branch
conv_branch = f"c_{cid}"
if conv_branch not in repo.branches:
repo.create_head(conv_branch)
repo.heads[conv_branch].checkout()
# Create or switch to message branch from conversation branch
msg_branch = f"m_{mid}" if mid else None
if msg_branch and msg_branch not in repo.branches:
repo.create_head(msg_branch)
repo.heads[msg_branch].checkout()
# Include file with content to commit
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
# Write files and stage them
for filename, content in files_to_commit.items():
file_path = os.path.join(repo_path, filename)
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
repo.index.add([filename])
# Create commit
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
commit_message = f"""
{query[:250]}
Response:
---
{response[:500]}...
Metadata
---
{metadata_yaml}
""".strip()
repo.index.commit(commit_message)
logger.debug(f"Saved conversation trace to repo at {repo_path}")
return repo_path
except Exception as e:
logger.error(f"Failed to add conversation trace to repo: {str(e)}")
return None
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path=None) -> bool:
"""
Merge the message branch into its parent conversation branch.
Args:
query: User query
response: Assistant response
tracer: Dictionary containing uid, cid and mid
repo_path: Path to the git repository
Returns:
bool: True if merge was successful, False otherwise
"""
try:
# Infer repository path from environment variable or provided path
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
repo = Repo(repo_path)
# Extract branch names
msg_branch = f"m_{tracer['mid']}"
conv_branch = f"c_{tracer['cid']}"
# Checkout conversation branch
repo.heads[conv_branch].checkout()
# Create commit message
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
commit_message = f"""
{query[:250]}
Response:
---
{response[:500]}...
Metadata
---
{metadata_yaml}
""".strip()
# Merge message branch into conversation branch
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
# Delete message branch after merge
repo.delete_head(msg_branch, force=True)
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
return True
except Exception as e:
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}")
return False