Only enable prompt tracer if git python is installed

This commit is contained in:
Debanjum 2024-11-01 23:57:22 -07:00
parent 5b18dc96e0
commit 31b5fde163
2 changed files with 17 additions and 2 deletions

View file

@ -88,7 +88,6 @@ dependencies = [
"anthropic == 0.26.1", "anthropic == 0.26.1",
"docx2txt == 0.8", "docx2txt == 0.8",
"google-generativeai == 0.8.3", "google-generativeai == 0.8.3",
"gitpython ~= 3.1.43",
] ]
dynamic = ["version"] dynamic = ["version"]
@ -120,6 +119,7 @@ dev = [
"mypy >= 1.0.1", "mypy >= 1.0.1",
"black >= 23.1.0", "black >= 23.1.0",
"pre-commit >= 3.0.4", "pre-commit >= 3.0.4",
"gitpython ~= 3.1.43",
] ]
[tool.hatch.version] [tool.hatch.version]

View file

@ -17,7 +17,6 @@ import PIL.Image
import requests import requests
import tiktoken import tiktoken
import yaml import yaml
from git import Repo
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp.llama import Llama from llama_cpp.llama import Llama
from transformers import AutoTokenizer from transformers import AutoTokenizer
@ -39,6 +38,13 @@ from khoj.utils.helpers import (
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
from git import Repo
except ImportError:
if in_debug_mode():
logger.warning("GitPython not installed. `pip install gitpython` to enable prompt tracer.")
model_to_prompt_size = { model_to_prompt_size = {
# OpenAI Models # OpenAI Models
"gpt-3.5-turbo": 12000, "gpt-3.5-turbo": 12000,
@ -510,6 +516,11 @@ def commit_conversation_trace(
Save trace of conversation step using git. Useful to visualize, compare and debug traces. Save trace of conversation step using git. Useful to visualize, compare and debug traces.
Returns the path to the repository. Returns the path to the repository.
""" """
try:
from git import Repo
except ImportError:
return None
# Serialize session, system message and response to yaml # Serialize session, system message and response to yaml
system_message_yaml = json.dumps(system_message, ensure_ascii=False, sort_keys=False) system_message_yaml = json.dumps(system_message, ensure_ascii=False, sort_keys=False)
response_yaml = json.dumps(response, ensure_ascii=False, sort_keys=False) response_yaml = json.dumps(response, ensure_ascii=False, sort_keys=False)
@ -617,6 +628,10 @@ def merge_message_into_conversation_trace(query: str, response: str, tracer: dic
Returns: Returns:
bool: True if merge was successful, False otherwise bool: True if merge was successful, False otherwise
""" """
try:
from git import Repo
except ImportError:
return False
try: try:
# Extract branch names # Extract branch names
msg_branch = f"m_{tracer['mid']}" msg_branch = f"m_{tracer['mid']}"