Simplify switching chat model when self-hosting (#934)

# Overview
- Default to use user chat models for train of thought when no server chat settings created by admins
- Default to not create server chat settings on first run

# Details
This change simplifies switching chat models for self-hosted setups 
by just changing the chat model on the user settings page.

It falls back to use the user chat model for train of thought 
if server chat settings have not been created on the admin panel.

Server chat settings, when set, controls the chat model used 
for Khoj's train of thought and the default user chat model.

Previously a self-hosted user had to update
1. the server chat settings in the admin panel and
2. their own user chat model in the user settings panel

to completely switch to a different chat model 
for both train of thought & response generation respectively

You can still set server chat settings via the admin panel 
to use a different chat model for train of thought vs response generation. 
But this is only useful for advanced, multi-user setups.
This commit is contained in:
Debanjum 2024-10-12 19:58:05 -07:00 committed by GitHub
commit c66c571396
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 102 additions and 88 deletions

View file

@ -253,7 +253,7 @@ def configure_server(
state.SearchType = configure_search_types()
state.search_models = configure_search(state.search_models, state.config.search_type)
setup_default_agent()
setup_default_agent(user)
message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled"
logger.info(message)
@ -265,8 +265,8 @@ def configure_server(
raise e
def setup_default_agent():
AgentAdapters.create_default_agent()
def setup_default_agent(user: KhojUser):
AgentAdapters.create_default_agent(user)
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None):

View file

@ -647,8 +647,8 @@ class AgentAdapters:
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
@staticmethod
def create_default_agent():
default_conversation_config = ConversationAdapters.get_default_conversation_config()
def create_default_agent(user: KhojUser):
default_conversation_config = ConversationAdapters.get_default_conversation_config(user)
if default_conversation_config is None:
logger.info("No default conversation config found, skipping default agent creation")
return None
@ -972,29 +972,51 @@ class ConversationAdapters:
return VoiceModelOption.objects.first()
@staticmethod
def get_default_conversation_config():
def get_default_conversation_config(user: KhojUser = None):
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings
server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is None or server_chat_settings.chat_default is None:
return ChatModelOptions.objects.filter().first()
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
return server_chat_settings.chat_default
# Get the user's chat settings, if the server chat settings are not set
user_chat_settings = UserConversationConfig.objects.filter(user=user).first() if user else None
if user_chat_settings is not None and user_chat_settings.setting is not None:
return user_chat_settings.setting
# Get the first chat model if even the user chat settings are not set
return ChatModelOptions.objects.filter().first()
@staticmethod
async def aget_default_conversation_config():
async def aget_default_conversation_config(user: KhojUser = None):
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings
server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter()
.prefetch_related("chat_default", "chat_default__openai_config")
.afirst()
)
if server_chat_settings is None or server_chat_settings.chat_default is None:
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
return server_chat_settings.chat_default
# Get the user's chat settings, if the server chat settings are not set
user_chat_settings = (
(await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst())
if user
else None
)
if user_chat_settings is not None and user_chat_settings.setting is not None:
return user_chat_settings.setting
# Get the first chat model if even the user chat settings are not set
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
@staticmethod
def get_advanced_conversation_config():
server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
return ConversationAdapters.get_default_conversation_config()
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
return server_chat_settings.chat_advanced
return ConversationAdapters.get_default_conversation_config()
@staticmethod
async def aget_advanced_conversation_config():
@ -1003,9 +1025,9 @@ class ConversationAdapters:
.prefetch_related("chat_advanced", "chat_advanced__openai_config")
.afirst()
)
if server_chat_settings is None or server_chat_settings.chat_advanced is None:
return await ConversationAdapters.aget_default_conversation_config()
if server_chat_settings is not None or server_chat_settings.chat_advanced is not None:
return server_chat_settings.chat_advanced
return await ConversationAdapters.aget_default_conversation_config()
@staticmethod
def create_conversation_from_public_conversation(

View file

@ -25,7 +25,6 @@ async def text_to_image(
location_data: LocationData,
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
agent: Agent = None,
@ -66,8 +65,8 @@ async def text_to_image(
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
user=user,
agent=agent,
)

View file

@ -102,7 +102,7 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed, agent=agent)
read_webpage_and_extract_content(subquery, link, content, user=user, agent=agent)
for link, subquery, content in webpages
]
results = await asyncio.gather(*tasks)
@ -158,7 +158,7 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed, agent=agent) for url in urls]
tasks = [read_webpage_and_extract_content(query, url, user=user, agent=agent) for url in urls]
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
@ -169,14 +169,14 @@ async def read_webpages(
async def read_webpage_and_extract_content(
subquery: str, url: str, content: str = None, subscribed: bool = False, agent: Agent = None
subquery: str, url: str, content: str = None, user: KhojUser = None, agent: Agent = None
) -> Tuple[str, Union[None, str], str]:
try:
if is_none_or_empty(content):
with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed, agent=agent)
extracted_info = await extract_relevant_info(subquery, content, user=user, agent=agent)
return subquery, extracted_info, url
except Exception as e:
logger.error(f"Failed to read web page at '{url}' with {e}")

