mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Commit conversation traces using user, chat, message branch hierarchy
- Message train of thought forks and merges from its conversation branch - Conversation branches from user branch - User branches from root commit on the main branch - Weave chat tracer metadata from api endpoint through all chat actors and commit it to the prompt trace
This commit is contained in:
parent
a3022b7556
commit
ea0712424b
6 changed files with 114 additions and 21 deletions
|
@ -23,7 +23,7 @@ from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
||||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
model_to_prompt_size = {
|
model_to_prompt_size = {
|
||||||
|
@ -119,6 +119,7 @@ def save_to_conversation_log(
|
||||||
conversation_id: str = None,
|
conversation_id: str = None,
|
||||||
automation_id: str = None,
|
automation_id: str = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
|
tracer: Dict[str, Any] = {},
|
||||||
):
|
):
|
||||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
updated_conversation = message_to_log(
|
updated_conversation = message_to_log(
|
||||||
|
@ -144,6 +145,9 @@ def save_to_conversation_log(
|
||||||
user_message=q,
|
user_message=q,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if in_debug_mode() or state.verbose > 1:
|
||||||
|
merge_message_into_conversation_trace(q, chat_response, tracer)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"""
|
f"""
|
||||||
Saved Conversation Turn
|
Saved Conversation Turn
|
||||||
|
|
|
@ -28,6 +28,7 @@ async def text_to_image(
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
status_code = 200
|
status_code = 200
|
||||||
image = None
|
image = None
|
||||||
|
@ -68,6 +69,7 @@ async def text_to_image(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
|
|
|
@ -64,6 +64,7 @@ async def search_online(
|
||||||
custom_filters: List[str] = [],
|
custom_filters: List[str] = [],
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
query += " ".join(custom_filters)
|
query += " ".join(custom_filters)
|
||||||
if not is_internet_connected():
|
if not is_internet_connected():
|
||||||
|
@ -73,7 +74,7 @@ async def search_online(
|
||||||
|
|
||||||
# Breakdown the query into subqueries to get the correct answer
|
# Breakdown the query into subqueries to get the correct answer
|
||||||
subqueries = await generate_online_subqueries(
|
subqueries = await generate_online_subqueries(
|
||||||
query, conversation_history, location, user, query_images=query_images, agent=agent
|
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
|
||||||
)
|
)
|
||||||
response_dict = {}
|
response_dict = {}
|
||||||
|
|
||||||
|
@ -111,7 +112,7 @@ async def search_online(
|
||||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [
|
tasks = [
|
||||||
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
|
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
|
||||||
for link, data in webpages.items()
|
for link, data in webpages.items()
|
||||||
]
|
]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
@ -153,6 +154,7 @@ async def read_webpages(
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"Infer web pages to read from the query and extract relevant information from them"
|
"Infer web pages to read from the query and extract relevant information from them"
|
||||||
logger.info(f"Inferring web pages to read")
|
logger.info(f"Inferring web pages to read")
|
||||||
|
@ -166,7 +168,7 @@ async def read_webpages(
|
||||||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
|
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
response: Dict[str, Dict] = defaultdict(dict)
|
response: Dict[str, Dict] = defaultdict(dict)
|
||||||
|
@ -192,7 +194,12 @@ async def read_webpage(
|
||||||
|
|
||||||
|
|
||||||
async def read_webpage_and_extract_content(
|
async def read_webpage_and_extract_content(
|
||||||
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
|
subqueries: set[str],
|
||||||
|
url: str,
|
||||||
|
content: str = None,
|
||||||
|
user: KhojUser = None,
|
||||||
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> Tuple[set[str], str, Union[None, str]]:
|
) -> Tuple[set[str], str, Union[None, str]]:
|
||||||
# Select the web scrapers to use for reading the web page
|
# Select the web scrapers to use for reading the web page
|
||||||
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
|
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
|
||||||
|
@ -214,7 +221,9 @@ async def read_webpage_and_extract_content(
|
||||||
# Extract relevant information from the web page
|
# Extract relevant information from the web page
|
||||||
if is_none_or_empty(extracted_info):
|
if is_none_or_empty(extracted_info):
|
||||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||||
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
|
extracted_info = await extract_relevant_info(
|
||||||
|
subqueries, content, user=user, agent=agent, tracer=tracer
|
||||||
|
)
|
||||||
|
|
||||||
# If we successfully extracted information, break the loop
|
# If we successfully extracted information, break the loop
|
||||||
if not is_none_or_empty(extracted_info):
|
if not is_none_or_empty(extracted_info):
|
||||||
|
|
|
@ -350,6 +350,7 @@ async def extract_references_and_questions(
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
|
|
||||||
|
@ -425,6 +426,7 @@ async def extract_references_and_questions(
|
||||||
user=user,
|
user=user,
|
||||||
max_prompt_size=conversation_config.max_prompt_size,
|
max_prompt_size=conversation_config.max_prompt_size,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
openai_chat_config = conversation_config.openai_config
|
openai_chat_config = conversation_config.openai_config
|
||||||
|
@ -442,6 +444,7 @@ async def extract_references_and_questions(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
@ -456,6 +459,7 @@ async def extract_references_and_questions(
|
||||||
user=user,
|
user=user,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
@ -471,6 +475,7 @@ async def extract_references_and_questions(
|
||||||
user=user,
|
user=user,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Collate search results as context for GPT
|
# Collate search results as context for GPT
|
||||||
|
|
|
@ -3,6 +3,7 @@ import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
@ -563,6 +564,12 @@ async def chat(
|
||||||
event_delimiter = "␃🔚␗"
|
event_delimiter = "␃🔚␗"
|
||||||
q = unquote(q)
|
q = unquote(q)
|
||||||
nonlocal conversation_id
|
nonlocal conversation_id
|
||||||
|
tracer: dict = {
|
||||||
|
"mid": f"{uuid.uuid4()}",
|
||||||
|
"cid": conversation_id,
|
||||||
|
"uid": user.id,
|
||||||
|
"khoj_version": state.khoj_version,
|
||||||
|
}
|
||||||
|
|
||||||
uploaded_images: list[str] = []
|
uploaded_images: list[str] = []
|
||||||
if images:
|
if images:
|
||||||
|
@ -682,6 +689,7 @@ async def chat(
|
||||||
user=user,
|
user=user,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||||
async for result in send_event(
|
async for result in send_event(
|
||||||
|
@ -689,7 +697,9 @@ async def chat(
|
||||||
):
|
):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
|
mode = await aget_relevant_output_modes(
|
||||||
|
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
|
||||||
|
)
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||||
yield result
|
yield result
|
||||||
if mode not in conversation_commands:
|
if mode not in conversation_commands:
|
||||||
|
@ -755,6 +765,7 @@ async def chat(
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
response_log = str(response)
|
response_log = str(response)
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log):
|
||||||
|
@ -774,6 +785,7 @@ async def chat(
|
||||||
client_application=request.user.client_app,
|
client_application=request.user.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -795,7 +807,7 @@ async def chat(
|
||||||
if ConversationCommand.Automation in conversation_commands:
|
if ConversationCommand.Automation in conversation_commands:
|
||||||
try:
|
try:
|
||||||
automation, crontime, query_to_run, subject = await create_automation(
|
automation, crontime, query_to_run, subject = await create_automation(
|
||||||
q, timezone, user, request.url, meta_log
|
q, timezone, user, request.url, meta_log, tracer=tracer
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||||
|
@ -817,6 +829,7 @@ async def chat(
|
||||||
inferred_queries=[query_to_run],
|
inferred_queries=[query_to_run],
|
||||||
automation_id=automation.id,
|
automation_id=automation.id,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
async for result in send_llm_response(llm_response):
|
async for result in send_llm_response(llm_response):
|
||||||
yield result
|
yield result
|
||||||
|
@ -838,6 +851,7 @@ async def chat(
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -882,6 +896,7 @@ async def chat(
|
||||||
custom_filters,
|
custom_filters,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -906,6 +921,7 @@ async def chat(
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -956,6 +972,7 @@ async def chat(
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -986,6 +1003,7 @@ async def chat(
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
content_obj = {
|
content_obj = {
|
||||||
"intentType": intent_type,
|
"intentType": intent_type,
|
||||||
|
@ -1014,6 +1032,7 @@ async def chat(
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -1041,6 +1060,7 @@ async def chat(
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
|
@ -1064,6 +1084,7 @@ async def chat(
|
||||||
location,
|
location,
|
||||||
user_name,
|
user_name,
|
||||||
uploaded_images,
|
uploaded_images,
|
||||||
|
tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send Response
|
# Send Response
|
||||||
|
|
|
@ -301,6 +301,7 @@ async def aget_relevant_information_sources(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||||
|
@ -337,6 +338,7 @@ async def aget_relevant_information_sources(
|
||||||
relevant_tools_prompt,
|
relevant_tools_prompt,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -378,6 +380,7 @@ async def aget_relevant_output_modes(
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||||
|
@ -413,7 +416,9 @@ async def aget_relevant_output_modes(
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Chat actor: Infer output mode for chat response", logger):
|
with timer("Chat actor: Infer output mode for chat response", logger):
|
||||||
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
|
response = await send_message_to_model_wrapper(
|
||||||
|
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
|
@ -444,6 +449,7 @@ async def infer_webpage_urls(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Infer webpage links from the given query
|
Infer webpage links from the given query
|
||||||
|
@ -468,7 +474,11 @@ async def infer_webpage_urls(
|
||||||
|
|
||||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
online_queries_prompt,
|
||||||
|
query_images=query_images,
|
||||||
|
response_type="json_object",
|
||||||
|
user=user,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||||
|
@ -490,6 +500,7 @@ async def generate_online_subqueries(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Generate subqueries from the given query
|
Generate subqueries from the given query
|
||||||
|
@ -514,7 +525,11 @@ async def generate_online_subqueries(
|
||||||
|
|
||||||
with timer("Chat actor: Generate online search subqueries", logger):
|
with timer("Chat actor: Generate online search subqueries", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
online_queries_prompt,
|
||||||
|
query_images=query_images,
|
||||||
|
response_type="json_object",
|
||||||
|
user=user,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
@ -533,7 +548,7 @@ async def generate_online_subqueries(
|
||||||
|
|
||||||
|
|
||||||
async def schedule_query(
|
async def schedule_query(
|
||||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
|
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
|
||||||
) -> Tuple[str, ...]:
|
) -> Tuple[str, ...]:
|
||||||
"""
|
"""
|
||||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||||
|
@ -546,7 +561,7 @@ async def schedule_query(
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
crontime_prompt, query_images=query_images, response_type="json_object", user=user
|
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
@ -561,7 +576,7 @@ async def schedule_query(
|
||||||
|
|
||||||
|
|
||||||
async def extract_relevant_info(
|
async def extract_relevant_info(
|
||||||
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None
|
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
|
||||||
) -> Union[str, None]:
|
) -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
Extract relevant information for a given query from the target corpus
|
Extract relevant information for a given query from the target corpus
|
||||||
|
@ -584,6 +599,7 @@ async def extract_relevant_info(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_information,
|
prompts.system_prompt_extract_relevant_information,
|
||||||
user=user,
|
user=user,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
@ -595,6 +611,7 @@ async def extract_relevant_summary(
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> Union[str, None]:
|
) -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
Extract relevant information for a given query from the target corpus
|
Extract relevant information for a given query from the target corpus
|
||||||
|
@ -622,6 +639,7 @@ async def extract_relevant_summary(
|
||||||
prompts.system_prompt_extract_relevant_summary,
|
prompts.system_prompt_extract_relevant_summary,
|
||||||
user=user,
|
user=user,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
@ -636,6 +654,7 @@ async def generate_excalidraw_diagram(
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
||||||
|
@ -650,6 +669,7 @@ async def generate_excalidraw_diagram(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
|
@ -660,6 +680,7 @@ async def generate_excalidraw_diagram(
|
||||||
q=better_diagram_description_prompt,
|
q=better_diagram_description_prompt,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield better_diagram_description_prompt, excalidraw_diagram_description
|
yield better_diagram_description_prompt, excalidraw_diagram_description
|
||||||
|
@ -674,6 +695,7 @@ async def generate_better_diagram_description(
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a diagram description from the given query and context
|
Generate a diagram description from the given query and context
|
||||||
|
@ -711,7 +733,7 @@ async def generate_better_diagram_description(
|
||||||
|
|
||||||
with timer("Chat actor: Generate better diagram description", logger):
|
with timer("Chat actor: Generate better diagram description", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
improve_diagram_description_prompt, query_images=query_images, user=user
|
improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer
|
||||||
)
|
)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
|
@ -724,6 +746,7 @@ async def generate_excalidraw_diagram_from_description(
|
||||||
q: str,
|
q: str,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> str:
|
) -> str:
|
||||||
personality_context = (
|
personality_context = (
|
||||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||||
|
@ -735,7 +758,9 @@ async def generate_excalidraw_diagram_from_description(
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Chat actor: Generate excalidraw diagram", logger):
|
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||||
raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user)
|
raw_response = await send_message_to_model_wrapper(
|
||||||
|
message=excalidraw_diagram_generation, user=user, tracer=tracer
|
||||||
|
)
|
||||||
raw_response = raw_response.strip()
|
raw_response = raw_response.strip()
|
||||||
raw_response = remove_json_codeblock(raw_response)
|
raw_response = remove_json_codeblock(raw_response)
|
||||||
response: Dict[str, str] = json.loads(raw_response)
|
response: Dict[str, str] = json.loads(raw_response)
|
||||||
|
@ -756,6 +781,7 @@ async def generate_better_image_prompt(
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a better image prompt from the given query
|
Generate a better image prompt from the given query
|
||||||
|
@ -802,7 +828,9 @@ async def generate_better_image_prompt(
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
|
response = await send_message_to_model_wrapper(
|
||||||
|
image_prompt, query_images=query_images, user=user, tracer=tracer
|
||||||
|
)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
response = response[1:-1]
|
response = response[1:-1]
|
||||||
|
@ -816,6 +844,7 @@ async def send_message_to_model_wrapper(
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = conversation_config.vision_enabled
|
||||||
|
@ -862,6 +891,7 @@ async def send_message_to_model_wrapper(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
|
@ -885,6 +915,7 @@ async def send_message_to_model_wrapper(
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
api_base_url=api_base_url,
|
api_base_url=api_base_url,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
@ -903,6 +934,7 @@ async def send_message_to_model_wrapper(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
@ -918,7 +950,7 @@ async def send_message_to_model_wrapper(
|
||||||
)
|
)
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
return gemini_send_message_to_model(
|
||||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
@ -929,6 +961,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
||||||
|
|
||||||
|
@ -961,6 +994,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
|
@ -975,7 +1009,11 @@ def send_message_to_model_wrapper_sync(
|
||||||
)
|
)
|
||||||
|
|
||||||
openai_response = send_message_to_model(
|
openai_response = send_message_to_model(
|
||||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
messages=truncated_messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model=chat_model,
|
||||||
|
response_type=response_type,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
return openai_response
|
return openai_response
|
||||||
|
@ -995,6 +1033,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
|
@ -1013,6 +1052,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
response_type=response_type,
|
response_type=response_type,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
@ -1032,6 +1072,7 @@ def generate_chat_response(
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: Optional[str] = None,
|
user_name: Optional[str] = None,
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
chat_response = None
|
chat_response = None
|
||||||
|
@ -1051,6 +1092,7 @@ def generate_chat_response(
|
||||||
client_application=client_application,
|
client_application=client_application,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||||
|
@ -1077,6 +1119,7 @@ def generate_chat_response(
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
|
@ -1100,6 +1143,7 @@ def generate_chat_response(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
|
@ -1120,6 +1164,7 @@ def generate_chat_response(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
api_key = conversation_config.openai_config.api_key
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
@ -1139,6 +1184,7 @@ def generate_chat_response(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata.update({"chat_model": conversation_config.chat_model})
|
metadata.update({"chat_model": conversation_config.chat_model})
|
||||||
|
@ -1495,9 +1541,15 @@ def scheduled_chat(
|
||||||
|
|
||||||
|
|
||||||
async def create_automation(
|
async def create_automation(
|
||||||
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
|
q: str,
|
||||||
|
timezone: str,
|
||||||
|
user: KhojUser,
|
||||||
|
calling_url: URL,
|
||||||
|
meta_log: dict = {},
|
||||||
|
conversation_id: str = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
|
crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer)
|
||||||
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
||||||
return job, crontime, query_to_run, subject
|
return job, crontime, query_to_run, subject
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue