From ea0712424b006e693a8013b66244dc40df9fee78 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 23 Oct 2024 20:02:28 -0700 Subject: [PATCH] Commit conversation traces using user, chat, message branch hierarchy - Message train of thought forks and merges from its conversation branch - Conversation branches from user branch - User branches from root commit on the main branch - Weave chat tracer metadata from api endpoint through all chat actors and commit it to the prompt trace --- src/khoj/processor/conversation/utils.py | 6 +- src/khoj/processor/image/generate.py | 2 + src/khoj/processor/tools/online_search.py | 19 ++++-- src/khoj/routers/api.py | 5 ++ src/khoj/routers/api_chat.py | 25 +++++++- src/khoj/routers/helpers.py | 78 +++++++++++++++++++---- 6 files changed, 114 insertions(+), 21 deletions(-) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index c9a6b234..710de9ff 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -23,7 +23,7 @@ from khoj.database.adapters import ConversationAdapters from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.utils import state -from khoj.utils.helpers import is_none_or_empty, merge_dicts +from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts logger = logging.getLogger(__name__) model_to_prompt_size = { @@ -119,6 +119,7 @@ def save_to_conversation_log( conversation_id: str = None, automation_id: str = None, query_images: List[str] = None, + tracer: Dict[str, Any] = {}, ): user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") updated_conversation = message_to_log( @@ -144,6 +145,9 @@ def save_to_conversation_log( user_message=q, ) + if in_debug_mode() or state.verbose > 1: + merge_message_into_conversation_trace(q, chat_response, tracer) + logger.info( f""" Saved Conversation Turn diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index 343e44b3..bdc00e09 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -28,6 +28,7 @@ async def text_to_image( send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, agent: Agent = None, + tracer: dict = {}, ): status_code = 200 image = None @@ -68,6 +69,7 @@ async def text_to_image( query_images=query_images, user=user, agent=agent, + tracer=tracer, ) if send_status_func: diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index fdf1ba9f..9afb3d67 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -64,6 +64,7 @@ async def search_online( custom_filters: List[str] = [], query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ): query += " ".join(custom_filters) if not is_internet_connected(): @@ -73,7 +74,7 @@ async def search_online( # Breakdown the query into subqueries to get the correct answer subqueries = await generate_online_subqueries( - query, conversation_history, location, user, query_images=query_images, agent=agent + query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer ) response_dict = {} @@ -111,7 +112,7 @@ async def search_online( async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} tasks = [ - read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent) + read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer) for link, data in webpages.items() ] results = await asyncio.gather(*tasks) @@ -153,6 +154,7 @@ async def read_webpages( send_status_func: Optional[Callable] = None, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ): "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") @@ -166,7 +168,7 @@ async def read_webpages( webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} - tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls] + tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls] results = await asyncio.gather(*tasks) response: Dict[str, Dict] = defaultdict(dict) @@ -192,7 +194,12 @@ async def read_webpage( async def read_webpage_and_extract_content( - subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None + subqueries: set[str], + url: str, + content: str = None, + user: KhojUser = None, + agent: Agent = None, + tracer: dict = {}, ) -> Tuple[set[str], str, Union[None, str]]: # Select the web scrapers to use for reading the web page web_scrapers = await ConversationAdapters.aget_enabled_webscrapers() @@ -214,7 +221,9 @@ async def read_webpage_and_extract_content( # Extract relevant information from the web page if is_none_or_empty(extracted_info): with timer(f"Extracting relevant information from web page at '{url}' took", logger): - extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent) + extracted_info = await extract_relevant_info( + subqueries, content, user=user, agent=agent, tracer=tracer + ) # If we successfully extracted information, break the loop if not is_none_or_empty(extracted_info): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index f89ca87a..c1f218d7 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -350,6 +350,7 @@ async def extract_references_and_questions( send_status_func: Optional[Callable] = None, query_images: Optional[List[str]] = None, agent: Agent = None, + tracer: dict = {}, ): user = request.user.object if request.user.is_authenticated else None @@ -425,6 +426,7 @@ async def extract_references_and_questions( user=user, max_prompt_size=conversation_config.max_prompt_size, personality_context=personality_context, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = conversation_config.openai_config @@ -442,6 +444,7 @@ async def extract_references_and_questions( query_images=query_images, vision_enabled=vision_enabled, personality_context=personality_context, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key @@ -456,6 +459,7 @@ async def extract_references_and_questions( user=user, vision_enabled=vision_enabled, personality_context=personality_context, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -471,6 +475,7 @@ async def extract_references_and_questions( user=user, vision_enabled=vision_enabled, personality_context=personality_context, + tracer=tracer, ) # Collate search results as context for GPT diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 09ea9eea..83881ddf 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -3,6 +3,7 @@ import base64 import json import logging import time +import uuid from datetime import datetime from functools import partial from typing import Dict, Optional @@ -563,6 +564,12 @@ async def chat( event_delimiter = "␃🔚␗" q = unquote(q) nonlocal conversation_id + tracer: dict = { + "mid": f"{uuid.uuid4()}", + "cid": conversation_id, + "uid": user.id, + "khoj_version": state.khoj_version, + } uploaded_images: list[str] = [] if images: @@ -682,6 +689,7 @@ async def chat( user=user, query_images=uploaded_images, agent=agent, + tracer=tracer, ) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) async for result in send_event( @@ -689,7 +697,9 @@ async def chat( ): yield result - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent) + mode = await aget_relevant_output_modes( + q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer + ) async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): yield result if mode not in conversation_commands: @@ -755,6 +765,7 @@ async def chat( query_images=uploaded_images, user=user, agent=agent, + tracer=tracer, ) response_log = str(response) async for result in send_llm_response(response_log): @@ -774,6 +785,7 @@ async def chat( client_application=request.user.client_app, conversation_id=conversation_id, query_images=uploaded_images, + tracer=tracer, ) return @@ -795,7 +807,7 @@ async def chat( if ConversationCommand.Automation in conversation_commands: try: automation, crontime, query_to_run, subject = await create_automation( - q, timezone, user, request.url, meta_log + q, timezone, user, request.url, meta_log, tracer=tracer ) except Exception as e: logger.error(f"Error scheduling task {q} for {user.email}: {e}") @@ -817,6 +829,7 @@ async def chat( inferred_queries=[query_to_run], automation_id=automation.id, query_images=uploaded_images, + tracer=tracer, ) async for result in send_llm_response(llm_response): yield result @@ -838,6 +851,7 @@ async def chat( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -882,6 +896,7 @@ async def chat( custom_filters, query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -906,6 +921,7 @@ async def chat( partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -956,6 +972,7 @@ async def chat( send_status_func=partial(send_event, ChatEvent.STATUS), query_images=uploaded_images, agent=agent, + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -986,6 +1003,7 @@ async def chat( compiled_references=compiled_references, online_results=online_results, query_images=uploaded_images, + tracer=tracer, ) content_obj = { "intentType": intent_type, @@ -1014,6 +1032,7 @@ async def chat( user=user, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), + tracer=tracer, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1041,6 +1060,7 @@ async def chat( compiled_references=compiled_references, online_results=online_results, query_images=uploaded_images, + tracer=tracer, ) async for result in send_llm_response(json.dumps(content_obj)): @@ -1064,6 +1084,7 @@ async def chat( location, user_name, uploaded_images, + tracer, ) # Send Response diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 6cc44c4f..1475c5cd 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -301,6 +301,7 @@ async def aget_relevant_information_sources( user: KhojUser, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. @@ -337,6 +338,7 @@ async def aget_relevant_information_sources( relevant_tools_prompt, response_type="json_object", user=user, + tracer=tracer, ) try: @@ -378,6 +380,7 @@ async def aget_relevant_output_modes( user: KhojUser = None, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. @@ -413,7 +416,9 @@ async def aget_relevant_output_modes( ) with timer("Chat actor: Infer output mode for chat response", logger): - response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user) + response = await send_message_to_model_wrapper( + relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer + ) try: response = response.strip() @@ -444,6 +449,7 @@ async def infer_webpage_urls( user: KhojUser, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ) -> List[str]: """ Infer webpage links from the given query @@ -468,7 +474,11 @@ async def infer_webpage_urls( with timer("Chat actor: Infer webpage urls to read", logger): response = await send_message_to_model_wrapper( - online_queries_prompt, query_images=query_images, response_type="json_object", user=user + online_queries_prompt, + query_images=query_images, + response_type="json_object", + user=user, + tracer=tracer, ) # Validate that the response is a non-empty, JSON-serializable list of URLs @@ -490,6 +500,7 @@ async def generate_online_subqueries( user: KhojUser, query_images: List[str] = None, agent: Agent = None, + tracer: dict = {}, ) -> List[str]: """ Generate subqueries from the given query @@ -514,7 +525,11 @@ async def generate_online_subqueries( with timer("Chat actor: Generate online search subqueries", logger): response = await send_message_to_model_wrapper( - online_queries_prompt, query_images=query_images, response_type="json_object", user=user + online_queries_prompt, + query_images=query_images, + response_type="json_object", + user=user, + tracer=tracer, ) # Validate that the response is a non-empty, JSON-serializable list @@ -533,7 +548,7 @@ async def generate_online_subqueries( async def schedule_query( - q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None + q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {} ) -> Tuple[str, ...]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. @@ -546,7 +561,7 @@ async def schedule_query( ) raw_response = await send_message_to_model_wrapper( - crontime_prompt, query_images=query_images, response_type="json_object", user=user + crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer ) # Validate that the response is a non-empty, JSON-serializable list @@ -561,7 +576,7 @@ async def schedule_query( async def extract_relevant_info( - qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None + qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {} ) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus @@ -584,6 +599,7 @@ async def extract_relevant_info( extract_relevant_information, prompts.system_prompt_extract_relevant_information, user=user, + tracer=tracer, ) return response.strip() @@ -595,6 +611,7 @@ async def extract_relevant_summary( query_images: List[str] = None, user: KhojUser = None, agent: Agent = None, + tracer: dict = {}, ) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus @@ -622,6 +639,7 @@ async def extract_relevant_summary( prompts.system_prompt_extract_relevant_summary, user=user, query_images=query_images, + tracer=tracer, ) return response.strip() @@ -636,6 +654,7 @@ async def generate_excalidraw_diagram( user: KhojUser = None, agent: Agent = None, send_status_func: Optional[Callable] = None, + tracer: dict = {}, ): if send_status_func: async for event in send_status_func("**Enhancing the Diagramming Prompt**"): @@ -650,6 +669,7 @@ async def generate_excalidraw_diagram( query_images=query_images, user=user, agent=agent, + tracer=tracer, ) if send_status_func: @@ -660,6 +680,7 @@ async def generate_excalidraw_diagram( q=better_diagram_description_prompt, user=user, agent=agent, + tracer=tracer, ) yield better_diagram_description_prompt, excalidraw_diagram_description @@ -674,6 +695,7 @@ async def generate_better_diagram_description( query_images: List[str] = None, user: KhojUser = None, agent: Agent = None, + tracer: dict = {}, ) -> str: """ Generate a diagram description from the given query and context @@ -711,7 +733,7 @@ async def generate_better_diagram_description( with timer("Chat actor: Generate better diagram description", logger): response = await send_message_to_model_wrapper( - improve_diagram_description_prompt, query_images=query_images, user=user + improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): @@ -724,6 +746,7 @@ async def generate_excalidraw_diagram_from_description( q: str, user: KhojUser = None, agent: Agent = None, + tracer: dict = {}, ) -> str: personality_context = ( prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" @@ -735,7 +758,9 @@ async def generate_excalidraw_diagram_from_description( ) with timer("Chat actor: Generate excalidraw diagram", logger): - raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user) + raw_response = await send_message_to_model_wrapper( + message=excalidraw_diagram_generation, user=user, tracer=tracer + ) raw_response = raw_response.strip() raw_response = remove_json_codeblock(raw_response) response: Dict[str, str] = json.loads(raw_response) @@ -756,6 +781,7 @@ async def generate_better_image_prompt( query_images: Optional[List[str]] = None, user: KhojUser = None, agent: Agent = None, + tracer: dict = {}, ) -> str: """ Generate a better image prompt from the given query @@ -802,7 +828,9 @@ async def generate_better_image_prompt( ) with timer("Chat actor: Generate contextual image prompt", logger): - response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user) + response = await send_message_to_model_wrapper( + image_prompt, query_images=query_images, user=user, tracer=tracer + ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): response = response[1:-1] @@ -816,6 +844,7 @@ async def send_message_to_model_wrapper( response_type: str = "text", user: KhojUser = None, query_images: List[str] = None, + tracer: dict = {}, ): conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) vision_available = conversation_config.vision_enabled @@ -862,6 +891,7 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, streaming=False, response_type=response_type, + tracer=tracer, ) elif model_type == ChatModelOptions.ModelType.OPENAI: @@ -885,6 +915,7 @@ async def send_message_to_model_wrapper( model=chat_model, response_type=response_type, api_base_url=api_base_url, + tracer=tracer, ) elif model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key @@ -903,6 +934,7 @@ async def send_message_to_model_wrapper( messages=truncated_messages, api_key=api_key, model=chat_model, + tracer=tracer, ) elif model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -918,7 +950,7 @@ async def send_message_to_model_wrapper( ) return gemini_send_message_to_model( - messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type + messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer ) else: raise HTTPException(status_code=500, detail="Invalid conversation config") @@ -929,6 +961,7 @@ def send_message_to_model_wrapper_sync( system_message: str = "", response_type: str = "text", user: KhojUser = None, + tracer: dict = {}, ): conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) @@ -961,6 +994,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, streaming=False, response_type=response_type, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: @@ -975,7 +1009,11 @@ def send_message_to_model_wrapper_sync( ) openai_response = send_message_to_model( - messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type + messages=truncated_messages, + api_key=api_key, + model=chat_model, + response_type=response_type, + tracer=tracer, ) return openai_response @@ -995,6 +1033,7 @@ def send_message_to_model_wrapper_sync( messages=truncated_messages, api_key=api_key, model=chat_model, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: @@ -1013,6 +1052,7 @@ def send_message_to_model_wrapper_sync( api_key=api_key, model=chat_model, response_type=response_type, + tracer=tracer, ) else: raise HTTPException(status_code=500, detail="Invalid conversation config") @@ -1032,6 +1072,7 @@ def generate_chat_response( location_data: LocationData = None, user_name: Optional[str] = None, query_images: Optional[List[str]] = None, + tracer: dict = {}, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables chat_response = None @@ -1051,6 +1092,7 @@ def generate_chat_response( client_application=client_application, conversation_id=conversation_id, query_images=query_images, + tracer=tracer, ) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) @@ -1077,6 +1119,7 @@ def generate_chat_response( location_data=location_data, user_name=user_name, agent=agent, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: @@ -1100,6 +1143,7 @@ def generate_chat_response( user_name=user_name, agent=agent, vision_available=vision_available, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: @@ -1120,6 +1164,7 @@ def generate_chat_response( user_name=user_name, agent=agent, vision_available=vision_available, + tracer=tracer, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -1139,6 +1184,7 @@ def generate_chat_response( user_name=user_name, agent=agent, vision_available=vision_available, + tracer=tracer, ) metadata.update({"chat_model": conversation_config.chat_model}) @@ -1495,9 +1541,15 @@ def scheduled_chat( async def create_automation( - q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None + q: str, + timezone: str, + user: KhojUser, + calling_url: URL, + meta_log: dict = {}, + conversation_id: str = None, + tracer: dict = {}, ): - crontime, query_to_run, subject = await schedule_query(q, meta_log, user) + crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer) job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id) return job, crontime, query_to_run, subject