Commit conversation traces using user, chat, message branch hierarchy

- Message train of thought forks and merges from its conversation branch
- Conversation branches from user branch
- User branches from root commit on the main branch

- Weave chat tracer metadata from api endpoint through all chat actors
  and commit it to the prompt trace
This commit is contained in:
Debanjum Singh Solanky 2024-10-23 20:02:28 -07:00
parent a3022b7556
commit ea0712424b
6 changed files with 114 additions and 21 deletions

View file

@ -23,7 +23,7 @@ from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import is_none_or_empty, merge_dicts from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
model_to_prompt_size = { model_to_prompt_size = {
@ -119,6 +119,7 @@ def save_to_conversation_log(
conversation_id: str = None, conversation_id: str = None,
automation_id: str = None, automation_id: str = None,
query_images: List[str] = None, query_images: List[str] = None,
tracer: Dict[str, Any] = {},
): ):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log( updated_conversation = message_to_log(
@ -144,6 +145,9 @@ def save_to_conversation_log(
user_message=q, user_message=q,
) )
if in_debug_mode() or state.verbose > 1:
merge_message_into_conversation_trace(q, chat_response, tracer)
logger.info( logger.info(
f""" f"""
Saved Conversation Turn Saved Conversation Turn

View file

@ -28,6 +28,7 @@ async def text_to_image(
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
): ):
status_code = 200 status_code = 200
image = None image = None
@ -68,6 +69,7 @@ async def text_to_image(
query_images=query_images, query_images=query_images,
user=user, user=user,
agent=agent, agent=agent,
tracer=tracer,
) )
if send_status_func: if send_status_func:

View file

@ -64,6 +64,7 @@ async def search_online(
custom_filters: List[str] = [], custom_filters: List[str] = [],
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
): ):
query += " ".join(custom_filters) query += " ".join(custom_filters)
if not is_internet_connected(): if not is_internet_connected():
@ -73,7 +74,7 @@ async def search_online(
# Breakdown the query into subqueries to get the correct answer # Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries( subqueries = await generate_online_subqueries(
query, conversation_history, location, user, query_images=query_images, agent=agent query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
) )
response_dict = {} response_dict = {}
@ -111,7 +112,7 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
tasks = [ tasks = [
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent) read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
for link, data in webpages.items() for link, data in webpages.items()
] ]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
@ -153,6 +154,7 @@ async def read_webpages(
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
): ):
"Infer web pages to read from the query and extract relevant information from them" "Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read") logger.info(f"Inferring web pages to read")
@ -166,7 +168,7 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls)) webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls] tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict) response: Dict[str, Dict] = defaultdict(dict)
@ -192,7 +194,12 @@ async def read_webpage(
async def read_webpage_and_extract_content( async def read_webpage_and_extract_content(
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None subqueries: set[str],
url: str,
content: str = None,
user: KhojUser = None,
agent: Agent = None,
tracer: dict = {},
) -> Tuple[set[str], str, Union[None, str]]: ) -> Tuple[set[str], str, Union[None, str]]:
# Select the web scrapers to use for reading the web page # Select the web scrapers to use for reading the web page
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers() web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
@ -214,7 +221,9 @@ async def read_webpage_and_extract_content(
# Extract relevant information from the web page # Extract relevant information from the web page
if is_none_or_empty(extracted_info): if is_none_or_empty(extracted_info):
with timer(f"Extracting relevant information from web page at '{url}' took", logger): with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent) extracted_info = await extract_relevant_info(
subqueries, content, user=user, agent=agent, tracer=tracer
)
# If we successfully extracted information, break the loop # If we successfully extracted information, break the loop
if not is_none_or_empty(extracted_info): if not is_none_or_empty(extracted_info):

View file

@ -350,6 +350,7 @@ async def extract_references_and_questions(
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
): ):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
@ -425,6 +426,7 @@ async def extract_references_and_questions(
user=user, user=user,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=conversation_config.max_prompt_size,
personality_context=personality_context, personality_context=personality_context,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config openai_chat_config = conversation_config.openai_config
@ -442,6 +444,7 @@ async def extract_references_and_questions(
query_images=query_images, query_images=query_images,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
@ -456,6 +459,7 @@ async def extract_references_and_questions(
user=user, user=user,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
@ -471,6 +475,7 @@ async def extract_references_and_questions(
user=user, user=user,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
tracer=tracer,
) )
# Collate search results as context for GPT # Collate search results as context for GPT

View file

@ -3,6 +3,7 @@ import base64
import json import json
import logging import logging
import time import time
import uuid
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from typing import Dict, Optional from typing import Dict, Optional
@ -563,6 +564,12 @@ async def chat(
event_delimiter = "␃🔚␗" event_delimiter = "␃🔚␗"
q = unquote(q) q = unquote(q)
nonlocal conversation_id nonlocal conversation_id
tracer: dict = {
"mid": f"{uuid.uuid4()}",
"cid": conversation_id,
"uid": user.id,
"khoj_version": state.khoj_version,
}
uploaded_images: list[str] = [] uploaded_images: list[str] = []
if images: if images:
@ -682,6 +689,7 @@ async def chat(
user=user, user=user,
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer,
) )
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event( async for result in send_event(
@ -689,7 +697,9 @@ async def chat(
): ):
yield result yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent) mode = await aget_relevant_output_modes(
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result yield result
if mode not in conversation_commands: if mode not in conversation_commands:
@ -755,6 +765,7 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
user=user, user=user,
agent=agent, agent=agent,
tracer=tracer,
) )
response_log = str(response) response_log = str(response)
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log):
@ -774,6 +785,7 @@ async def chat(
client_application=request.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer,
) )
return return
@ -795,7 +807,7 @@ async def chat(
if ConversationCommand.Automation in conversation_commands: if ConversationCommand.Automation in conversation_commands:
try: try:
automation, crontime, query_to_run, subject = await create_automation( automation, crontime, query_to_run, subject = await create_automation(
q, timezone, user, request.url, meta_log q, timezone, user, request.url, meta_log, tracer=tracer
) )
except Exception as e: except Exception as e:
logger.error(f"Error scheduling task {q} for {user.email}: {e}") logger.error(f"Error scheduling task {q} for {user.email}: {e}")
@ -817,6 +829,7 @@ async def chat(
inferred_queries=[query_to_run], inferred_queries=[query_to_run],
automation_id=automation.id, automation_id=automation.id,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer,
) )
async for result in send_llm_response(llm_response): async for result in send_llm_response(llm_response):
yield result yield result
@ -838,6 +851,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -882,6 +896,7 @@ async def chat(
custom_filters, custom_filters,
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -906,6 +921,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -956,6 +972,7 @@ async def chat(
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -986,6 +1003,7 @@ async def chat(
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer,
) )
content_obj = { content_obj = {
"intentType": intent_type, "intentType": intent_type,
@ -1014,6 +1032,7 @@ async def chat(
user=user, user=user,
agent=agent, agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -1041,6 +1060,7 @@ async def chat(
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer,
) )
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj)):
@ -1064,6 +1084,7 @@ async def chat(
location, location,
user_name, user_name,
uploaded_images, uploaded_images,
tracer,
) )
# Send Response # Send Response

