diff --git a/pyproject.toml b/pyproject.toml index 93df0b42..12c7789c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,6 +119,7 @@ dev = [ "mypy >= 1.0.1", "black >= 23.1.0", "pre-commit >= 3.0.4", + "gitpython ~= 3.1.43", ] [tool.hatch.version] diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index fb6d1909..c9a6b234 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -2,6 +2,7 @@ import base64 import logging import math import mimetypes +import os import queue from dataclasses import dataclass from datetime import datetime @@ -12,6 +13,8 @@ from typing import Any, Dict, List, Optional import PIL.Image import requests import tiktoken +import yaml +from git import Repo from langchain.schema import ChatMessage from llama_cpp.llama import Llama from transformers import AutoTokenizer @@ -344,3 +347,160 @@ def get_image_from_url(image_url: str, type="pil"): except requests.exceptions.RequestException as e: logger.error(f"Failed to get image from URL {image_url}: {e}") return ImageWithType(content=None, type=None) + + +def commit_conversation_trace( + session: list[ChatMessage], + response: str | list[dict], + tracer: dict, + system_message: str | list[dict] = "", + repo_path: str = "/tmp/khoj_promptrace", +) -> str: + """ + Save trace of conversation step using git. Useful to visualize, compare and debug traces. + Returns the path to the repository. + """ + # Serialize session, system message and response to yaml + system_message_yaml = yaml.dump(system_message, allow_unicode=True, sort_keys=False, default_flow_style=False) + response_yaml = yaml.dump(response, allow_unicode=True, sort_keys=False, default_flow_style=False) + formatted_session = [{"role": message.role, "content": message.content} for message in session] + session_yaml = yaml.dump(formatted_session, allow_unicode=True, sort_keys=False, default_flow_style=False) + query = ( + yaml.dump(session[-1].content, allow_unicode=True, sort_keys=False, default_flow_style=False) + .strip() + .removeprefix("'") + .removesuffix("'") + ) # Extract serialized query from chat session + + # Extract chat metadata for session + uid, cid, mid = tracer.get("uid", "main"), tracer.get("cid", "main"), tracer.get("mid") + + # Infer repository path from environment variable or provided path + repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace" + + try: + # Prepare git repository + os.makedirs(repo_path, exist_ok=True) + repo = Repo.init(repo_path) + + # Remove post-commit hook if it exists + hooks_dir = os.path.join(repo_path, ".git", "hooks") + post_commit_hook = os.path.join(hooks_dir, "post-commit") + if os.path.exists(post_commit_hook): + os.remove(post_commit_hook) + + # Configure git user if not set + if not repo.config_reader().has_option("user", "email"): + repo.config_writer().set_value("user", "name", "Prompt Tracer").release() + repo.config_writer().set_value("user", "email", "promptracer@khoj.dev").release() + + # Create an initial commit if the repository is newly created + if not repo.head.is_valid(): + repo.index.commit("And then there was a trace") + + # Check out the initial commit + initial_commit = repo.commit("HEAD~0") + repo.head.reference = initial_commit + repo.head.reset(index=True, working_tree=True) + + # Create or switch to user branch from initial commit + user_branch = f"u_{uid}" + if user_branch not in repo.branches: + repo.create_head(user_branch) + repo.heads[user_branch].checkout() + + # Create or switch to conversation branch from user branch + conv_branch = f"c_{cid}" + if conv_branch not in repo.branches: + repo.create_head(conv_branch) + repo.heads[conv_branch].checkout() + + # Create or switch to message branch from conversation branch + msg_branch = f"m_{mid}" if mid else None + if msg_branch and msg_branch not in repo.branches: + repo.create_head(msg_branch) + repo.heads[msg_branch].checkout() + + # Include file with content to commit + files_to_commit = {"query": session_yaml, "response": response_yaml, "system_prompt": system_message_yaml} + + # Write files and stage them + for filename, content in files_to_commit.items(): + file_path = os.path.join(repo_path, filename) + with open(file_path, "w", encoding="utf-8") as f: + f.write(content) + repo.index.add([filename]) + + # Create commit + metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False) + commit_message = f""" +{query[:250]} + +Response: +--- +{response[:500]}... + +Metadata +--- +{metadata_yaml} +""".strip() + + repo.index.commit(commit_message) + + logger.debug(f"Saved conversation trace to repo at {repo_path}") + return repo_path + except Exception as e: + logger.error(f"Failed to add conversation trace to repo: {str(e)}") + return None + + +def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path=None) -> bool: + """ + Merge the message branch into its parent conversation branch. + + Args: + query: User query + response: Assistant response + tracer: Dictionary containing uid, cid and mid + repo_path: Path to the git repository + + Returns: + bool: True if merge was successful, False otherwise + """ + try: + # Infer repository path from environment variable or provided path + repo_path = os.getenv("PROMPTRACE_DIR", repo_path) or "/tmp/promptrace" + repo = Repo(repo_path) + + # Extract branch names + msg_branch = f"m_{tracer['mid']}" + conv_branch = f"c_{tracer['cid']}" + + # Checkout conversation branch + repo.heads[conv_branch].checkout() + + # Create commit message + metadata_yaml = yaml.dump(tracer, allow_unicode=True, sort_keys=False, default_flow_style=False) + commit_message = f""" +{query[:250]} + +Response: +--- +{response[:500]}... + +Metadata +--- +{metadata_yaml} +""".strip() + + # Merge message branch into conversation branch + repo.git.merge(msg_branch, no_ff=True, m=commit_message) + + # Delete message branch after merge + repo.delete_head(msg_branch, force=True) + + logger.debug(f"Successfully merged {msg_branch} into {conv_branch}") + return True + except Exception as e: + logger.error(f"Failed to merge message {msg_branch} into conversation {conv_branch}: {str(e)}") + return False