Commit conversation traces using user, chat, message branch hierarchy

- Message train of thought forks and merges from its conversation branch
- Conversation branches from user branch
- User branches from root commit on the main branch

- Weave chat tracer metadata from api endpoint through all chat actors
  and commit it to the prompt trace
This commit is contained in:
Debanjum Singh Solanky 2024-10-23 20:02:28 -07:00
parent a3022b7556
commit ea0712424b
6 changed files with 114 additions and 21 deletions

View file

@ -23,7 +23,7 @@ from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, merge_dicts
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
logger = logging.getLogger(__name__)
model_to_prompt_size = {
@ -119,6 +119,7 @@ def save_to_conversation_log(
conversation_id: str = None,
automation_id: str = None,
query_images: List[str] = None,
tracer: Dict[str, Any] = {},
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
@ -144,6 +145,9 @@ def save_to_conversation_log(
user_message=q,
)
if in_debug_mode() or state.verbose > 1:
merge_message_into_conversation_trace(q, chat_response, tracer)
logger.info(
f"""
Saved Conversation Turn

View file

@ -28,6 +28,7 @@ async def text_to_image(
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
tracer: dict = {},
):
status_code = 200
image = None
@ -68,6 +69,7 @@ async def text_to_image(
query_images=query_images,
user=user,
agent=agent,
tracer=tracer,
)
if send_status_func:

View file

@ -64,6 +64,7 @@ async def search_online(
custom_filters: List[str] = [],
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
query += " ".join(custom_filters)
if not is_internet_connected():
@ -73,7 +74,7 @@ async def search_online(
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(
query, conversation_history, location, user, query_images=query_images, agent=agent
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
)
response_dict = {}
@ -111,7 +112,7 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
for link, data in webpages.items()
]
results = await asyncio.gather(*tasks)
@ -153,6 +154,7 @@ async def read_webpages(
send_status_func: Optional[Callable] = None,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
@ -166,7 +168,7 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
@ -192,7 +194,12 @@ async def read_webpage(
async def read_webpage_and_extract_content(
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
subqueries: set[str],
url: str,
content: str = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Tuple[set[str], str, Union[None, str]]:
# Select the web scrapers to use for reading the web page
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
@ -214,7 +221,9 @@ async def read_webpage_and_extract_content(
# Extract relevant information from the web page
if is_none_or_empty(extracted_info):
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
extracted_info = await extract_relevant_info(
subqueries, content, user=user, agent=agent, tracer=tracer
)
# If we successfully extracted information, break the loop
if not is_none_or_empty(extracted_info):

View file

@ -350,6 +350,7 @@ async def extract_references_and_questions(
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
tracer: dict = {},
):
user = request.user.object if request.user.is_authenticated else None
@ -425,6 +426,7 @@ async def extract_references_and_questions(
user=user,
max_prompt_size=conversation_config.max_prompt_size,
personality_context=personality_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
@ -442,6 +444,7 @@ async def extract_references_and_questions(
query_images=query_images,
vision_enabled=vision_enabled,
personality_context=personality_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
@ -456,6 +459,7 @@ async def extract_references_and_questions(
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
@ -471,6 +475,7 @@ async def extract_references_and_questions(
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
tracer=tracer,
)
# Collate search results as context for GPT

View file

@ -3,6 +3,7 @@ import base64
import json
import logging
import time
import uuid
from datetime import datetime
from functools import partial
from typing import Dict, Optional
@ -563,6 +564,12 @@ async def chat(
event_delimiter = "␃🔚␗"
q = unquote(q)
nonlocal conversation_id
tracer: dict = {
"mid": f"{uuid.uuid4()}",
"cid": conversation_id,
"uid": user.id,
"khoj_version": state.khoj_version,
}
uploaded_images: list[str] = []
if images:
@ -682,6 +689,7 @@ async def chat(
user=user,
query_images=uploaded_images,
agent=agent,
tracer=tracer,
)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
@ -689,7 +697,9 @@ async def chat(
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
mode = await aget_relevant_output_modes(
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
@ -755,6 +765,7 @@ async def chat(
query_images=uploaded_images,
user=user,
agent=agent,
tracer=tracer,
)
response_log = str(response)
async for result in send_llm_response(response_log):
@ -774,6 +785,7 @@ async def chat(
client_application=request.user.client_app,
conversation_id=conversation_id,
query_images=uploaded_images,
tracer=tracer,
)
return
@ -795,7 +807,7 @@ async def chat(
if ConversationCommand.Automation in conversation_commands:
try:
automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log
q, timezone, user, request.url, meta_log, tracer=tracer
)
except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
@ -817,6 +829,7 @@ async def chat(
inferred_queries=[query_to_run],
automation_id=automation.id,
query_images=uploaded_images,
tracer=tracer,
)
async for result in send_llm_response(llm_response):
yield result
@ -838,6 +851,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -882,6 +896,7 @@ async def chat(
custom_filters,
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -906,6 +921,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -956,6 +972,7 @@ async def chat(
send_status_func=partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -986,6 +1003,7 @@ async def chat(
compiled_references=compiled_references,
online_results=online_results,
query_images=uploaded_images,
tracer=tracer,
)
content_obj = {
"intentType": intent_type,
@ -1014,6 +1032,7 @@ async def chat(
user=user,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -1041,6 +1060,7 @@ async def chat(
compiled_references=compiled_references,
online_results=online_results,
query_images=uploaded_images,
tracer=tracer,
)
async for result in send_llm_response(json.dumps(content_obj)):
@ -1064,6 +1084,7 @@ async def chat(
location,
user_name,
uploaded_images,
tracer,
)
# Send Response

View file

@ -301,6 +301,7 @@ async def aget_relevant_information_sources(
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -337,6 +338,7 @@ async def aget_relevant_information_sources(
relevant_tools_prompt,
response_type="json_object",
user=user,
tracer=tracer,
)
try:
@ -378,6 +380,7 @@ async def aget_relevant_output_modes(
user: KhojUser = None,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -413,7 +416,9 @@ async def aget_relevant_output_modes(
)
with timer("Chat actor: Infer output mode for chat response", logger):
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
response = await send_message_to_model_wrapper(
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer
)
try:
response = response.strip()
@ -444,6 +449,7 @@ async def infer_webpage_urls(
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
) -> List[str]:
"""
Infer webpage links from the given query
@ -468,7 +474,11 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
online_queries_prompt,
query_images=query_images,
response_type="json_object",
user=user,
tracer=tracer,
)
# Validate that the response is a non-empty, JSON-serializable list of URLs
@ -490,6 +500,7 @@ async def generate_online_subqueries(
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
) -> List[str]:
"""
Generate subqueries from the given query
@ -514,7 +525,11 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
online_queries_prompt,
query_images=query_images,
response_type="json_object",
user=user,
tracer=tracer,
)
# Validate that the response is a non-empty, JSON-serializable list
@ -533,7 +548,7 @@ async def generate_online_subqueries(
async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
) -> Tuple[str, ...]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
@ -546,7 +561,7 @@ async def schedule_query(
)
raw_response = await send_message_to_model_wrapper(
crontime_prompt, query_images=query_images, response_type="json_object", user=user
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
)
# Validate that the response is a non-empty, JSON-serializable list
@ -561,7 +576,7 @@ async def schedule_query(
async def extract_relevant_info(
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
@ -584,6 +599,7 @@ async def extract_relevant_info(
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
user=user,
tracer=tracer,
)
return response.strip()
@ -595,6 +611,7 @@ async def extract_relevant_summary(
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
@ -622,6 +639,7 @@ async def extract_relevant_summary(
prompts.system_prompt_extract_relevant_summary,
user=user,
query_images=query_images,
tracer=tracer,
)
return response.strip()
@ -636,6 +654,7 @@ async def generate_excalidraw_diagram(
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
tracer: dict = {},
):
if send_status_func:
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
@ -650,6 +669,7 @@ async def generate_excalidraw_diagram(
query_images=query_images,
user=user,
agent=agent,
tracer=tracer,
)
if send_status_func:
@ -660,6 +680,7 @@ async def generate_excalidraw_diagram(
q=better_diagram_description_prompt,
user=user,
agent=agent,
tracer=tracer,
)
yield better_diagram_description_prompt, excalidraw_diagram_description
@ -674,6 +695,7 @@ async def generate_better_diagram_description(
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> str:
"""
Generate a diagram description from the given query and context
@ -711,7 +733,7 @@ async def generate_better_diagram_description(
with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt, query_images=query_images, user=user
improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@ -724,6 +746,7 @@ async def generate_excalidraw_diagram_from_description(
q: str,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> str:
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@ -735,7 +758,9 @@ async def generate_excalidraw_diagram_from_description(
)
with timer("Chat actor: Generate excalidraw diagram", logger):
raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user)
raw_response = await send_message_to_model_wrapper(
message=excalidraw_diagram_generation, user=user, tracer=tracer
)
raw_response = raw_response.strip()
raw_response = remove_json_codeblock(raw_response)
response: Dict[str, str] = json.loads(raw_response)
@ -756,6 +781,7 @@ async def generate_better_image_prompt(
query_images: Optional[List[str]] = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> str:
"""
Generate a better image prompt from the given query
@ -802,7 +828,9 @@ async def generate_better_image_prompt(
)
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
response = await send_message_to_model_wrapper(
image_prompt, query_images=query_images, user=user, tracer=tracer
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
@ -816,6 +844,7 @@ async def send_message_to_model_wrapper(
response_type: str = "text",
user: KhojUser = None,
query_images: List[str] = None,
tracer: dict = {},
):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled
@ -862,6 +891,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.OPENAI:
@ -885,6 +915,7 @@ async def send_message_to_model_wrapper(
model=chat_model,
response_type=response_type,
api_base_url=api_base_url,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
@ -903,6 +934,7 @@ async def send_message_to_model_wrapper(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
@ -918,7 +950,7 @@ async def send_message_to_model_wrapper(
)
return gemini_send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -929,6 +961,7 @@ def send_message_to_model_wrapper_sync(
system_message: str = "",
response_type: str = "text",
user: KhojUser = None,
tracer: dict = {},
):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
@ -961,6 +994,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens,
streaming=False,
response_type=response_type,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -975,7 +1009,11 @@ def send_message_to_model_wrapper_sync(
)
openai_response = send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
return openai_response
@ -995,6 +1033,7 @@ def send_message_to_model_wrapper_sync(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
@ -1013,6 +1052,7 @@ def send_message_to_model_wrapper_sync(
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
else:
raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -1032,6 +1072,7 @@ def generate_chat_response(
location_data: LocationData = None,
user_name: Optional[str] = None,
query_images: Optional[List[str]] = None,
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
chat_response = None
@ -1051,6 +1092,7 @@ def generate_chat_response(
client_application=client_application,
conversation_id=conversation_id,
query_images=query_images,
tracer=tracer,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
@ -1077,6 +1119,7 @@ def generate_chat_response(
location_data=location_data,
user_name=user_name,
agent=agent,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -1100,6 +1143,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
vision_available=vision_available,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
@ -1120,6 +1164,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
vision_available=vision_available,
tracer=tracer,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
@ -1139,6 +1184,7 @@ def generate_chat_response(
user_name=user_name,
agent=agent,
vision_available=vision_available,
tracer=tracer,
)
metadata.update({"chat_model": conversation_config.chat_model})
@ -1495,9 +1541,15 @@ def scheduled_chat(
async def create_automation(
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
q: str,
timezone: str,
user: KhojUser,
calling_url: URL,
meta_log: dict = {},
conversation_id: str = None,
tracer: dict = {},
):
crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer)
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
return job, crontime, query_to_run, subject