From a9009ea774608507e07709a9084c8e91bbff0fb9 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 8 Oct 2024 18:09:03 -0700 Subject: [PATCH] Default to use user chat model if server chat settings not defined Fallback to use user chat model for train of thought if server chat settings not defined. This simplifies switching chat models for single-user, self-hosted setups by just changing the chat model on the user settings page. Server chat settings, when set, controls the default user chat model and the chat model that is used for Khoj's train of thought. Previously a self-hosted user had to update both the server chat settings in the admin panel and their own user chat model in the user settings panel to explicitly switch to a different chat model (i.e to switch to a new model for both train of thought & response generation) You can still set server chat settings to use a different chat model for train of thought and response generation. But this is only necessary for advanced self-hosted or cloud hosted setups of Khoj. --- src/khoj/configure.py | 6 +-- src/khoj/database/adapters/__init__.py | 54 ++++++++++++++++++-------- 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 53e19f71..27c5ed08 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -244,7 +244,7 @@ def configure_server( state.SearchType = configure_search_types() state.search_models = configure_search(state.search_models, state.config.search_type) - setup_default_agent() + setup_default_agent(user) message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled" logger.info(message) @@ -256,8 +256,8 @@ def configure_server( raise e -def setup_default_agent(): - AgentAdapters.create_default_agent() +def setup_default_agent(user: KhojUser): + AgentAdapters.create_default_agent(user) def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None): diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 9687ec01..40c82fa6 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -643,8 +643,8 @@ class AgentAdapters: return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first() @staticmethod - def create_default_agent(): - default_conversation_config = ConversationAdapters.get_default_conversation_config() + def create_default_agent(user: KhojUser): + default_conversation_config = ConversationAdapters.get_default_conversation_config(user) if default_conversation_config is None: logger.info("No default conversation config found, skipping default agent creation") return None @@ -968,29 +968,51 @@ class ConversationAdapters: return VoiceModelOption.objects.first() @staticmethod - def get_default_conversation_config(): + def get_default_conversation_config(user: KhojUser = None): + """Get default conversation config. Prefer chat model by server admin > user > first created chat model""" + # Get the server chat settings server_chat_settings = ServerChatSettings.objects.first() - if server_chat_settings is None or server_chat_settings.chat_default is None: - return ChatModelOptions.objects.filter().first() - return server_chat_settings.chat_default + if server_chat_settings is not None and server_chat_settings.chat_default is not None: + return server_chat_settings.chat_default + + # Get the user's chat settings, if the server chat settings are not set + user_chat_settings = UserConversationConfig.objects.filter(user=user).first() if user else None + if user_chat_settings is not None and user_chat_settings.setting is not None: + return user_chat_settings.setting + + # Get the first chat model if even the user chat settings are not set + return ChatModelOptions.objects.filter().first() @staticmethod - async def aget_default_conversation_config(): + async def aget_default_conversation_config(user: KhojUser = None): + """Get default conversation config. Prefer chat model by server admin > user > first created chat model""" + # Get the server chat settings server_chat_settings: ServerChatSettings = ( await ServerChatSettings.objects.filter() .prefetch_related("chat_default", "chat_default__openai_config") .afirst() ) - if server_chat_settings is None or server_chat_settings.chat_default is None: - return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() - return server_chat_settings.chat_default + if server_chat_settings is not None and server_chat_settings.chat_default is not None: + return server_chat_settings.chat_default + + # Get the user's chat settings, if the server chat settings are not set + user_chat_settings = ( + (await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst()) + if user + else None + ) + if user_chat_settings is not None and user_chat_settings.setting is not None: + return user_chat_settings.setting + + # Get the first chat model if even the user chat settings are not set + return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() @staticmethod def get_advanced_conversation_config(): server_chat_settings = ServerChatSettings.objects.first() - if server_chat_settings is None or server_chat_settings.chat_advanced is None: - return ConversationAdapters.get_default_conversation_config() - return server_chat_settings.chat_advanced + if server_chat_settings is not None and server_chat_settings.chat_advanced is not None: + return server_chat_settings.chat_advanced + return ConversationAdapters.get_default_conversation_config() @staticmethod async def aget_advanced_conversation_config(): @@ -999,9 +1021,9 @@ class ConversationAdapters: .prefetch_related("chat_advanced", "chat_advanced__openai_config") .afirst() ) - if server_chat_settings is None or server_chat_settings.chat_advanced is None: - return await ConversationAdapters.aget_default_conversation_config() - return server_chat_settings.chat_advanced + if server_chat_settings is not None or server_chat_settings.chat_advanced is not None: + return server_chat_settings.chat_advanced + return await ConversationAdapters.aget_default_conversation_config() @staticmethod def create_conversation_from_public_conversation(