mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Commit conversation traces using user, chat, message branch hierarchy
- Message train of thought forks and merges from its conversation branch - Conversation branches from user branch - User branches from root commit on the main branch - Weave chat tracer metadata from api endpoint through all chat actors and commit it to the prompt trace
This commit is contained in:
parent
a3022b7556
commit
ea0712424b
6 changed files with 114 additions and 21 deletions
|
@ -23,7 +23,7 @@ from khoj.database.adapters import ConversationAdapters
|
|||
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
model_to_prompt_size = {
|
||||
|
@ -119,6 +119,7 @@ def save_to_conversation_log(
|
|||
conversation_id: str = None,
|
||||
automation_id: str = None,
|
||||
query_images: List[str] = None,
|
||||
tracer: Dict[str, Any] = {},
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
updated_conversation = message_to_log(
|
||||
|
@ -144,6 +145,9 @@ def save_to_conversation_log(
|
|||
user_message=q,
|
||||
)
|
||||
|
||||
if in_debug_mode() or state.verbose > 1:
|
||||
merge_message_into_conversation_trace(q, chat_response, tracer)
|
||||
|
||||
logger.info(
|
||||
f"""
|
||||
Saved Conversation Turn
|
||||
|
|
|
@ -28,6 +28,7 @@ async def text_to_image(
|
|||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
status_code = 200
|
||||
image = None
|
||||
|
@ -68,6 +69,7 @@ async def text_to_image(
|
|||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
if send_status_func:
|
||||
|
|
|
@ -64,6 +64,7 @@ async def search_online(
|
|||
custom_filters: List[str] = [],
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
query += " ".join(custom_filters)
|
||||
if not is_internet_connected():
|
||||
|
@ -73,7 +74,7 @@ async def search_online(
|
|||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(
|
||||
query, conversation_history, location, user, query_images=query_images, agent=agent
|
||||
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
|
||||
)
|
||||
response_dict = {}
|
||||
|
||||
|
@ -111,7 +112,7 @@ async def search_online(
|
|||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [
|
||||
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
|
||||
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
|
||||
for link, data in webpages.items()
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
@ -153,6 +154,7 @@ async def read_webpages(
|
|||
send_status_func: Optional[Callable] = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
logger.info(f"Inferring web pages to read")
|
||||
|
@ -166,7 +168,7 @@ async def read_webpages(
|
|||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
|
||||
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent, tracer=tracer) for url in urls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
response: Dict[str, Dict] = defaultdict(dict)
|
||||
|
@ -192,7 +194,12 @@ async def read_webpage(
|
|||
|
||||
|
||||
async def read_webpage_and_extract_content(
|
||||
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
|
||||
subqueries: set[str],
|
||||
url: str,
|
||||
content: str = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> Tuple[set[str], str, Union[None, str]]:
|
||||
# Select the web scrapers to use for reading the web page
|
||||
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
|
||||
|
@ -214,7 +221,9 @@ async def read_webpage_and_extract_content(
|
|||
# Extract relevant information from the web page
|
||||
if is_none_or_empty(extracted_info):
|
||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
|
||||
extracted_info = await extract_relevant_info(
|
||||
subqueries, content, user=user, agent=agent, tracer=tracer
|
||||
)
|
||||
|
||||
# If we successfully extracted information, break the loop
|
||||
if not is_none_or_empty(extracted_info):
|
||||
|
|
|
@ -350,6 +350,7 @@ async def extract_references_and_questions(
|
|||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
|
@ -425,6 +426,7 @@ async def extract_references_and_questions(
|
|||
user=user,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
personality_context=personality_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = conversation_config.openai_config
|
||||
|
@ -442,6 +444,7 @@ async def extract_references_and_questions(
|
|||
query_images=query_images,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -456,6 +459,7 @@ async def extract_references_and_questions(
|
|||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -471,6 +475,7 @@ async def extract_references_and_questions(
|
|||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
|
|
|
@ -3,6 +3,7 @@ import base64
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
|
@ -563,6 +564,12 @@ async def chat(
|
|||
event_delimiter = "␃🔚␗"
|
||||
q = unquote(q)
|
||||
nonlocal conversation_id
|
||||
tracer: dict = {
|
||||
"mid": f"{uuid.uuid4()}",
|
||||
"cid": conversation_id,
|
||||
"uid": user.id,
|
||||
"khoj_version": state.khoj_version,
|
||||
}
|
||||
|
||||
uploaded_images: list[str] = []
|
||||
if images:
|
||||
|
@ -682,6 +689,7 @@ async def chat(
|
|||
user=user,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
async for result in send_event(
|
||||
|
@ -689,7 +697,9 @@ async def chat(
|
|||
):
|
||||
yield result
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
|
||||
mode = await aget_relevant_output_modes(
|
||||
q, meta_log, is_automated_task, user, uploaded_images, agent, tracer=tracer
|
||||
)
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||
yield result
|
||||
if mode not in conversation_commands:
|
||||
|
@ -755,6 +765,7 @@ async def chat(
|
|||
query_images=uploaded_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
response_log = str(response)
|
||||
async for result in send_llm_response(response_log):
|
||||
|
@ -774,6 +785,7 @@ async def chat(
|
|||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -795,7 +807,7 @@ async def chat(
|
|||
if ConversationCommand.Automation in conversation_commands:
|
||||
try:
|
||||
automation, crontime, query_to_run, subject = await create_automation(
|
||||
q, timezone, user, request.url, meta_log
|
||||
q, timezone, user, request.url, meta_log, tracer=tracer
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scheduling task {q} for {user.email}: {e}")
|
||||
|
@ -817,6 +829,7 @@ async def chat(
|
|||
inferred_queries=[query_to_run],
|
||||
automation_id=automation.id,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
async for result in send_llm_response(llm_response):
|
||||
yield result
|
||||
|
@ -838,6 +851,7 @@ async def chat(
|
|||
partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -882,6 +896,7 @@ async def chat(
|
|||
custom_filters,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -906,6 +921,7 @@ async def chat(
|
|||
partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -956,6 +972,7 @@ async def chat(
|
|||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -986,6 +1003,7 @@ async def chat(
|
|||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
|
@ -1014,6 +1032,7 @@ async def chat(
|
|||
user=user,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -1041,6 +1060,7 @@ async def chat(
|
|||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
|
@ -1064,6 +1084,7 @@ async def chat(
|
|||
location,
|
||||
user_name,
|
||||
uploaded_images,
|
||||
tracer,
|
||||
)
|
||||
|
||||
# Send Response
|
||||
|
|
|
@ -301,6 +301,7 @@ async def aget_relevant_information_sources(
|
|||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||
|
@ -337,6 +338,7 @@ async def aget_relevant_information_sources(
|
|||
relevant_tools_prompt,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -378,6 +380,7 @@ async def aget_relevant_output_modes(
|
|||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||
|
@ -413,7 +416,9 @@ async def aget_relevant_output_modes(
|
|||
)
|
||||
|
||||
with timer("Chat actor: Infer output mode for chat response", logger):
|
||||
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
|
||||
response = await send_message_to_model_wrapper(
|
||||
relevant_mode_prompt, response_type="json_object", user=user, tracer=tracer
|
||||
)
|
||||
|
||||
try:
|
||||
response = response.strip()
|
||||
|
@ -444,6 +449,7 @@ async def infer_webpage_urls(
|
|||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer webpage links from the given query
|
||||
|
@ -468,7 +474,11 @@ async def infer_webpage_urls(
|
|||
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
online_queries_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||
|
@ -490,6 +500,7 @@ async def generate_online_subqueries(
|
|||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate subqueries from the given query
|
||||
|
@ -514,7 +525,11 @@ async def generate_online_subqueries(
|
|||
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
online_queries_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
@ -533,7 +548,7 @@ async def generate_online_subqueries(
|
|||
|
||||
|
||||
async def schedule_query(
|
||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
|
||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None, tracer: dict = {}
|
||||
) -> Tuple[str, ...]:
|
||||
"""
|
||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||
|
@ -546,7 +561,7 @@ async def schedule_query(
|
|||
)
|
||||
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
crontime_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
crontime_prompt, query_images=query_images, response_type="json_object", user=user, tracer=tracer
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
@ -561,7 +576,7 @@ async def schedule_query(
|
|||
|
||||
|
||||
async def extract_relevant_info(
|
||||
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None
|
||||
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None, tracer: dict = {}
|
||||
) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
|
@ -584,6 +599,7 @@ async def extract_relevant_info(
|
|||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_information,
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
@ -595,6 +611,7 @@ async def extract_relevant_summary(
|
|||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
|
@ -622,6 +639,7 @@ async def extract_relevant_summary(
|
|||
prompts.system_prompt_extract_relevant_summary,
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
@ -636,6 +654,7 @@ async def generate_excalidraw_diagram(
|
|||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
if send_status_func:
|
||||
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
||||
|
@ -650,6 +669,7 @@ async def generate_excalidraw_diagram(
|
|||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
if send_status_func:
|
||||
|
@ -660,6 +680,7 @@ async def generate_excalidraw_diagram(
|
|||
q=better_diagram_description_prompt,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
yield better_diagram_description_prompt, excalidraw_diagram_description
|
||||
|
@ -674,6 +695,7 @@ async def generate_better_diagram_description(
|
|||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
"""
|
||||
Generate a diagram description from the given query and context
|
||||
|
@ -711,7 +733,7 @@ async def generate_better_diagram_description(
|
|||
|
||||
with timer("Chat actor: Generate better diagram description", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
improve_diagram_description_prompt, query_images=query_images, user=user
|
||||
improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
|
@ -724,6 +746,7 @@ async def generate_excalidraw_diagram_from_description(
|
|||
q: str,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
|
@ -735,7 +758,9 @@ async def generate_excalidraw_diagram_from_description(
|
|||
)
|
||||
|
||||
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user)
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
message=excalidraw_diagram_generation, user=user, tracer=tracer
|
||||
)
|
||||
raw_response = raw_response.strip()
|
||||
raw_response = remove_json_codeblock(raw_response)
|
||||
response: Dict[str, str] = json.loads(raw_response)
|
||||
|
@ -756,6 +781,7 @@ async def generate_better_image_prompt(
|
|||
query_images: Optional[List[str]] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> str:
|
||||
"""
|
||||
Generate a better image prompt from the given query
|
||||
|
@ -802,7 +828,9 @@ async def generate_better_image_prompt(
|
|||
)
|
||||
|
||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
|
||||
response = await send_message_to_model_wrapper(
|
||||
image_prompt, query_images=query_images, user=user, tracer=tracer
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
|
@ -816,6 +844,7 @@ async def send_message_to_model_wrapper(
|
|||
response_type: str = "text",
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
|
@ -862,6 +891,7 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
|
@ -885,6 +915,7 @@ async def send_message_to_model_wrapper(
|
|||
model=chat_model,
|
||||
response_type=response_type,
|
||||
api_base_url=api_base_url,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -903,6 +934,7 @@ async def send_message_to_model_wrapper(
|
|||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -918,7 +950,7 @@ async def send_message_to_model_wrapper(
|
|||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type, tracer=tracer
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
@ -929,6 +961,7 @@ def send_message_to_model_wrapper_sync(
|
|||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
user: KhojUser = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
||||
|
||||
|
@ -961,6 +994,7 @@ def send_message_to_model_wrapper_sync(
|
|||
max_prompt_size=max_tokens,
|
||||
streaming=False,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
|
@ -975,7 +1009,11 @@ def send_message_to_model_wrapper_sync(
|
|||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
|
@ -995,6 +1033,7 @@ def send_message_to_model_wrapper_sync(
|
|||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
|
@ -1013,6 +1052,7 @@ def send_message_to_model_wrapper_sync(
|
|||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
@ -1032,6 +1072,7 @@ def generate_chat_response(
|
|||
location_data: LocationData = None,
|
||||
user_name: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
tracer: dict = {},
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response = None
|
||||
|
@ -1051,6 +1092,7 @@ def generate_chat_response(
|
|||
client_application=client_application,
|
||||
conversation_id=conversation_id,
|
||||
query_images=query_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
|
@ -1077,6 +1119,7 @@ def generate_chat_response(
|
|||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
|
@ -1100,6 +1143,7 @@ def generate_chat_response(
|
|||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
|
@ -1120,6 +1164,7 @@ def generate_chat_response(
|
|||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -1139,6 +1184,7 @@ def generate_chat_response(
|
|||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
metadata.update({"chat_model": conversation_config.chat_model})
|
||||
|
@ -1495,9 +1541,15 @@ def scheduled_chat(
|
|||
|
||||
|
||||
async def create_automation(
|
||||
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
|
||||
q: str,
|
||||
timezone: str,
|
||||
user: KhojUser,
|
||||
calling_url: URL,
|
||||
meta_log: dict = {},
|
||||
conversation_id: str = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
|
||||
crontime, query_to_run, subject = await schedule_query(q, meta_log, user, tracer=tracer)
|
||||
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
||||
return job, crontime, query_to_run, subject
|
||||
|
||||
|
|
Loading…
Reference in a new issue