View file

@ -395,7 +395,7 @@ async def extract_references_and_questions(
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
conversation_config = await ConversationAdapters.aget_default_conversation_config()
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
vision_enabled = conversation_config.vision_enabled
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:

View file

@ -194,7 +194,7 @@ def chat_history(
n: Optional[int] = None,
):
user = request.user.object
validate_conversation_config()
validate_conversation_config(user)
# Load Conversation History
conversation = ConversationAdapters.get_conversation_by_user(
@ -694,7 +694,7 @@ async def chat(
q,
meta_log,
is_automated_task,
subscribed=subscribed,
user=user,
uploaded_image_url=uploaded_image_url,
agent=agent,
)
@ -704,7 +704,7 @@ async def chat(
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent)
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
@ -767,8 +767,8 @@ async def chat(
q,
contextual_data,
conversation_history=meta_log,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
user=user,
agent=agent,
)
response_log = str(response)
@ -957,7 +957,6 @@ async def chat(
location_data=location,
references=compiled_references,
online_results=online_results,
subscribed=subscribed,
send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
agent=agent,
@ -1192,7 +1191,7 @@ async def get_chat(
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(
q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url
q, meta_log, is_automated_task, user=user, uploaded_image_url=uploaded_image_url
)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
@ -1200,7 +1199,7 @@ async def get_chat(
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url)
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
@ -1252,7 +1251,7 @@ async def get_chat(
q,
contextual_data,
conversation_history=meta_log,
subscribed=subscribed,
user=user,
uploaded_image_url=uploaded_image_url,
)
response_log = str(response)
@ -1438,7 +1437,6 @@ async def get_chat(
location_data=location,
references=compiled_references,
online_results=online_results,
subscribed=subscribed,
send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
):

View file

@ -40,7 +40,7 @@ def get_user_chat_model(
chat_model = ConversationAdapters.get_conversation_config(user)
if chat_model is None:
chat_model = ConversationAdapters.get_default_conversation_config()
chat_model = ConversationAdapters.get_default_conversation_config(user)
return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model}))

View file

@ -39,6 +39,7 @@ from khoj.database.adapters import (
AutomationAdapters,
ConversationAdapters,
EntryAdapters,
ais_user_subscribed,
create_khoj_token,
get_khoj_tokens,
get_user_name,
@ -119,20 +120,20 @@ def is_query_empty(query: str) -> bool:
return is_none_or_empty(query.strip())
def validate_conversation_config():
default_config = ConversationAdapters.get_default_conversation_config()
def validate_conversation_config(user: KhojUser):
default_config = ConversationAdapters.get_default_conversation_config(user)
if default_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
if default_config.model_type == "openai" and not default_config.openai_config:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.")
async def is_ready_to_chat(user: KhojUser):
user_conversation_config = (await ConversationAdapters.aget_user_conversation_config(user)) or (
await ConversationAdapters.aget_default_conversation_config()
)
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if user_conversation_config == None:
user_conversation_config = await ConversationAdapters.aget_default_conversation_config()
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
chat_model = user_conversation_config.chat_model
@ -246,19 +247,19 @@ async def agenerate_chat_response(*args):
return await loop.run_in_executor(executor, generate_chat_response, *args)
async def acreate_title_from_query(query: str) -> str:
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
"""
Create a title from the given query
"""
title_generation_prompt = prompts.subject_generation.format(query=query)
with timer("Chat actor: Generate title from query", logger):
response = await send_message_to_model_wrapper(title_generation_prompt)
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
return response.strip()
async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]:
async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None) -> Tuple[bool, str]:
"""
Check if the system prompt is safe to use
"""
@ -267,7 +268,7 @@ async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]:
reason = ""
with timer("Chat actor: Check if safe prompt", logger):
response = await send_message_to_model_wrapper(safe_prompt_check)
response = await send_message_to_model_wrapper(safe_prompt_check, user=user)
response = response.strip()
try:
@ -288,7 +289,7 @@ async def aget_relevant_information_sources(
query: str,
conversation_history: dict,
is_task: bool,
subscribed: bool,
user: KhojUser,
uploaded_image_url: str = None,
agent: Agent = None,
):
@ -326,7 +327,7 @@ async def aget_relevant_information_sources(
response = await send_message_to_model_wrapper(
relevant_tools_prompt,
response_type="json_object",
subscribed=subscribed,
user=user,
)
try:
@ -362,7 +363,12 @@ async def aget_relevant_information_sources(
async def aget_relevant_output_modes(
query: str, conversation_history: dict, is_task: bool = False, uploaded_image_url: str = None, agent: Agent = None
query: str,
conversation_history: dict,
is_task: bool = False,
user: KhojUser = None,
uploaded_image_url: str = None,
agent: Agent = None,
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -398,7 +404,7 @@ async def aget_relevant_output_modes(
)
with timer("Chat actor: Infer output mode for chat response", logger):
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object")
response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user)
try:
response = response.strip()
@ -453,7 +459,7 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object"
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list of URLs
@ -499,7 +505,7 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object"
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list
@ -517,7 +523,9 @@ async def generate_online_subqueries(
return [q]
async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]:
async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None
) -> Tuple[str, ...]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
"""
@ -529,7 +537,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
)
raw_response = await send_message_to_model_wrapper(
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object"
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list
@ -543,7 +551,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Agent = None) -> Union[str, None]:
async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agent: Agent = None) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
"""
@ -561,14 +569,11 @@ async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Ag
personality_context=personality_context,
)
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
chat_model_option=chat_model,
subscribed=subscribed,
user=user,
)
return response.strip()
@ -577,8 +582,8 @@ async def extract_relevant_summary(
q: str,
corpus: str,
conversation_history: dict,
subscribed: bool = False,
uploaded_image_url: str = None,
user: KhojUser = None,
agent: Agent = None,
) -> Union[str, None]:
"""
@ -601,14 +606,11 @@ async def extract_relevant_summary(
personality_context=personality_context,
)
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_summary,
chat_model_option=chat_model,
subscribed=subscribed,
user=user,
uploaded_image_url=uploaded_image_url,
)
return response.strip()
@ -621,8 +623,8 @@ async def generate_better_image_prompt(
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
model_type: Optional[str] = None,
subscribed: bool = False,
uploaded_image_url: Optional[str] = None,
user: KhojUser = None,
agent: Agent = None,
) -> str:
"""
@ -672,12 +674,8 @@ async def generate_better_image_prompt(
personality_context=personality_context,
)
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(
image_prompt, chat_model_option=chat_model, subscribed=subscribed, uploaded_image_url=uploaded_image_url
)
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
@ -689,14 +687,10 @@ async def send_message_to_model_wrapper(
message: str,
system_message: str = "",
response_type: str = "text",
chat_model_option: ChatModelOptions = None,
subscribed: bool = False,
user: KhojUser = None,
uploaded_image_url: str = None,
):
conversation_config: ChatModelOptions = (
chat_model_option or await ConversationAdapters.aget_default_conversation_config()
)
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled
if not vision_available and uploaded_image_url:
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
@ -704,6 +698,7 @@ async def send_message_to_model_wrapper(
conversation_config = vision_enabled_config
vision_available = True
subscribed = await ais_user_subscribed(user)
chat_model = conversation_config.chat_model
max_tokens = (
conversation_config.subscribed_max_prompt_size
@ -802,8 +797,9 @@ def send_message_to_model_wrapper_sync(
message: str,
system_message: str = "",
response_type: str = "text",
user: KhojUser = None,
):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config()
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
if conversation_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
@ -1182,7 +1178,7 @@ class CommonQueryParamsClass:
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
def should_notify(original_query: str, executed_query: str, ai_response: str) -> bool:
def should_notify(original_query: str, executed_query: str, ai_response: str, user: KhojUser) -> bool:
"""
Decide whether to notify the user of the AI response.
Default to notifying the user for now.
@ -1199,7 +1195,7 @@ def should_notify(original_query: str, executed_query: str, ai_response: str) ->
with timer("Chat actor: Decide to notify user of automation response", logger):
try:
# TODO Replace with async call so we don't have to maintain a sync version
response = send_message_to_model_wrapper_sync(to_notify_or_not)
response = send_message_to_model_wrapper_sync(to_notify_or_not, user)
should_notify_result = "no" not in response.lower()
logger.info(f'Decided to {"not " if not should_notify_result else ""}notify user of automation response.')
return should_notify_result
@ -1291,7 +1287,9 @@ def scheduled_chat(
ai_response = raw_response.text
# Notify user if the AI response is satisfactory
if should_notify(original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response):
if should_notify(
original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response, user=user
):
if is_resend_enabled():
send_task_email(user.get_short_name(), user.email, cleaned_query, ai_response, subject, is_image)
else:
@ -1301,7 +1299,7 @@ def scheduled_chat(
async def create_automation(
q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None
):
crontime, query_to_run, subject = await schedule_query(q, meta_log)
crontime, query_to_run, subject = await schedule_query(q, meta_log, user)
job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id)
return job, crontime, query_to_run, subject
@ -1495,9 +1493,9 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
current_notion_config = get_user_notion_config(user)
notion_token = current_notion_config.token if current_notion_config else ""
selected_chat_model_config = (
ConversationAdapters.get_conversation_config(user) or ConversationAdapters.get_default_conversation_config()
)
selected_chat_model_config = ConversationAdapters.get_conversation_config(
user
) or ConversationAdapters.get_default_conversation_config(user)
chat_models = ConversationAdapters.get_conversation_processor_options().all()
chat_model_options = list()
for chat_model in chat_models:

View file

@ -129,9 +129,6 @@ def initialization(interactive: bool = True):
if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists():
default_chat_model_name = user_chat_model_name
# Create a server chat settings object with the default chat model
default_chat_model = ChatModelOptions.objects.filter(chat_model=default_chat_model_name).first()
ServerChatSettings.objects.create(chat_default=default_chat_model)
logger.info("🗣️ Chat model configuration complete")
# Set up offline speech to text model