mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Save conversation traces to git for visualization
This commit is contained in:
parent
7e0a692d16
commit
10c8fd3b2a
2 changed files with 161 additions and 0 deletions
|
@ -119,6 +119,7 @@ dev = [
|
|||
"mypy >= 1.0.1",
|
||||
"black >= 23.1.0",
|
||||
"pre-commit >= 3.0.4",
|
||||
"gitpython ~= 3.1.43",
|
||||
]
|
||||
|
||||
[tool.hatch.version]
|
||||
|
|
|
@ -2,6 +2,7 @@ import base64
|
|||
import logging
|
||||
import math
|
||||
import mimetypes
|
||||
import os
|
||||
import queue
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
|
@ -12,6 +13,8 @@ from typing import Any, Dict, List, Optional
|
|||
import PIL.Image
|
||||
import requests
|
||||
import tiktoken
|
||||
import yaml
|
||||
from git import Repo
|
||||
from langchain.schema import ChatMessage
|
||||
from llama_cpp.llama import Llama
|
||||
from transformers import AutoTokenizer
|
||||
|
@ -344,3 +347,160 @@ def get_image_from_url(image_url: str, type="pil"):
|
|||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||
return ImageWithType(content=None, type=None)
|
||||
|
||||
|
||||
def commit_conversation_trace(
|
||||
session: list[ChatMessage],
|
||||
response: str | list[dict],
|
||||
tracer: dict,
|
||||
system_message: str | list[dict] = "",
|
||||
repo_path: str = "/tmp/khoj_promptrace",
|
||||
) -> str:
|
||||
"""
|
||||
Save trace of conversation step using git. Useful to visualize, compare and debug traces.
|
||||
Returns the path to the repository.
|
||||
"""
|
||||
# Serialize session, system message and response to yaml
|
||||
system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
formatted_session = [{"role": message.role, "content": message.content} for message in session]
|
||||
session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
query = (
|
||||
yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
.strip()
|
||||
.removeprefix("'")
|
||||
.removesuffix("'")
|
||||
) # Extract serialized query from chat session
|
||||
|
||||
# Extract chat metadata for session
|
||||
uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid")
|
||||
|
||||
# Infer repository path from environment variable or provided path
|
||||
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
|
||||
|
||||
try:
|
||||
# Prepare git repository
|
||||
os.makedirs(repo_path, exist_ok=True)
|
||||
repo = Repo.init(repo_path)
|
||||
|
||||
# Remove post-commit hook if it exists
|
||||
hooks_dir = os.path.join(repo_path, ".git", "hooks")
|
||||
post_commit_hook = os.path.join(hooks_dir, "post-commit")
|
||||
if os.path.exists(post_commit_hook):
|
||||
os.remove(post_commit_hook)
|
||||
|
||||
# Configure git user if not set
|
||||
if not repo.config_reader().has_option("user", "email"):
|
||||
repo.config_writer().set_value("user", "name", "Prompt Tracer").release()
|
||||
repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release()
|
||||
|
||||
# Create an initial commit if the repository is newly created
|
||||
if not repo.head.is_valid():
|
||||
repo.index.commit("And then there was a trace")
|
||||
|
||||
# Check out the initial commit
|
||||
initial_commit = repo.commit("HEAD~0")
|
||||
repo.head.reference = initial_commit
|
||||
repo.head.reset(index=True, working_tree=True)
|
||||
|
||||
# Create or switch to user branch from initial commit
|
||||
user_branch = f"u_{uid}"
|
||||
if user_branch not in repo.branches:
|
||||
repo.create_head(user_branch)
|
||||
repo.heads[user_branch].checkout()
|
||||
|
||||
# Create or switch to conversation branch from user branch
|
||||
conv_branch = f"c_{cid}"
|
||||
if conv_branch not in repo.branches:
|
||||
repo.create_head(conv_branch)
|
||||
repo.heads[conv_branch].checkout()
|
||||
|
||||
# Create or switch to message branch from conversation branch
|
||||
msg_branch = f"m_{mid}" if mid else None
|
||||
if msg_branch and msg_branch not in repo.branches:
|
||||
repo.create_head(msg_branch)
|
||||
repo.heads[msg_branch].checkout()
|
||||
|
||||
# Include file with content to commit
|
||||
files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml}
|
||||
|
||||
# Write files and stage them
|
||||
for filename, content in files_to_commit.items():
|
||||
file_path = os.path.join(repo_path, filename)
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
repo.index.add([filename])
|
||||
|
||||
# Create commit
|
||||
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
commit_message = f"""
|
||||
{query[:250]}
|
||||
|
||||
Response:
|
||||
---
|
||||
{response[:500]}...
|
||||
|
||||
Metadata
|
||||
---
|
||||
{metadata_yaml}
|
||||
""".strip()
|
||||
|
||||
repo.index.commit(commit_message)
|
||||
|
||||
logger.debug(f"Saved conversation trace to repo at {repo_path}")
|
||||
return repo_path
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add conversation trace to repo: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path=None) -> bool:
|
||||
"""
|
||||
Merge the message branch into its parent conversation branch.
|
||||
|
||||
Args:
|
||||
query: User query
|
||||
response: Assistant response
|
||||
tracer: Dictionary containing uid, cid and mid
|
||||
repo_path: Path to the git repository
|
||||
|
||||
Returns:
|
||||
bool: True if merge was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Infer repository path from environment variable or provided path
|
||||
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace"
|
||||
repo = Repo(repo_path)
|
||||
|
||||
# Extract branch names
|
||||
msg_branch = f"m_{tracer['mid']}"
|
||||
conv_branch = f"c_{tracer['cid']}"
|
||||
|
||||
# Checkout conversation branch
|
||||
repo.heads[conv_branch].checkout()
|
||||
|
||||
# Create commit message
|
||||
metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
commit_message = f"""
|
||||
{query[:250]}
|
||||
|
||||
Response:
|
||||
---
|
||||
{response[:500]}...
|
||||
|
||||
Metadata
|
||||
---
|
||||
{metadata_yaml}
|
||||
""".strip()
|
||||
|
||||
# Merge message branch into conversation branch
|
||||
repo.git.merge(msg_branch, no_ff=True, m=commit_message)
|
||||
|
||||
# Delete message branch after merge
|
||||
repo.delete_head(msg_branch, force=True)
|
||||
|
||||
logger.debug(f"Successfully merged {msg_branch} into {conv_branch}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}")
|
||||
return False
|
||||
|
|
Loading…
Reference in a new issue