Use user chat models for train of thought when no server chat settings

Update chat actors to use user's chat model for train of thought. This
requires passing the user info as argument to all the chat actors.

Whether the user is subscribed or not can be inferred from the user
info being passed, so it doesn't need to be passed as a separate
argument to chat actor functions

Let send_message_to_model function infer chat model instead of passing
it as an argument from some chat actors. Better if this logic can be
done in a single place.
This commit is contained in:
Debanjum Singh Solanky 2024-10-08 18:46:27 -07:00
parent ec0c79217f
commit 05fb0f14d3
6 changed files with 63 additions and 66 deletions

View file

@ -25,7 +25,6 @@ async def text_to_image(
location_data: LocationData,
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
agent: Agent = None,
@ -66,8 +65,8 @@ async def text_to_image(
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
user=user,
agent=agent,
)

View file

@ -102,7 +102,7 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed, agent=agent)
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed, user=user, agent=agent)
for link, subquery, content in webpages
]
results = await asyncio.gather(*tasks)
@ -158,7 +158,9 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed, agent=agent) for url in urls]
tasks = [
read_webpage_and_extract_content(query, url, subscribed=subscribed, user=user, agent=agent) for url in urls
]
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
@ -169,14 +171,14 @@ async def read_webpages(
async def read_webpage_and_extract_content(
subquery: str, url: str, content: str = None, subscribed: bool = False, agent: Agent = None
subquery: str, url: str, content: str = None, subscribed: bool = False, user: KhojUser = None, agent: Agent = None
) -> Tuple[str, Union[None, str], str]:
try:
if is_none_or_empty(content):
with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed, agent=agent)
extracted_info = await extract_relevant_info(subquery, content, user=user, agent=agent)
return subquery, extracted_info, url
except Exception as e:
logger.error(f"Failed to read web page at '{url}' with {e}")

View file