View file

@ -301,6 +301,7 @@ async def aget_relevant_information_sources(
user: KhojUser, user: KhojUser,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
): ):
""" """
Given a query, determine which of the available tools the agent should use in order to answer appropriately. Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -337,6 +338,7 @@ async def aget_relevant_information_sources(
relevant_tools_prompt, relevant_tools_prompt,
response_type="json_object", response_type="json_object",
user=user, user=user,
tracer=tracer,
) )
try: try:
@ -378,6 +380,7 @@ async def aget_relevant_output_modes(
user: KhojUser = None, user: KhojUser = None,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
): ):
""" """
Given a query, determine which of the available tools the agent should use in order to answer appropriately. Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -413,7 +416,9 @@ async def aget_relevant_output_modes(
) )
with timer("Chat actor: Infer output mode for chat response", logger): with timer("Chat actor: Infer output mode for chat response", logger):
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user) response = await send_message_to_model_wrapper(
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer
)
try: try:
response = response.strip() response = response.strip()
@ -444,6 +449,7 @@ async def infer_webpage_urls(
user: KhojUser, user: KhojUser,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> List[str]: ) -> List[str]:
""" """
Infer webpage links from the given query Infer webpage links from the given query
@ -468,7 +474,11 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger): with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
online_queries_prompt, query_images=query_images, response_type="json_object", user=user online_queries_prompt,
query_images=query_images,
response_type="json_object",
user=user,
tracer=tracer,
) )
# Validate that the response is a non-empty, JSON-serializable list of URLs # Validate that the response is a non-empty, JSON-serializable list of URLs
@ -490,6 +500,7 @@ async def generate_online_subqueries(
user: KhojUser, user: KhojUser,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> List[str]: ) -> List[str]:
""" """
Generate subqueries from the given query Generate subqueries from the given query
@ -514,7 +525,11 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger): with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
online_queries_prompt, query_images=query_images, response_type="json_object", user=user online_queries_prompt,
query_images=query_images,
response_type="json_object",
user=user,
tracer=tracer,
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
@ -533,7 +548,7 @@ async def generate_online_subqueries(
async def schedule_query( async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
) -> Tuple[str, ...]: ) -> Tuple[str, ...]:
""" """
Schedule the date, time to run the query. Assume the server timezone is UTC. Schedule the date, time to run the query. Assume the server timezone is UTC.
@ -546,7 +561,7 @@ async def schedule_query(
) )
raw_response = await send_message_to_model_wrapper( raw_response = await send_message_to_model_wrapper(
crontime_prompt, query_images=query_images, response_type="json_object", user=user crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
@ -561,7 +576,7 @@ async def schedule_query(
async def extract_relevant_info( async def extract_relevant_info(
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
) -> Union[str, None]: ) -> Union[str, None]:
""" """
Extract relevant information for a given query from the target corpus Extract relevant information for a given query from the target corpus
@ -584,6 +599,7 @@ async def extract_relevant_info(
extract_relevant_information, extract_relevant_information,
prompts.system_prompt_extract_relevant_information, prompts.system_prompt_extract_relevant_information,
user=user, user=user,
tracer=tracer,
) )
return response.strip() return response.strip()
@ -595,6 +611,7 @@ async def extract_relevant_summary(
query_images: List[str] = None, query_images: List[str] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> Union[str, None]: ) -> Union[str, None]:
""" """
Extract relevant information for a given query from the target corpus Extract relevant information for a given query from the target corpus
@ -622,6 +639,7 @@ async def extract_relevant_summary(
prompts.system_prompt_extract_relevant_summary, prompts.system_prompt_extract_relevant_summary,
user=user, user=user,
query_images=query_images, query_images=query_images,
tracer=tracer,
) )
return response.strip() return response.strip()
@ -636,6 +654,7 @@ async def generate_excalidraw_diagram(
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
tracer: dict = {},
): ):
if send_status_func: if send_status_func:
async for event in send_status_func("**Enhancing the Diagramming Prompt**"): async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
@ -650,6 +669,7 @@ async def generate_excalidraw_diagram(
query_images=query_images, query_images=query_images,
user=user, user=user,
agent=agent, agent=agent,
tracer=tracer,
) )
if send_status_func: if send_status_func:
@ -660,6 +680,7 @@ async def generate_excalidraw_diagram(
q=better_diagram_description_prompt, q=better_diagram_description_prompt,
user=user, user=user,
agent=agent, agent=agent,
tracer=tracer,
) )
yield better_diagram_description_prompt, excalidraw_diagram_description yield better_diagram_description_prompt, excalidraw_diagram_description
@ -674,6 +695,7 @@ async def generate_better_diagram_description(
query_images: List[str] = None, query_images: List[str] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> str: ) -> str:
""" """
Generate a diagram description from the given query and context Generate a diagram description from the given query and context
@ -711,7 +733,7 @@ async def generate_better_diagram_description(
with timer("Chat actor: Generate better diagram description", logger): with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
improve_diagram_description_prompt, query_images=query_images, user=user improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer
) )
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@ -724,6 +746,7 @@ async def generate_excalidraw_diagram_from_description(
q: str, q: str,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> str: ) -> str:
personality_context = ( personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@ -735,7 +758,9 @@ async def generate_excalidraw_diagram_from_description(
) )
with timer("Chat actor: Generate excalidraw diagram", logger): with timer("Chat actor: Generate excalidraw diagram", logger):
raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user) raw_response = await send_message_to_model_wrapper(
message=excalidraw_diagram_generation, user=user, tracer=tracer
)
raw_response = raw_response.strip() raw_response = raw_response.strip()
raw_response = remove_json_codeblock(raw_response) raw_response = remove_json_codeblock(raw_response)
response: Dict[str, str] = json.loads(raw_response) response: Dict[str, str] = json.loads(raw_response)
@ -756,6 +781,7 @@ async def generate_better_image_prompt(
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
) -> str: ) -> str:
""" """
Generate a better image prompt from the given query Generate a better image prompt from the given query
@ -802,7 +828,9 @@ async def generate_better_image_prompt(
) )
with timer("Chat actor: Generate contextual image prompt", logger): with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user) response = await send_message_to_model_wrapper(
image_prompt, query_images=query_images, user=user, tracer=tracer
)
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1] response = response[1:-1]
@ -816,6 +844,7 @@ async def send_message_to_model_wrapper(
response_type: str = "text", response_type: str = "text",
user: KhojUser = None, user: KhojUser = None,
query_images: List[str] = None, query_images: List[str] = None,
tracer: dict = {},
): ):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled vision_available = conversation_config.vision_enabled
@ -862,6 +891,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
streaming=False, streaming=False,
response_type=response_type, response_type=response_type,
tracer=tracer,
) )
elif model_type == ChatModelOptions.ModelType.OPENAI: elif model_type == ChatModelOptions.ModelType.OPENAI:
@ -885,6 +915,7 @@ async def send_message_to_model_wrapper(
model=chat_model, model=chat_model,
response_type=response_type, response_type=response_type,
api_base_url=api_base_url, api_base_url=api_base_url,
tracer=tracer,
) )
elif model_type == ChatModelOptions.ModelType.ANTHROPIC: elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
@ -903,6 +934,7 @@ async def send_message_to_model_wrapper(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
model=chat_model, model=chat_model,
tracer=tracer,
) )
elif model_type == ChatModelOptions.ModelType.GOOGLE: elif model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
@ -918,7 +950,7 @@ async def send_message_to_model_wrapper(
) )
return gemini_send_message_to_model( return gemini_send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
) )
else: else:
raise HTTPException(status_code=500, detail="Invalid conversation config") raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -929,6 +961,7 @@ def send_message_to_model_wrapper_sync(
system_message: str = "", system_message: str = "",
response_type: str = "text", response_type: str = "text",
user: KhojUser = None, user: KhojUser = None,
tracer: dict = {},
): ):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
@ -961,6 +994,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
streaming=False, streaming=False,
response_type=response_type, response_type=response_type,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -975,7 +1009,11 @@ def send_message_to_model_wrapper_sync(
) )
openai_response = send_message_to_model( openai_response = send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
) )
return openai_response return openai_response
@ -995,6 +1033,7 @@ def send_message_to_model_wrapper_sync(
messages=truncated_messages, messages=truncated_messages,
api_key=api_key, api_key=api_key,
model=chat_model, model=chat_model,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
@ -1013,6 +1052,7 @@ def send_message_to_model_wrapper_sync(
api_key=api_key, api_key=api_key,
model=chat_model, model=chat_model,
response_type=response_type, response_type=response_type,
tracer=tracer,
) )
else: else:
raise HTTPException(status_code=500, detail="Invalid conversation config") raise HTTPException(status_code=500, detail="Invalid conversation config")
@ -1032,6 +1072,7 @@ def generate_chat_response(
location_data: LocationData = None, location_data: LocationData = None,
user_name: Optional[str] = None, user_name: Optional[str] = None,
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables # Initialize Variables
chat_response = None chat_response = None
@ -1051,6 +1092,7 @@ def generate_chat_response(
client_application=client_application, client_application=client_application,
conversation_id=conversation_id, conversation_id=conversation_id,
query_images=query_images, query_images=query_images,
tracer=tracer,
) )
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
@ -1077,6 +1119,7 @@ def generate_chat_response(
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -1100,6 +1143,7 @@ def generate_chat_response(
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
@ -1120,6 +1164,7 @@ def generate_chat_response(
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key api_key = conversation_config.openai_config.api_key
@ -1139,6 +1184,7 @@ def generate_chat_response(
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
tracer=tracer,
) )
metadata.update({"chat_model": conversation_config.chat_model}) metadata.update({"chat_model": conversation_config.chat_model})
@ -1495,9 +1541,15 @@ def scheduled_chat(
async def create_automation( async def create_automation(
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None q: str,
timezone: str,
user: KhojUser,
calling_url: URL,
meta_log: dict = {},
conversation_id: str = None,
tracer: dict = {},
): ):
crontime, query_to_run, subject = await schedule_query(q, meta_log, user) crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer)
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id) job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
return job, crontime, query_to_run, subject return job, crontime, query_to_run, subject