diff --git a/src/khoj/processor/conversation/anthropic/utils.py b/src/khoj/processor/conversation/anthropic/utils.py index 4f6b68c2..3fee8b43 100644 --- a/src/khoj/processor/conversation/anthropic/utils.py +++ b/src/khoj/processor/conversation/anthropic/utils.py @@ -18,7 +18,12 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils import state -from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty +from khoj.utils.helpers import ( + get_chat_usage_metrics, + in_debug_mode, + is_none_or_empty, + is_promptrace_enabled, +) logger = logging.getLogger(__name__) @@ -84,7 +89,7 @@ def anthropic_completion_with_backoff( # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature - if in_debug_mode() or state.verbose > 1: + if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) return aggregated_response @@ -156,7 +161,7 @@ def anthropic_llm_thread( # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature - if in_debug_mode() or state.verbose > 1: + if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: logger.error(f"Error in anthropic_llm_thread: {e}", exc_info=True) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index eb9b21b0..5f24362b 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -25,7 +25,12 @@ from khoj.processor.conversation.utils import ( get_image_from_url, ) from khoj.utils import state -from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode, is_none_or_empty +from khoj.utils.helpers import ( + get_chat_usage_metrics, + in_debug_mode, + is_none_or_empty, + is_promptrace_enabled, +) logger = logging.getLogger(__name__) @@ -84,7 +89,7 @@ def gemini_completion_with_backoff( # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature - if in_debug_mode() or state.verbose > 1: + if is_promptrace_enabled(): commit_conversation_trace(messages, response_text, tracer) return response_text @@ -160,7 +165,7 @@ def gemini_llm_thread( # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature - if in_debug_mode() or state.verbose > 1: + if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) except StopCandidateException as e: logger.warning( diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 998589dd..66660c43 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -19,7 +19,12 @@ from khoj.processor.conversation.utils import ( ) from khoj.utils import state from khoj.utils.constants import empty_escape_sequences -from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty +from khoj.utils.helpers import ( + ConversationCommand, + in_debug_mode, + is_none_or_empty, + is_promptrace_enabled, +) from khoj.utils.rawconfig import LocationData from khoj.utils.yaml import yaml_dump @@ -246,7 +251,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int g.send(response_delta) # Save conversation trace - if in_debug_mode() or state.verbose > 1: + if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) finally: @@ -287,7 +292,7 @@ def send_message_to_model_offline( # Streamed responses need to be saved by the calling function tracer["chat_model"] = model tracer["temperature"] = temperature - if in_debug_mode() or state.verbose > 1: + if is_promptrace_enabled(): commit_conversation_trace(messages, response_text, tracer) return response_text diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index ddc59d76..2f01be32 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -20,7 +20,11 @@ from khoj.processor.conversation.utils import ( commit_conversation_trace, ) from khoj.utils import state -from khoj.utils.helpers import get_chat_usage_metrics, in_debug_mode +from khoj.utils.helpers import ( + get_chat_usage_metrics, + in_debug_mode, + is_promptrace_enabled, +) logger = logging.getLogger(__name__) @@ -97,7 +101,7 @@ def completion_with_backoff( # Save conversation trace tracer["chat_model"] = model tracer["temperature"] = temperature - if in_debug_mode() or state.verbose > 1: + if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) return aggregated_response @@ -208,7 +212,7 @@ def llm_thread( # Save conversation trace tracer["chat_model"] = model_name tracer["temperature"] = temperature - if in_debug_mode() or state.verbose > 1: + if is_promptrace_enabled(): commit_conversation_trace(messages, aggregated_response, tracer) except Exception as e: logger.error(f"Error in llm_thread: {e}", exc_info=True) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 155a109c..21a95a29 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -34,6 +34,7 @@ from khoj.utils.helpers import ( ConversationCommand, in_debug_mode, is_none_or_empty, + is_promptrace_enabled, merge_dicts, ) from khoj.utils.rawconfig import FileAttachment @@ -292,7 +293,7 @@ def save_to_conversation_log( user_message=q, ) - if os.getenv("PROMPTRACE_DIR"): + if is_promptrace_enabled(): merge_message_into_conversation_trace(q, chat_response, tracer) logger.info( @@ -591,7 +592,7 @@ def commit_conversation_trace( return None # Infer repository path from environment variable or provided path - repo_path = repo_path or os.getenv("PROMPTRACE_DIR") + repo_path = repo_path if not is_none_or_empty(repo_path) else os.getenv("PROMPTRACE_DIR") if not repo_path: return None @@ -686,7 +687,7 @@ Metadata return None -def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path="/tmp/promptrace") -> bool: +def merge_message_into_conversation_trace(query: str, response: str, tracer: dict, repo_path=None) -> bool: """ Merge the message branch into its parent conversation branch. @@ -709,7 +710,9 @@ def merge_message_into_conversation_trace(query: str, response: str, tracer: dic conv_branch = f"c_{tracer['cid']}" # Infer repository path from environment variable or provided path - repo_path = os.getenv("PROMPTRACE_DIR", repo_path) + repo_path = repo_path if not is_none_or_empty(repo_path) else os.getenv("PROMPTRACE_DIR") + if not repo_path: + return None repo = Repo(repo_path) # Checkout conversation branch diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 02cd7a92..187e9062 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -451,6 +451,12 @@ def in_debug_mode(): return is_env_var_true("KHOJ_DEBUG") +def is_promptrace_enabled(): + """Check if Khoj is running with prompt tracing enabled. + Set PROMPTRACE_DIR environment variable to prompt tracing path to enable it.""" + return not is_none_or_empty(os.getenv("PROMPTRACE_DIR")) + + def is_valid_url(url: str) -> bool: """Check if a string is a valid URL""" try: