diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 4490b7d2..a2c531f8 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1009,8 +1009,15 @@ class ConversationAdapters: """Get default conversation config. Prefer chat model by server admin > user > first created chat model""" # Get the server chat settings server_chat_settings = ServerChatSettings.objects.first() - if server_chat_settings is not None and server_chat_settings.chat_default is not None: - return server_chat_settings.chat_default + + is_subscribed = is_user_subscribed(user) if user else False + if server_chat_settings: + # If the user is subscribed and the advanced model is enabled, return the advanced model + if is_subscribed and server_chat_settings.chat_advanced: + return server_chat_settings.chat_advanced + # If the default model is set, return it + if server_chat_settings.chat_default: + return server_chat_settings.chat_default # Get the user's chat settings, if the server chat settings are not set user_chat_settings = UserConversationConfig.objects.filter(user=user).first() if user else None @@ -1026,11 +1033,20 @@ class ConversationAdapters: # Get the server chat settings server_chat_settings: ServerChatSettings = ( await ServerChatSettings.objects.filter() - .prefetch_related("chat_default", "chat_default__openai_config") + .prefetch_related( + "chat_default", "chat_default__openai_config", "chat_advanced", "chat_advanced__openai_config" + ) .afirst() ) - if server_chat_settings is not None and server_chat_settings.chat_default is not None: - return server_chat_settings.chat_default + is_subscribed = await ais_user_subscribed(user) if user else False + + if server_chat_settings: + # If the user is subscribed and the advanced model is enabled, return the advanced model + if is_subscribed and server_chat_settings.chat_advanced: + return server_chat_settings.chat_advanced + # If the default model is set, return it + if server_chat_settings.chat_default: + return server_chat_settings.chat_default # Get the user's chat settings, if the server chat settings are not set user_chat_settings = ( diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index d0b78d9a..5ebfd911 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -817,7 +817,7 @@ async def chat( if not q: conversation_config = await ConversationAdapters.aget_user_conversation_config(user) if conversation_config == None: - conversation_config = await ConversationAdapters.aget_default_conversation_config() + conversation_config = await ConversationAdapters.aget_default_conversation_config(user) model_type = conversation_config.model_type formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device()) async for result in send_llm_response(formatted_help): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 0f5a7006..2af1f64d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -140,7 +140,7 @@ def validate_conversation_config(user: KhojUser): async def is_ready_to_chat(user: KhojUser): user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user) if user_conversation_config == None: - user_conversation_config = await ConversationAdapters.aget_default_conversation_config() + user_conversation_config = await ConversationAdapters.aget_default_conversation_config(user) if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: chat_model = user_conversation_config.chat_model