@ -395,7 +395,7 @@ async def extract_references_and_questions(
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
conversation_config = await ConversationAdapters.aget_default_conversation_config()
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
vision_enabled = conversation_config.vision_enabled
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:

View file

@ -194,7 +194,7 @@ def chat_history(
n: Optional[int] = None,
):
user = request.user.object
validate_conversation_config()
validate_conversation_config(user)
# Load Conversation History
conversation = ConversationAdapters.get_conversation_by_user(
@ -694,7 +694,7 @@ async def chat(
q,
meta_log,
is_automated_task,
subscribed=subscribed,
user=user,
uploaded_image_url=uploaded_image_url,
agent=agent,
)
@ -704,7 +704,7 @@ async def chat(
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent)
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
@ -767,8 +767,8 @@ async def chat(
q,
contextual_data,
conversation_history=meta_log,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
user=user,
agent=agent,
)
response_log = str(response)
@ -957,7 +957,6 @@ async def chat(
location_data=location,
references=compiled_references,
online_results=online_results,
subscribed=subscribed,
send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
agent=agent,
@ -1192,7 +1191,7 @@ async def get_chat(
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(
q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url
q, meta_log, is_automated_task, user=user, uploaded_image_url=uploaded_image_url
)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
@ -1200,7 +1199,7 @@ async def get_chat(
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url)
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
@ -1252,7 +1251,7 @@ async def get_chat(
q,
contextual_data,
conversation_history=meta_log,
subscribed=subscribed,
user=user,
uploaded_image_url=uploaded_image_url,
)
response_log = str(response)
@ -1438,7 +1437,6 @@ async def get_chat(
location_data=location,
references=compiled_references,
online_results=online_results,
subscribed=subscribed,
send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
):

View file

@ -40,7 +40,7 @@ def get_user_chat_model(
chat_model = ConversationAdapters.get_conversation_config(user)
if chat_model is None:
chat_model = ConversationAdapters.get_default_conversation_config()
chat_model = ConversationAdapters.get_default_conversation_config(user)
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))

View file

@ -39,6 +39,7 @@ from khoj.database.adapters import (
AutomationAdapters,
ConversationAdapters,
EntryAdapters,
ais_user_subscribed,
create_khoj_token,
get_khoj_tokens,
get_user_name,
@ -119,20 +120,20 @@ def is_query_empty(query: str) -> bool:
return is_none_or_empty(query.strip())
def validate_conversation_config():
default_config = ConversationAdapters.get_default_conversation_config()
def validate_conversation_config(user: KhojUser):
default_config = ConversationAdapters.get_default_conversation_config(user)
if default_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
if default_config.model_type == "openai" and not default_config.openai_config:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
async def is_ready_to_chat(user: KhojUser):
user_conversation_config = (await ConversationAdapters.aget_user_conversation_config(user)) or (
await ConversationAdapters.aget_default_conversation_config()
)
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if user_conversation_config == None:
user_conversation_config = await ConversationAdapters.aget_default_conversation_config()
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
chat_model = user_conversation_config.chat_model
@ -246,19 +247,19 @@ async def agenerate_chat_response(*args):
return await loop.run_in_executor(executor, generate_chat_response, *args)
async def acreate_title_from_query(query: str) -> str:
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
"""
Create a title from the given query
"""
title_generation_prompt = prompts.subject_generation.format(query=query)
with timer("Chat actor: Generate title from query", logger):
response = await send_message_to_model_wrapper(title_generation_prompt)
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
return response.strip()
async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]:
async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None) -> Tuple[bool, str]:
"""
Check if the system prompt is safe to use
"""
@ -267,7 +268,7 @@ async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]:
reason = ""
with timer("Chat actor: Check if safe prompt", logger):
response = await send_message_to_model_wrapper(safe_prompt_check)
response = await send_message_to_model_wrapper(safe_prompt_check, user=user)
response = response.strip()
try:
@ -288,7 +289,7 @@ async def aget_relevant_information_sources(
query: str,
conversation_history: dict,
is_task: bool,
subscribed: bool,
user: KhojUser,
uploaded_image_url: str = None,
agent: Agent = None,
):
@ -326,7 +327,7 @@ async def aget_relevant_information_sources(
response = await send_message_to_model_wrapper(
relevant_tools_prompt,
response_type="json_object",
subscribed=subscribed,
user=user,
)
try:
@ -362,7 +363,12 @@ async def aget_relevant_information_sources(
async def aget_relevant_output_modes(
query: str, conversation_history: dict, is_task: bool = False, uploaded_image_url: str = None, agent: Agent = None
query: str,
conversation_history: dict,
is_task: bool = False,
user: KhojUser = None,
uploaded_image_url: str = None,
agent: Agent = None,
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -398,7 +404,7 @@ async def aget_relevant_output_modes(
)
with timer("Chat actor: Infer output mode for chat response", logger):
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object")
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
try:
response = response.strip()
@ -453,7 +459,7 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object"
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list of URLs
@ -499,7 +505,7 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object"
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list
@ -517,7 +523,9 @@ async def generate_online_subqueries(
return [q]
async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]:
async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None
) -> Tuple[str, ...]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
"""
@ -529,7 +537,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
)
raw_response = await send_message_to_model_wrapper(
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object"
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list
@ -543,7 +551,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Agent = None) -> Union[str, None]:
async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agent: Agent = None) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
"""
@ -561,14 +569,11 @@ async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Ag
personality_context=personality_context,
)
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
chat_model_option=chat_model,
subscribed=subscribed,
user=user,
)
return response.strip()
@ -577,8 +582,8 @@ async def extract_relevant_summary(
q: str,
corpus: str,
conversation_history: dict,
subscribed: bool = False,
uploaded_image_url: str = None,
user: KhojUser = None,
agent: Agent = None,
) -> Union[str, None]:
"""
@ -601,14 +606,11 @@ async def extract_relevant_summary(
personality_context=personality_context,
)
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_summary,
chat_model_option=chat_model,
subscribed=subscribed,
user=user,
uploaded_image_url=uploaded_image_url,
)
return response.strip()
@ -621,8 +623,8 @@ async def generate_better_image_prompt(
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
model_type: Optional[str] = None,
subscribed: bool = False,
uploaded_image_url: Optional[str] = None,
user: KhojUser = None,
agent: Agent = None,
) -> str:
"""
@ -672,12 +674,8 @@ async def generate_better_image_prompt(
personality_context=personality_context,
)
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(
image_prompt, chat_model_option=chat_model, subscribed=subscribed, uploaded_image_url=uploaded_image_url
)
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
@ -689,14 +687,10 @@ async def send_message_to_model_wrapper(
message: str,
system_message: str = "",
response_type: str = "text",
chat_model_option: ChatModelOptions = None,
subscribed: bool = False,
user: KhojUser = None,
uploaded_image_url: str = None,
):
conversation_config: ChatModelOptions = (
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
)
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled
if not vision_available and uploaded_image_url:
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
@ -704,6 +698,7 @@ async def send_message_to_model_wrapper(
conversation_config = vision_enabled_config
vision_available = True
subscribed = await ais_user_subscribed(user)
chat_model = conversation_config.chat_model
max_tokens = (
conversation_config.subscribed_max_prompt_size
@ -802,8 +797,9 @@ def send_message_to_model_wrapper_sync(
message: str,
system_message: str = "",
response_type: str = "text",
user: KhojUser = None,
):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config()
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
if conversation_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
@ -1182,7 +1178,7 @@ class CommonQueryParamsClass:
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
def should_notify(original_query: str, executed_query: str, ai_response: str) -> bool:
def should_notify(original_query: str, executed_query: str, ai_response: str, user: KhojUser) -> bool:
"""
Decide whether to notify the user of the AI response.
Default to notifying the user for now.
@ -1199,7 +1195,7 @@ def should_notify(original_query: str, executed_query: str, ai_response: str) ->
with timer("Chat actor: Decide to notify user of automation response", logger):
try:
# TODO Replace with async call so we don't have to maintain a sync version
response = send_message_to_model_wrapper_sync(to_notify_or_not)
response = send_message_to_model_wrapper_sync(to_notify_or_not, user)
should_notify_result = "no" not in response.lower()
logger.info(f'Decided to {"not " if not should_notify_result else ""}notify user of automation response.')
return should_notify_result
@ -1291,7 +1287,9 @@ def scheduled_chat(
ai_response = raw_response.text
# Notify user if the AI response is satisfactory
if should_notify(original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response):
if should_notify(
original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response, user=user
):
if is_resend_enabled():
send_task_email(user.get_short_name(), user.email, cleaned_query, ai_response, subject, is_image)
else:
@ -1301,7 +1299,7 @@ def scheduled_chat(
async def create_automation(
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
):
crontime, query_to_run, subject = await schedule_query(q, meta_log)
crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
return job, crontime, query_to_run, subject
@ -1495,9 +1493,9 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
current_notion_config = get_user_notion_config(user)
notion_token = current_notion_config.token if current_notion_config else ""
selected_chat_model_config = (
ConversationAdapters.get_conversation_config(user) or ConversationAdapters.get_default_conversation_config()
)
selected_chat_model_config = ConversationAdapters.get_conversation_config(
user
) or ConversationAdapters.get_default_conversation_config(user)
chat_models = ConversationAdapters.get_conversation_processor_options().all()
chat_model_options = list()
for chat_model in chat_models: