mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Simplify switching chat model when self-hosting (#934)
# Overview - Default to use user chat models for train of thought when no server chat settings created by admins - Default to not create server chat settings on first run # Details This change simplifies switching chat models for self-hosted setups by just changing the chat model on the user settings page. It falls back to use the user chat model for train of thought if server chat settings have not been created on the admin panel. Server chat settings, when set, controls the chat model used for Khoj's train of thought and the default user chat model. Previously a self-hosted user had to update 1. the server chat settings in the admin panel and 2. their own user chat model in the user settings panel to completely switch to a different chat model for both train of thought & response generation respectively You can still set server chat settings via the admin panel to use a different chat model for train of thought vs response generation. But this is only useful for advanced, multi-user setups.
This commit is contained in:
commit
c66c571396
9 changed files with 102 additions and 88 deletions
|
@ -253,7 +253,7 @@ def configure_server(
|
||||||
|
|
||||||
state.SearchType = configure_search_types()
|
state.SearchType = configure_search_types()
|
||||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
setup_default_agent()
|
setup_default_agent(user)
|
||||||
|
|
||||||
message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
|
message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
|
@ -265,8 +265,8 @@ def configure_server(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def setup_default_agent():
|
def setup_default_agent(user: KhojUser):
|
||||||
AgentAdapters.create_default_agent()
|
AgentAdapters.create_default_agent(user)
|
||||||
|
|
||||||
|
|
||||||
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None):
|
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None):
|
||||||
|
|
|
@ -647,8 +647,8 @@ class AgentAdapters:
|
||||||
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
|
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_default_agent():
|
def create_default_agent(user: KhojUser):
|
||||||
default_conversation_config = ConversationAdapters.get_default_conversation_config()
|
default_conversation_config = ConversationAdapters.get_default_conversation_config(user)
|
||||||
if default_conversation_config is None:
|
if default_conversation_config is None:
|
||||||
logger.info("No default conversation config found, skipping default agent creation")
|
logger.info("No default conversation config found, skipping default agent creation")
|
||||||
return None
|
return None
|
||||||
|
@ -972,29 +972,51 @@ class ConversationAdapters:
|
||||||
return VoiceModelOption.objects.first()
|
return VoiceModelOption.objects.first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_default_conversation_config():
|
def get_default_conversation_config(user: KhojUser = None):
|
||||||
|
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
||||||
|
# Get the server chat settings
|
||||||
server_chat_settings = ServerChatSettings.objects.first()
|
server_chat_settings = ServerChatSettings.objects.first()
|
||||||
if server_chat_settings is None or server_chat_settings.chat_default is None:
|
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
|
||||||
return ChatModelOptions.objects.filter().first()
|
|
||||||
return server_chat_settings.chat_default
|
return server_chat_settings.chat_default
|
||||||
|
|
||||||
|
# Get the user's chat settings, if the server chat settings are not set
|
||||||
|
user_chat_settings = UserConversationConfig.objects.filter(user=user).first() if user else None
|
||||||
|
if user_chat_settings is not None and user_chat_settings.setting is not None:
|
||||||
|
return user_chat_settings.setting
|
||||||
|
|
||||||
|
# Get the first chat model if even the user chat settings are not set
|
||||||
|
return ChatModelOptions.objects.filter().first()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_default_conversation_config():
|
async def aget_default_conversation_config(user: KhojUser = None):
|
||||||
|
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
||||||
|
# Get the server chat settings
|
||||||
server_chat_settings: ServerChatSettings = (
|
server_chat_settings: ServerChatSettings = (
|
||||||
await ServerChatSettings.objects.filter()
|
await ServerChatSettings.objects.filter()
|
||||||
.prefetch_related("chat_default", "chat_default__openai_config")
|
.prefetch_related("chat_default", "chat_default__openai_config")
|
||||||
.afirst()
|
.afirst()
|
||||||
)
|
)
|
||||||
if server_chat_settings is None or server_chat_settings.chat_default is None:
|
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
|
||||||
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
|
||||||
return server_chat_settings.chat_default
|
return server_chat_settings.chat_default
|
||||||
|
|
||||||
|
# Get the user's chat settings, if the server chat settings are not set
|
||||||
|
user_chat_settings = (
|
||||||
|
(await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst())
|
||||||
|
if user
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if user_chat_settings is not None and user_chat_settings.setting is not None:
|
||||||
|
return user_chat_settings.setting
|
||||||
|
|
||||||
|
# Get the first chat model if even the user chat settings are not set
|
||||||
|
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_advanced_conversation_config():
|
def get_advanced_conversation_config():
|
||||||
server_chat_settings = ServerChatSettings.objects.first()
|
server_chat_settings = ServerChatSettings.objects.first()
|
||||||
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
|
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
|
||||||
return ConversationAdapters.get_default_conversation_config()
|
|
||||||
return server_chat_settings.chat_advanced
|
return server_chat_settings.chat_advanced
|
||||||
|
return ConversationAdapters.get_default_conversation_config()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_advanced_conversation_config():
|
async def aget_advanced_conversation_config():
|
||||||
|
@ -1003,9 +1025,9 @@ class ConversationAdapters:
|
||||||
.prefetch_related("chat_advanced", "chat_advanced__openai_config")
|
.prefetch_related("chat_advanced", "chat_advanced__openai_config")
|
||||||
.afirst()
|
.afirst()
|
||||||
)
|
)
|
||||||
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
|
if server_chat_settings is not None or server_chat_settings.chat_advanced is not None:
|
||||||
return await ConversationAdapters.aget_default_conversation_config()
|
|
||||||
return server_chat_settings.chat_advanced
|
return server_chat_settings.chat_advanced
|
||||||
|
return await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_conversation_from_public_conversation(
|
def create_conversation_from_public_conversation(
|
||||||
|
|
|
@ -25,7 +25,6 @@ async def text_to_image(
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
references: List[Dict[str, Any]],
|
references: List[Dict[str, Any]],
|
||||||
online_results: Dict[str, Any],
|
online_results: Dict[str, Any],
|
||||||
subscribed: bool = False,
|
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
uploaded_image_url: Optional[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
@ -66,8 +65,8 @@ async def text_to_image(
|
||||||
note_references=references,
|
note_references=references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
model_type=text_to_image_config.model_type,
|
model_type=text_to_image_config.model_type,
|
||||||
subscribed=subscribed,
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -102,7 +102,7 @@ async def search_online(
|
||||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [
|
tasks = [
|
||||||
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed, agent=agent)
|
read_webpage_and_extract_content(subquery, link, content, user=user, agent=agent)
|
||||||
for link, subquery, content in webpages
|
for link, subquery, content in webpages
|
||||||
]
|
]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
@ -158,7 +158,7 @@ async def read_webpages(
|
||||||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed, agent=agent) for url in urls]
|
tasks = [read_webpage_and_extract_content(query, url, user=user, agent=agent) for url in urls]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
response: Dict[str, Dict] = defaultdict(dict)
|
response: Dict[str, Dict] = defaultdict(dict)
|
||||||
|
@ -169,14 +169,14 @@ async def read_webpages(
|
||||||
|
|
||||||
|
|
||||||
async def read_webpage_and_extract_content(
|
async def read_webpage_and_extract_content(
|
||||||
subquery: str, url: str, content: str = None, subscribed: bool = False, agent: Agent = None
|
subquery: str, url: str, content: str = None, user: KhojUser = None, agent: Agent = None
|
||||||
) -> Tuple[str, Union[None, str], str]:
|
) -> Tuple[str, Union[None, str], str]:
|
||||||
try:
|
try:
|
||||||
if is_none_or_empty(content):
|
if is_none_or_empty(content):
|
||||||
with timer(f"Reading web page at '{url}' took", logger):
|
with timer(f"Reading web page at '{url}' took", logger):
|
||||||
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
|
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
|
||||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||||
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed, agent=agent)
|
extracted_info = await extract_relevant_info(subquery, content, user=user, agent=agent)
|
||||||
return subquery, extracted_info, url
|
return subquery, extracted_info, url
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to read web page at '{url}' with {e}")
|
logger.error(f"Failed to read web page at '{url}' with {e}")
|
||||||
|
|
|
@ -395,7 +395,7 @@ async def extract_references_and_questions(
|
||||||
# Infer search queries from user message
|
# Infer search queries from user message
|
||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
||||||
vision_enabled = conversation_config.vision_enabled
|
vision_enabled = conversation_config.vision_enabled
|
||||||
|
|
||||||
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
|
|
|
@ -194,7 +194,7 @@ def chat_history(
|
||||||
n: Optional[int] = None,
|
n: Optional[int] = None,
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
validate_conversation_config()
|
validate_conversation_config(user)
|
||||||
|
|
||||||
# Load Conversation History
|
# Load Conversation History
|
||||||
conversation = ConversationAdapters.get_conversation_by_user(
|
conversation = ConversationAdapters.get_conversation_by_user(
|
||||||
|
@ -694,7 +694,7 @@ async def chat(
|
||||||
q,
|
q,
|
||||||
meta_log,
|
meta_log,
|
||||||
is_automated_task,
|
is_automated_task,
|
||||||
subscribed=subscribed,
|
user=user,
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
@ -704,7 +704,7 @@ async def chat(
|
||||||
):
|
):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent)
|
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||||
yield result
|
yield result
|
||||||
if mode not in conversation_commands:
|
if mode not in conversation_commands:
|
||||||
|
@ -767,8 +767,8 @@ async def chat(
|
||||||
q,
|
q,
|
||||||
contextual_data,
|
contextual_data,
|
||||||
conversation_history=meta_log,
|
conversation_history=meta_log,
|
||||||
subscribed=subscribed,
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
response_log = str(response)
|
response_log = str(response)
|
||||||
|
@ -957,7 +957,6 @@ async def chat(
|
||||||
location_data=location,
|
location_data=location,
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
subscribed=subscribed,
|
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
@ -1192,7 +1191,7 @@ async def get_chat(
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||||
conversation_commands = await aget_relevant_information_sources(
|
conversation_commands = await aget_relevant_information_sources(
|
||||||
q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url
|
q, meta_log, is_automated_task, user=user, uploaded_image_url=uploaded_image_url
|
||||||
)
|
)
|
||||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||||
async for result in send_event(
|
async for result in send_event(
|
||||||
|
@ -1200,7 +1199,7 @@ async def get_chat(
|
||||||
):
|
):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url)
|
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url)
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||||
yield result
|
yield result
|
||||||
if mode not in conversation_commands:
|
if mode not in conversation_commands:
|
||||||
|
@ -1252,7 +1251,7 @@ async def get_chat(
|
||||||
q,
|
q,
|
||||||
contextual_data,
|
contextual_data,
|
||||||
conversation_history=meta_log,
|
conversation_history=meta_log,
|
||||||
subscribed=subscribed,
|
user=user,
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
)
|
)
|
||||||
response_log = str(response)
|
response_log = str(response)
|
||||||
|
@ -1438,7 +1437,6 @@ async def get_chat(
|
||||||
location_data=location,
|
location_data=location,
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
subscribed=subscribed,
|
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
):
|
):
|
||||||
|
|
|
@ -40,7 +40,7 @@ def get_user_chat_model(
|
||||||
chat_model = ConversationAdapters.get_conversation_config(user)
|
chat_model = ConversationAdapters.get_conversation_config(user)
|
||||||
|
|
||||||
if chat_model is None:
|
if chat_model is None:
|
||||||
chat_model = ConversationAdapters.get_default_conversation_config()
|
chat_model = ConversationAdapters.get_default_conversation_config(user)
|
||||||
|
|
||||||
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
|
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))
|
||||||
|
|
||||||
|
|
|
@ -39,6 +39,7 @@ from khoj.database.adapters import (
|
||||||
AutomationAdapters,
|
AutomationAdapters,
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
|
ais_user_subscribed,
|
||||||
create_khoj_token,
|
create_khoj_token,
|
||||||
get_khoj_tokens,
|
get_khoj_tokens,
|
||||||
get_user_name,
|
get_user_name,
|
||||||
|
@ -119,20 +120,20 @@ def is_query_empty(query: str) -> bool:
|
||||||
return is_none_or_empty(query.strip())
|
return is_none_or_empty(query.strip())
|
||||||
|
|
||||||
|
|
||||||
def validate_conversation_config():
|
def validate_conversation_config(user: KhojUser):
|
||||||
default_config = ConversationAdapters.get_default_conversation_config()
|
default_config = ConversationAdapters.get_default_conversation_config(user)
|
||||||
|
|
||||||
if default_config is None:
|
if default_config is None:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
|
||||||
|
|
||||||
if default_config.model_type == "openai" and not default_config.openai_config:
|
if default_config.model_type == "openai" and not default_config.openai_config:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
|
||||||
|
|
||||||
|
|
||||||
async def is_ready_to_chat(user: KhojUser):
|
async def is_ready_to_chat(user: KhojUser):
|
||||||
user_conversation_config = (await ConversationAdapters.aget_user_conversation_config(user)) or (
|
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||||
await ConversationAdapters.aget_default_conversation_config()
|
if user_conversation_config == None:
|
||||||
)
|
user_conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||||
|
|
||||||
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
chat_model = user_conversation_config.chat_model
|
chat_model = user_conversation_config.chat_model
|
||||||
|
@ -246,19 +247,19 @@ async def agenerate_chat_response(*args):
|
||||||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
||||||
|
|
||||||
|
|
||||||
async def acreate_title_from_query(query: str) -> str:
|
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
||||||
"""
|
"""
|
||||||
Create a title from the given query
|
Create a title from the given query
|
||||||
"""
|
"""
|
||||||
title_generation_prompt = prompts.subject_generation.format(query=query)
|
title_generation_prompt = prompts.subject_generation.format(query=query)
|
||||||
|
|
||||||
with timer("Chat actor: Generate title from query", logger):
|
with timer("Chat actor: Generate title from query", logger):
|
||||||
response = await send_message_to_model_wrapper(title_generation_prompt)
|
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
||||||
|
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]:
|
async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Check if the system prompt is safe to use
|
Check if the system prompt is safe to use
|
||||||
"""
|
"""
|
||||||
|
@ -267,7 +268,7 @@ async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]:
|
||||||
reason = ""
|
reason = ""
|
||||||
|
|
||||||
with timer("Chat actor: Check if safe prompt", logger):
|
with timer("Chat actor: Check if safe prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(safe_prompt_check)
|
response = await send_message_to_model_wrapper(safe_prompt_check, user=user)
|
||||||
|
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
try:
|
try:
|
||||||
|
@ -288,7 +289,7 @@ async def aget_relevant_information_sources(
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
is_task: bool,
|
is_task: bool,
|
||||||
subscribed: bool,
|
user: KhojUser,
|
||||||
uploaded_image_url: str = None,
|
uploaded_image_url: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
|
@ -326,7 +327,7 @@ async def aget_relevant_information_sources(
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
relevant_tools_prompt,
|
relevant_tools_prompt,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
subscribed=subscribed,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -362,7 +363,12 @@ async def aget_relevant_information_sources(
|
||||||
|
|
||||||
|
|
||||||
async def aget_relevant_output_modes(
|
async def aget_relevant_output_modes(
|
||||||
query: str, conversation_history: dict, is_task: bool = False, uploaded_image_url: str = None, agent: Agent = None
|
query: str,
|
||||||
|
conversation_history: dict,
|
||||||
|
is_task: bool = False,
|
||||||
|
user: KhojUser = None,
|
||||||
|
uploaded_image_url: str = None,
|
||||||
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||||
|
@ -398,7 +404,7 @@ async def aget_relevant_output_modes(
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Chat actor: Infer output mode for chat response", logger):
|
with timer("Chat actor: Infer output mode for chat response", logger):
|
||||||
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object")
|
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
|
@ -453,7 +459,7 @@ async def infer_webpage_urls(
|
||||||
|
|
||||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object"
|
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||||
|
@ -499,7 +505,7 @@ async def generate_online_subqueries(
|
||||||
|
|
||||||
with timer("Chat actor: Generate online search subqueries", logger):
|
with timer("Chat actor: Generate online search subqueries", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object"
|
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
@ -517,7 +523,9 @@ async def generate_online_subqueries(
|
||||||
return [q]
|
return [q]
|
||||||
|
|
||||||
|
|
||||||
async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]:
|
async def schedule_query(
|
||||||
|
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None
|
||||||
|
) -> Tuple[str, ...]:
|
||||||
"""
|
"""
|
||||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||||
"""
|
"""
|
||||||
|
@ -529,7 +537,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object"
|
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
@ -543,7 +551,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
|
||||||
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
||||||
|
|
||||||
|
|
||||||
async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Agent = None) -> Union[str, None]:
|
async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agent: Agent = None) -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
Extract relevant information for a given query from the target corpus
|
Extract relevant information for a given query from the target corpus
|
||||||
"""
|
"""
|
||||||
|
@ -561,14 +569,11 @@ async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Ag
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
|
||||||
|
|
||||||
with timer("Chat actor: Extract relevant information from data", logger):
|
with timer("Chat actor: Extract relevant information from data", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_information,
|
prompts.system_prompt_extract_relevant_information,
|
||||||
chat_model_option=chat_model,
|
user=user,
|
||||||
subscribed=subscribed,
|
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
@ -577,8 +582,8 @@ async def extract_relevant_summary(
|
||||||
q: str,
|
q: str,
|
||||||
corpus: str,
|
corpus: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
subscribed: bool = False,
|
|
||||||
uploaded_image_url: str = None,
|
uploaded_image_url: str = None,
|
||||||
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> Union[str, None]:
|
) -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
|
@ -601,14 +606,11 @@ async def extract_relevant_summary(
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
|
||||||
|
|
||||||
with timer("Chat actor: Extract relevant information from data", logger):
|
with timer("Chat actor: Extract relevant information from data", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_summary,
|
prompts.system_prompt_extract_relevant_summary,
|
||||||
chat_model_option=chat_model,
|
user=user,
|
||||||
subscribed=subscribed,
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
uploaded_image_url=uploaded_image_url,
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
@ -621,8 +623,8 @@ async def generate_better_image_prompt(
|
||||||
note_references: List[Dict[str, Any]],
|
note_references: List[Dict[str, Any]],
|
||||||
online_results: Optional[dict] = None,
|
online_results: Optional[dict] = None,
|
||||||
model_type: Optional[str] = None,
|
model_type: Optional[str] = None,
|
||||||
subscribed: bool = False,
|
|
||||||
uploaded_image_url: Optional[str] = None,
|
uploaded_image_url: Optional[str] = None,
|
||||||
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -672,12 +674,8 @@ async def generate_better_image_prompt(
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
|
||||||
|
|
||||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user)
|
||||||
image_prompt, chat_model_option=chat_model, subscribed=subscribed, uploaded_image_url=uploaded_image_url
|
|
||||||
)
|
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
response = response[1:-1]
|
response = response[1:-1]
|
||||||
|
@ -689,14 +687,10 @@ async def send_message_to_model_wrapper(
|
||||||
message: str,
|
message: str,
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
chat_model_option: ChatModelOptions = None,
|
user: KhojUser = None,
|
||||||
subscribed: bool = False,
|
|
||||||
uploaded_image_url: str = None,
|
uploaded_image_url: str = None,
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = (
|
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||||
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
|
|
||||||
)
|
|
||||||
|
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = conversation_config.vision_enabled
|
||||||
if not vision_available and uploaded_image_url:
|
if not vision_available and uploaded_image_url:
|
||||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||||
|
@ -704,6 +698,7 @@ async def send_message_to_model_wrapper(
|
||||||
conversation_config = vision_enabled_config
|
conversation_config = vision_enabled_config
|
||||||
vision_available = True
|
vision_available = True
|
||||||
|
|
||||||
|
subscribed = await ais_user_subscribed(user)
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
max_tokens = (
|
max_tokens = (
|
||||||
conversation_config.subscribed_max_prompt_size
|
conversation_config.subscribed_max_prompt_size
|
||||||
|
@ -802,8 +797,9 @@ def send_message_to_model_wrapper_sync(
|
||||||
message: str,
|
message: str,
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
|
user: KhojUser = None,
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config()
|
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
||||||
|
|
||||||
if conversation_config is None:
|
if conversation_config is None:
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||||
|
@ -1182,7 +1178,7 @@ class CommonQueryParamsClass:
|
||||||
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
|
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
|
||||||
|
|
||||||
|
|
||||||
def should_notify(original_query: str, executed_query: str, ai_response: str) -> bool:
|
def should_notify(original_query: str, executed_query: str, ai_response: str, user: KhojUser) -> bool:
|
||||||
"""
|
"""
|
||||||
Decide whether to notify the user of the AI response.
|
Decide whether to notify the user of the AI response.
|
||||||
Default to notifying the user for now.
|
Default to notifying the user for now.
|
||||||
|
@ -1199,7 +1195,7 @@ def should_notify(original_query: str, executed_query: str, ai_response: str) ->
|
||||||
with timer("Chat actor: Decide to notify user of automation response", logger):
|
with timer("Chat actor: Decide to notify user of automation response", logger):
|
||||||
try:
|
try:
|
||||||
# TODO Replace with async call so we don't have to maintain a sync version
|
# TODO Replace with async call so we don't have to maintain a sync version
|
||||||
response = send_message_to_model_wrapper_sync(to_notify_or_not)
|
response = send_message_to_model_wrapper_sync(to_notify_or_not, user)
|
||||||
should_notify_result = "no" not in response.lower()
|
should_notify_result = "no" not in response.lower()
|
||||||
logger.info(f'Decided to {"not " if not should_notify_result else ""}notify user of automation response.')
|
logger.info(f'Decided to {"not " if not should_notify_result else ""}notify user of automation response.')
|
||||||
return should_notify_result
|
return should_notify_result
|
||||||
|
@ -1291,7 +1287,9 @@ def scheduled_chat(
|
||||||
ai_response = raw_response.text
|
ai_response = raw_response.text
|
||||||
|
|
||||||
# Notify user if the AI response is satisfactory
|
# Notify user if the AI response is satisfactory
|
||||||
if should_notify(original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response):
|
if should_notify(
|
||||||
|
original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response, user=user
|
||||||
|
):
|
||||||
if is_resend_enabled():
|
if is_resend_enabled():
|
||||||
send_task_email(user.get_short_name(), user.email, cleaned_query, ai_response, subject, is_image)
|
send_task_email(user.get_short_name(), user.email, cleaned_query, ai_response, subject, is_image)
|
||||||
else:
|
else:
|
||||||
|
@ -1301,7 +1299,7 @@ def scheduled_chat(
|
||||||
async def create_automation(
|
async def create_automation(
|
||||||
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
|
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
|
||||||
):
|
):
|
||||||
crontime, query_to_run, subject = await schedule_query(q, meta_log)
|
crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
|
||||||
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
|
||||||
return job, crontime, query_to_run, subject
|
return job, crontime, query_to_run, subject
|
||||||
|
|
||||||
|
@ -1495,9 +1493,9 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
||||||
current_notion_config = get_user_notion_config(user)
|
current_notion_config = get_user_notion_config(user)
|
||||||
notion_token = current_notion_config.token if current_notion_config else ""
|
notion_token = current_notion_config.token if current_notion_config else ""
|
||||||
|
|
||||||
selected_chat_model_config = (
|
selected_chat_model_config = ConversationAdapters.get_conversation_config(
|
||||||
ConversationAdapters.get_conversation_config(user) or ConversationAdapters.get_default_conversation_config()
|
user
|
||||||
)
|
) or ConversationAdapters.get_default_conversation_config(user)
|
||||||
chat_models = ConversationAdapters.get_conversation_processor_options().all()
|
chat_models = ConversationAdapters.get_conversation_processor_options().all()
|
||||||
chat_model_options = list()
|
chat_model_options = list()
|
||||||
for chat_model in chat_models:
|
for chat_model in chat_models:
|
||||||
|
|
|
@ -129,9 +129,6 @@ def initialization(interactive: bool = True):
|
||||||
if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists():
|
if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists():
|
||||||
default_chat_model_name = user_chat_model_name
|
default_chat_model_name = user_chat_model_name
|
||||||
|
|
||||||
# Create a server chat settings object with the default chat model
|
|
||||||
default_chat_model = ChatModelOptions.objects.filter(chat_model=default_chat_model_name).first()
|
|
||||||
ServerChatSettings.objects.create(chat_default=default_chat_model)
|
|
||||||
logger.info("🗣️ Chat model configuration complete")
|
logger.info("🗣️ Chat model configuration complete")
|
||||||
|
|
||||||
# Set up offline speech to text model
|
# Set up offline speech to text model
|
||||||
|
|
Loading…
Reference in a new issue