Fix toggling prompt tracer on/off in Khoj via PROMPTRACE_DIR env var

Previous changes to depend on just the PROMPTRACE_DIR env var instead
of KHOJ_DEBUG or verbosity flag was partial/incomplete.

This fix adds all the changes required to only depend on the
PROMPTRACE_DIR env var to enable/disable prompt tracing in Khoj.
This commit is contained in:
Debanjum 2024-11-21 13:58:51 -08:00
parent 4a40cf79c3
commit f434c3fab2
6 changed files with 44 additions and 16 deletions

View file

@ -18,7 +18,12 @@ from khoj.processor.conversation.utils import (
get_image_from_url, get_image_from_url,
) )
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty from khoj.utils.helpers import (
get_chat_usage_metrics,
in_debug_mode,
is_none_or_empty,
is_promptrace_enabled,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -84,7 +89,7 @@ def anthropic_completion_with_backoff(
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1: if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response return aggregated_response
@ -156,7 +161,7 @@ def anthropic_llm_thread(
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1: if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e: except Exception as e:
logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True)

View file

@ -25,7 +25,12 @@ from khoj.processor.conversation.utils import (
get_image_from_url, get_image_from_url,
) )
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty from khoj.utils.helpers import (
get_chat_usage_metrics,
in_debug_mode,
is_none_or_empty,
is_promptrace_enabled,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -84,7 +89,7 @@ def gemini_completion_with_backoff(
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1: if is_promptrace_enabled():
commit_conversation_trace(messages, response_text, tracer) commit_conversation_trace(messages, response_text, tracer)
return response_text return response_text
@ -160,7 +165,7 @@ def gemini_llm_thread(
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1: if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
except StopCandidateException as e: except StopCandidateException as e:
logger.warning( logger.warning(

View file

@ -19,7 +19,12 @@ from khoj.processor.conversation.utils import (
) )
from khoj.utils import state from khoj.utils import state
from khoj.utils.constants import empty_escape_sequences from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty from khoj.utils.helpers import (
ConversationCommand,
in_debug_mode,
is_none_or_empty,
is_promptrace_enabled,
)
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
from khoj.utils.yaml import yaml_dump from khoj.utils.yaml import yaml_dump
@ -246,7 +251,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int
g.send(response_delta) g.send(response_delta)
# Save conversation trace # Save conversation trace
if in_debug_mode() or state.verbose > 1: if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
finally: finally:
@ -287,7 +292,7 @@ def send_message_to_model_offline(
# Streamed responses need to be saved by the calling function # Streamed responses need to be saved by the calling function
tracer["chat_model"] = model tracer["chat_model"] = model
tracer["temperature"] = temperature tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1: if is_promptrace_enabled():
commit_conversation_trace(messages, response_text, tracer) commit_conversation_trace(messages, response_text, tracer)
return response_text return response_text

View file

@ -20,7 +20,11 @@ from khoj.processor.conversation.utils import (
commit_conversation_trace, commit_conversation_trace,
) )
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode from khoj.utils.helpers import (
get_chat_usage_metrics,
in_debug_mode,
is_promptrace_enabled,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -97,7 +101,7 @@ def completion_with_backoff(
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model tracer["chat_model"] = model
tracer["temperature"] = temperature tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1: if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
return aggregated_response return aggregated_response
@ -208,7 +212,7 @@ def llm_thread(
# Save conversation trace # Save conversation trace
tracer["chat_model"] = model_name tracer["chat_model"] = model_name
tracer["temperature"] = temperature tracer["temperature"] = temperature
if in_debug_mode() or state.verbose > 1: if is_promptrace_enabled():
commit_conversation_trace(messages, aggregated_response, tracer) commit_conversation_trace(messages, aggregated_response, tracer)
except Exception as e: except Exception as e:
logger.error(f"Error in llm_thread: {e}", exc_info=True) logger.error(f"Error in llm_thread: {e}", exc_info=True)

View file

@ -34,6 +34,7 @@ from khoj.utils.helpers import (
ConversationCommand, ConversationCommand,
in_debug_mode, in_debug_mode,
is_none_or_empty, is_none_or_empty,
is_promptrace_enabled,
merge_dicts, merge_dicts,
) )
from khoj.utils.rawconfig import FileAttachment from khoj.utils.rawconfig import FileAttachment
@ -292,7 +293,7 @@ def save_to_conversation_log(
user_message=q, user_message=q,
) )
if os.getenv("PROMPTRACE_DIR"): if is_promptrace_enabled():
merge_message_into_conversation_trace(q, chat_response, tracer) merge_message_into_conversation_trace(q, chat_response, tracer)
logger.info( logger.info(
@ -591,7 +592,7 @@ def commit_conversation_trace(
return None return None
# Infer repository path from environment variable or provided path # Infer repository path from environment variable or provided path
repo_path = repo_path or os.getenv("PROMPTRACE_DIR") repo_path = repo_path if not is_none_or_empty(repo_path) else os.getenv("PROMPTRACE_DIR")
if not repo_path: if not repo_path:
return None return None
@ -686,7 +687,7 @@ Metadata
return None return None
def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path="/tmp/promptrace") -> bool: def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path=None) -> bool:
""" """
Merge the message branch into its parent conversation branch. Merge the message branch into its parent conversation branch.
@ -709,7 +710,9 @@ def merge_message_into_conversation_trace(query: str, response: str, tracer: dic
conv_branch = f"c_{tracer['cid']}" conv_branch = f"c_{tracer['cid']}"
# Infer repository path from environment variable or provided path # Infer repository path from environment variable or provided path
repo_path = os.getenv("PROMPTRACE_DIR", repo_path) repo_path = repo_path if not is_none_or_empty(repo_path) else os.getenv("PROMPTRACE_DIR")
if not repo_path:
return None
repo = Repo(repo_path) repo = Repo(repo_path)
# Checkout conversation branch # Checkout conversation branch

View file

@ -451,6 +451,12 @@ def in_debug_mode():
return is_env_var_true("KHOJ_DEBUG") return is_env_var_true("KHOJ_DEBUG")
def is_promptrace_enabled():
"""Check if Khoj is running with prompt tracing enabled.
Set PROMPTRACE_DIR environment variable to prompt tracing path to enable it."""
return not is_none_or_empty(os.getenv("PROMPTRACE_DIR"))
def is_valid_url(url: str) -> bool: def is_valid_url(url: str) -> bool:
"""Check if a string is a valid URL""" """Check if a string is a valid URL"""
try: try: