Fix default chat model to use user model if no server chat model set

- Advanced chat model should also fallback to user chat model if set
- Get conversation config should falback to user chat model if set

These assume no server chat model settings is configured
This commit is contained in:
Debanjum Singh Solanky 2024-10-13 03:02:29 -07:00
parent 81aa1b5589
commit 931c56182e
3 changed files with 9 additions and 13 deletions

View file

@ -939,21 +939,21 @@ class ConversationAdapters:
def get_conversation_config(user: KhojUser): def get_conversation_config(user: KhojUser):
subscribed = is_user_subscribed(user) subscribed = is_user_subscribed(user)
if not subscribed: if not subscribed:
return ConversationAdapters.get_default_conversation_config() return ConversationAdapters.get_default_conversation_config(user)
config = UserConversationConfig.objects.filter(user=user).first() config = UserConversationConfig.objects.filter(user=user).first()
if config: if config:
return config.setting return config.setting
return ConversationAdapters.get_advanced_conversation_config() return ConversationAdapters.get_advanced_conversation_config(user)
@staticmethod @staticmethod
async def aget_conversation_config(user: KhojUser): async def aget_conversation_config(user: KhojUser):
subscribed = await ais_user_subscribed(user) subscribed = await ais_user_subscribed(user)
if not subscribed: if not subscribed:
return await ConversationAdapters.aget_default_conversation_config() return await ConversationAdapters.aget_default_conversation_config(user)
config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst() config = await UserConversationConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if config: if config:
return config.setting return config.setting
return ConversationAdapters.aget_advanced_conversation_config() return ConversationAdapters.aget_advanced_conversation_config(user)
@staticmethod @staticmethod
async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]: async def aget_voice_model_config(user: KhojUser) -> Optional[VoiceModelOption]:
@ -1014,22 +1014,22 @@ class ConversationAdapters:
return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst()
@staticmethod @staticmethod
def get_advanced_conversation_config(): def get_advanced_conversation_config(user: KhojUser):
server_chat_settings = ServerChatSettings.objects.first() server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is not None and server_chat_settings.chat_advanced is not None: if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
return server_chat_settings.chat_advanced return server_chat_settings.chat_advanced
return ConversationAdapters.get_default_conversation_config() return ConversationAdapters.get_default_conversation_config(user)
@staticmethod @staticmethod
async def aget_advanced_conversation_config(): async def aget_advanced_conversation_config(user: KhojUser = None):
server_chat_settings: ServerChatSettings = ( server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter() await ServerChatSettings.objects.filter()
.prefetch_related("chat_advanced", "chat_advanced__openai_config") .prefetch_related("chat_advanced", "chat_advanced__openai_config")
.afirst() .afirst()
) )
if server_chat_settings is not None or server_chat_settings.chat_advanced is not None: if server_chat_settings is not None and server_chat_settings.chat_advanced is not None:
return server_chat_settings.chat_advanced return server_chat_settings.chat_advanced
return await ConversationAdapters.aget_default_conversation_config() return await ConversationAdapters.aget_default_conversation_config(user)
@staticmethod @staticmethod
def create_conversation_from_public_conversation( def create_conversation_from_public_conversation(

View file

@ -53,7 +53,6 @@ async def search_online(
conversation_history: dict, conversation_history: dict,
location: LocationData, location: LocationData,
user: KhojUser, user: KhojUser,
subscribed: bool = False,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [], custom_filters: List[str] = [],
uploaded_image_url: str = None, uploaded_image_url: str = None,
@ -141,7 +140,6 @@ async def read_webpages(
conversation_history: dict, conversation_history: dict,
location: LocationData, location: LocationData,
user: KhojUser, user: KhojUser,
subscribed: bool = False,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None, uploaded_image_url: str = None,
agent: Agent = None, agent: Agent = None,

View file

@ -885,7 +885,6 @@ async def chat(
meta_log, meta_log,
location, location,
user, user,
subscribed,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
custom_filters, custom_filters,
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
@ -910,7 +909,6 @@ async def chat(
meta_log, meta_log,
location, location,
user, user,
subscribed,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url, uploaded_image_url=uploaded_image_url,
agent=agent, agent=agent,