Save conversation traces to git for visualization

This commit is contained in:
Debanjum Singh Solanky 2024-10-23 19:21:43 -07:00
parent 7e0a692d16
commit 10c8fd3b2a
2 changed files with 161 additions and 0 deletions

View file

@ -119,6 +119,7 @@ dev = [
"mypy >= 1.0.1", "mypy >= 1.0.1",
"black >= 23.1.0", "black >= 23.1.0",
"pre-commit >= 3.0.4", "pre-commit >= 3.0.4",
"gitpython ~= 3.1.43",
] ]
[tool.hatch.version] [tool.hatch.version]

View file

@ -2,6 +2,7 @@ import base64
import logging import logging
import math import math
import mimetypes import mimetypes
import os
import queue import queue
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
@ -12,6 +13,8 @@ from typing import Any, Dict, List, Optional
import PIL.Image import PIL.Image
import requests import requests
import tiktoken import tiktoken
import yaml
from git import Repo
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp.llama import Llama from llama_cpp.llama import Llama
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -344,3 +347,160 @@ def get_image_from_url(image_url: str, type="pil"):
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}") logger.error(f"Failed to get image from URL {image_url}: {e}")
return ImageWithType(content=None, type=None) return ImageWithType(content=None, type=None)
def commit_conversation_trace(
session: list[ChatMessage],
response: str | list[dict],
tracer: dict,
system_message: str | list[dict] = "",
repo_path: str = "/tmp/khoj_promptrace",
) -> str:
"""
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
Returns the path to the repository.
"""
# Serialize session, system message and response to yaml
system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False)
response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False)
formatted_session = [{"role": message.role, "content": message.content} for message in session]
session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False)
query = (
yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False)
.strip()
.removeprefix("'")
.removesuffix("'")
) # Extract serialized query from chat session
# Extract chat metadata for session
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
# Infer repository path from environment variable or provided path
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
try:
# Prepare git repository
os.makedirs(repo_path, exist_ok=True)
repo = Repo.init(repo_path)
# Remove post-commit hook if it exists
hooks_dir = os.path.join(repo_path, ".git", "hooks")
post_commit_hook = os.path.join(hooks_dir, "post-commit")
if os.path.exists(post_commit_hook):
os.remove(post_commit_hook)
# Configure git user if not set
if not repo.config_reader().has_option("user", "email"):
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
# Create an initial commit if the repository is newly created
if not repo.head.is_valid():
repo.index.commit("And then there was a trace")
# Check out the initial commit
initial_commit = repo.commit("HEAD~0")
repo.head.reference = initial_commit
repo.head.reset(index=True, working_tree=True)
# Create or switch to user branch from initial commit
user_branch = f"u_{uid}"
if user_branch not in repo.branches:
repo.create_head(user_branch)
repo.heads[user_branch].checkout()
# Create or switch to conversation branch from user branch
conv_branch = f"c_{cid}"
if conv_branch not in repo.branches:
repo.create_head(conv_branch)
repo.heads[conv_branch].checkout()
# Create or switch to message branch from conversation branch
msg_branch = f"m_{mid}" if mid else None
if msg_branch and msg_branch not in repo.branches:
repo.create_head(msg_branch)
repo.heads[msg_branch].checkout()
# Include file with content to commit
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
# Write files and stage them
for filename, content in files_to_commit.items():
file_path = os.path.join(repo_path, filename)
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
repo.index.add([filename])
# Create commit
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
commit_message = f"""
{query[:250]}
Response:
---
{response[:500]}...
Metadata
---
{metadata_yaml}
""".strip()
repo.index.commit(commit_message)
logger.debug(f"Saved conversation trace to repo at {repo_path}")
return repo_path
except Exception as e:
logger.error(f"Failed to add conversation trace to repo: {str(e)}")
return None
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path=None) -> bool:
"""
Merge the message branch into its parent conversation branch.
Args:
query: User query
response: Assistant response
tracer: Dictionary containing uid, cid and mid
repo_path: Path to the git repository
Returns:
bool: True if merge was successful, False otherwise
"""
try:
# Infer repository path from environment variable or provided path
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
repo = Repo(repo_path)
# Extract branch names
msg_branch = f"m_{tracer['mid']}"
conv_branch = f"c_{tracer['cid']}"
# Checkout conversation branch
repo.heads[conv_branch].checkout()
# Create commit message
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
commit_message = f"""
{query[:250]}
Response:
---
{response[:500]}...
Metadata
---
{metadata_yaml}
""".strip()
# Merge message branch into conversation branch
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
# Delete message branch after merge
repo.delete_head(msg_branch, force=True)
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
return True
except Exception as e:
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}")
return False