diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 1ca7e041..b60c00d1 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -253,7 +253,7 @@ def configure_server( state.SearchType = configure_search_types() state.search_models = configure_search(state.search_models, state.config.search_type) - setup_default_agent() + setup_default_agent(user) message = "📡 Telemetry disabled" if telemetry_disabled(state.config.app) else "📡 Telemetry enabled" logger.info(message) @@ -265,8 +265,8 @@ def configure_server( raise e -def setup_default_agent(): - AgentAdapters.create_default_agent() +def setup_default_agent(user: KhojUser): + AgentAdapters.create_default_agent(user) def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, user: KhojUser = None): diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 76fd31ab..51a211b6 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -647,8 +647,8 @@ class AgentAdapters: return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first() @staticmethod - def create_default_agent(): - default_conversation_config = ConversationAdapters.get_default_conversation_config() + def create_default_agent(user: KhojUser): + default_conversation_config = ConversationAdapters.get_default_conversation_config(user) if default_conversation_config is None: logger.info("No default conversation config found, skipping default agent creation") return None @@ -972,29 +972,51 @@ class ConversationAdapters: return VoiceModelOption.objects.first() @staticmethod - def get_default_conversation_config(): + def get_default_conversation_config(user: KhojUser = None): + """Get default conversation config. Prefer chat model by server admin > user > first created chat model""" + # Get the server chat settings server_chat_settings = ServerChatSettings.objects.first() - if server_chat_settings is None or server_chat_settings.chat_default is None: - return ChatModelOptions.objects.filter().first() - return server_chat_settings.chat_default + if server_chat_settings is not None and server_chat_settings.chat_default is not None: + return server_chat_settings.chat_default + + # Get the user's chat settings, if the server chat settings are not set + user_chat_settings = UserConversationConfig.objects.filter(user=user).first() if user else None + if user_chat_settings is not None and user_chat_settings.setting is not None: + return user_chat_settings.setting + + # Get the first chat model if even the user chat settings are not set + return ChatModelOptions.objects.filter().first() @staticmethod - async def aget_default_conversation_config(): + async def aget_default_conversation_config(user: KhojUser = None): + """Get default conversation config. Prefer chat model by server admin > user > first created chat model""" + # Get the server chat settings server_chat_settings: ServerChatSettings = ( await ServerChatSettings.objects.filter() .prefetch_related("chat_default", "chat_default__openai_config") .afirst() ) - if server_chat_settings is None or server_chat_settings.chat_default is None: - return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() - return server_chat_settings.chat_default + if server_chat_settings is not None and server_chat_settings.chat_default is not None: + return server_chat_settings.chat_default + + # Get the user's chat settings, if the server chat settings are not set + user_chat_settings = ( + (await UserConversationConfig.objects.filter(user=user).prefetch_related("setting__openai_config").afirst()) + if user + else None + ) + if user_chat_settings is not None and user_chat_settings.setting is not None: + return user_chat_settings.setting + + # Get the first chat model if even the user chat settings are not set + return await ChatModelOptions.objects.filter().prefetch_related("openai_config").afirst() @staticmethod def get_advanced_conversation_config(): server_chat_settings = ServerChatSettings.objects.first() - if server_chat_settings is None or server_chat_settings.chat_advanced is None: - return ConversationAdapters.get_default_conversation_config() - return server_chat_settings.chat_advanced + if server_chat_settings is not None and server_chat_settings.chat_advanced is not None: + return server_chat_settings.chat_advanced + return ConversationAdapters.get_default_conversation_config() @staticmethod async def aget_advanced_conversation_config(): @@ -1003,9 +1025,9 @@ class ConversationAdapters: .prefetch_related("chat_advanced", "chat_advanced__openai_config") .afirst() ) - if server_chat_settings is None or server_chat_settings.chat_advanced is None: - return await ConversationAdapters.aget_default_conversation_config() - return server_chat_settings.chat_advanced + if server_chat_settings is not None or server_chat_settings.chat_advanced is not None: + return server_chat_settings.chat_advanced + return await ConversationAdapters.aget_default_conversation_config() @staticmethod def create_conversation_from_public_conversation( diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index ef7105ca..59073731 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -25,7 +25,6 @@ async def text_to_image( location_data: LocationData, references: List[Dict[str, Any]], online_results: Dict[str, Any], - subscribed: bool = False, send_status_func: Optional[Callable] = None, uploaded_image_url: Optional[str] = None, agent: Agent = None, @@ -66,8 +65,8 @@ async def text_to_image( note_references=references, online_results=online_results, model_type=text_to_image_config.model_type, - subscribed=subscribed, uploaded_image_url=uploaded_image_url, + user=user, agent=agent, ) diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 393442c4..16539b5c 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -102,7 +102,7 @@ async def search_online( async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} tasks = [ - read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed, agent=agent) + read_webpage_and_extract_content(subquery, link, content, user=user, agent=agent) for link, subquery, content in webpages ] results = await asyncio.gather(*tasks) @@ -158,7 +158,7 @@ async def read_webpages( webpage_links_str = "\n- " + "\n- ".join(list(urls)) async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): yield {ChatEvent.STATUS: event} - tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed, agent=agent) for url in urls] + tasks = [read_webpage_and_extract_content(query, url, user=user, agent=agent) for url in urls] results = await asyncio.gather(*tasks) response: Dict[str, Dict] = defaultdict(dict) @@ -169,14 +169,14 @@ async def read_webpages( async def read_webpage_and_extract_content( - subquery: str, url: str, content: str = None, subscribed: bool = False, agent: Agent = None + subquery: str, url: str, content: str = None, user: KhojUser = None, agent: Agent = None ) -> Tuple[str, Union[None, str], str]: try: if is_none_or_empty(content): with timer(f"Reading web page at '{url}' took", logger): content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url) with timer(f"Extracting relevant information from web page at '{url}' took", logger): - extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed, agent=agent) + extracted_info = await extract_relevant_info(subquery, content, user=user, agent=agent) return subquery, extracted_info, url except Exception as e: logger.error(f"Failed to read web page at '{url}' with {e}") diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 11ab1112..59948b47 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -395,7 +395,7 @@ async def extract_references_and_questions( # Infer search queries from user message with timer("Extracting search queries took", logger): # If we've reached here, either the user has enabled offline chat or the openai model is enabled. - conversation_config = await ConversationAdapters.aget_default_conversation_config() + conversation_config = await ConversationAdapters.aget_default_conversation_config(user) vision_enabled = conversation_config.vision_enabled if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index b175ef8a..d4f0c27f 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -194,7 +194,7 @@ def chat_history( n: Optional[int] = None, ): user = request.user.object - validate_conversation_config() + validate_conversation_config(user) # Load Conversation History conversation = ConversationAdapters.get_conversation_by_user( @@ -694,7 +694,7 @@ async def chat( q, meta_log, is_automated_task, - subscribed=subscribed, + user=user, uploaded_image_url=uploaded_image_url, agent=agent, ) @@ -704,7 +704,7 @@ async def chat( ): yield result - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent) + mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent) async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): yield result if mode not in conversation_commands: @@ -767,8 +767,8 @@ async def chat( q, contextual_data, conversation_history=meta_log, - subscribed=subscribed, uploaded_image_url=uploaded_image_url, + user=user, agent=agent, ) response_log = str(response) @@ -957,7 +957,6 @@ async def chat( location_data=location, references=compiled_references, online_results=online_results, - subscribed=subscribed, send_status_func=partial(send_event, ChatEvent.STATUS), uploaded_image_url=uploaded_image_url, agent=agent, @@ -1192,7 +1191,7 @@ async def get_chat( if conversation_commands == [ConversationCommand.Default] or is_automated_task: conversation_commands = await aget_relevant_information_sources( - q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url + q, meta_log, is_automated_task, user=user, uploaded_image_url=uploaded_image_url ) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) async for result in send_event( @@ -1200,7 +1199,7 @@ async def get_chat( ): yield result - mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url) + mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url) async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): yield result if mode not in conversation_commands: @@ -1252,7 +1251,7 @@ async def get_chat( q, contextual_data, conversation_history=meta_log, - subscribed=subscribed, + user=user, uploaded_image_url=uploaded_image_url, ) response_log = str(response) @@ -1438,7 +1437,6 @@ async def get_chat( location_data=location, references=compiled_references, online_results=online_results, - subscribed=subscribed, send_status_func=partial(send_event, ChatEvent.STATUS), uploaded_image_url=uploaded_image_url, ): diff --git a/src/khoj/routers/api_model.py b/src/khoj/routers/api_model.py index d5af4ba0..fc6be626 100644 --- a/src/khoj/routers/api_model.py +++ b/src/khoj/routers/api_model.py @@ -40,7 +40,7 @@ def get_user_chat_model( chat_model = ConversationAdapters.get_conversation_config(user) if chat_model is None: - chat_model = ConversationAdapters.get_default_conversation_config() + chat_model = ConversationAdapters.get_default_conversation_config(user) return Response(status_code=200, content=json.dumps({"id": chat_model.id, "chat_model": chat_model.chat_model})) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index fdb1aa12..245fdf09 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -39,6 +39,7 @@ from khoj.database.adapters import ( AutomationAdapters, ConversationAdapters, EntryAdapters, + ais_user_subscribed, create_khoj_token, get_khoj_tokens, get_user_name, @@ -119,20 +120,20 @@ def is_query_empty(query: str) -> bool: return is_none_or_empty(query.strip()) -def validate_conversation_config(): - default_config = ConversationAdapters.get_default_conversation_config() +def validate_conversation_config(user: KhojUser): + default_config = ConversationAdapters.get_default_conversation_config(user) if default_config is None: - raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") + raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.") if default_config.model_type == "openai" and not default_config.openai_config: - raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") + raise HTTPException(status_code=500, detail="Contact the server administrator to add a chat model.") async def is_ready_to_chat(user: KhojUser): - user_conversation_config = (await ConversationAdapters.aget_user_conversation_config(user)) or ( - await ConversationAdapters.aget_default_conversation_config() - ) + user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user) + if user_conversation_config == None: + user_conversation_config = await ConversationAdapters.aget_default_conversation_config() if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE: chat_model = user_conversation_config.chat_model @@ -246,19 +247,19 @@ async def agenerate_chat_response(*args): return await loop.run_in_executor(executor, generate_chat_response, *args) -async def acreate_title_from_query(query: str) -> str: +async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: """ Create a title from the given query """ title_generation_prompt = prompts.subject_generation.format(query=query) with timer("Chat actor: Generate title from query", logger): - response = await send_message_to_model_wrapper(title_generation_prompt) + response = await send_message_to_model_wrapper(title_generation_prompt, user=user) return response.strip() -async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]: +async def acheck_if_safe_prompt(system_prompt: str, user: KhojUser = None) -> Tuple[bool, str]: """ Check if the system prompt is safe to use """ @@ -267,7 +268,7 @@ async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]: reason = "" with timer("Chat actor: Check if safe prompt", logger): - response = await send_message_to_model_wrapper(safe_prompt_check) + response = await send_message_to_model_wrapper(safe_prompt_check, user=user) response = response.strip() try: @@ -288,7 +289,7 @@ async def aget_relevant_information_sources( query: str, conversation_history: dict, is_task: bool, - subscribed: bool, + user: KhojUser, uploaded_image_url: str = None, agent: Agent = None, ): @@ -326,7 +327,7 @@ async def aget_relevant_information_sources( response = await send_message_to_model_wrapper( relevant_tools_prompt, response_type="json_object", - subscribed=subscribed, + user=user, ) try: @@ -362,7 +363,12 @@ async def aget_relevant_information_sources( async def aget_relevant_output_modes( - query: str, conversation_history: dict, is_task: bool = False, uploaded_image_url: str = None, agent: Agent = None + query: str, + conversation_history: dict, + is_task: bool = False, + user: KhojUser = None, + uploaded_image_url: str = None, + agent: Agent = None, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. @@ -398,7 +404,7 @@ async def aget_relevant_output_modes( ) with timer("Chat actor: Infer output mode for chat response", logger): - response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object") + response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object", user=user) try: response = response.strip() @@ -453,7 +459,7 @@ async def infer_webpage_urls( with timer("Chat actor: Infer webpage urls to read", logger): response = await send_message_to_model_wrapper( - online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object" + online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user ) # Validate that the response is a non-empty, JSON-serializable list of URLs @@ -499,7 +505,7 @@ async def generate_online_subqueries( with timer("Chat actor: Generate online search subqueries", logger): response = await send_message_to_model_wrapper( - online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object" + online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user ) # Validate that the response is a non-empty, JSON-serializable list @@ -517,7 +523,9 @@ async def generate_online_subqueries( return [q] -async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: str = None) -> Tuple[str, ...]: +async def schedule_query( + q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None +) -> Tuple[str, ...]: """ Schedule the date, time to run the query. Assume the server timezone is UTC. """ @@ -529,7 +537,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: ) raw_response = await send_message_to_model_wrapper( - crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object" + crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user ) # Validate that the response is a non-empty, JSON-serializable list @@ -543,7 +551,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url: raise AssertionError(f"Invalid response for scheduling query: {raw_response}") -async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Agent = None) -> Union[str, None]: +async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agent: Agent = None) -> Union[str, None]: """ Extract relevant information for a given query from the target corpus """ @@ -561,14 +569,11 @@ async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Ag personality_context=personality_context, ) - chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config() - with timer("Chat actor: Extract relevant information from data", logger): response = await send_message_to_model_wrapper( extract_relevant_information, prompts.system_prompt_extract_relevant_information, - chat_model_option=chat_model, - subscribed=subscribed, + user=user, ) return response.strip() @@ -577,8 +582,8 @@ async def extract_relevant_summary( q: str, corpus: str, conversation_history: dict, - subscribed: bool = False, uploaded_image_url: str = None, + user: KhojUser = None, agent: Agent = None, ) -> Union[str, None]: """ @@ -601,14 +606,11 @@ async def extract_relevant_summary( personality_context=personality_context, ) - chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config() - with timer("Chat actor: Extract relevant information from data", logger): response = await send_message_to_model_wrapper( extract_relevant_information, prompts.system_prompt_extract_relevant_summary, - chat_model_option=chat_model, - subscribed=subscribed, + user=user, uploaded_image_url=uploaded_image_url, ) return response.strip() @@ -621,8 +623,8 @@ async def generate_better_image_prompt( note_references: List[Dict[str, Any]], online_results: Optional[dict] = None, model_type: Optional[str] = None, - subscribed: bool = False, uploaded_image_url: Optional[str] = None, + user: KhojUser = None, agent: Agent = None, ) -> str: """ @@ -672,12 +674,8 @@ async def generate_better_image_prompt( personality_context=personality_context, ) - chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config() - with timer("Chat actor: Generate contextual image prompt", logger): - response = await send_message_to_model_wrapper( - image_prompt, chat_model_option=chat_model, subscribed=subscribed, uploaded_image_url=uploaded_image_url - ) + response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): response = response[1:-1] @@ -689,14 +687,10 @@ async def send_message_to_model_wrapper( message: str, system_message: str = "", response_type: str = "text", - chat_model_option: ChatModelOptions = None, - subscribed: bool = False, + user: KhojUser = None, uploaded_image_url: str = None, ): - conversation_config: ChatModelOptions = ( - chat_model_option or await ConversationAdapters.aget_default_conversation_config() - ) - + conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) vision_available = conversation_config.vision_enabled if not vision_available and uploaded_image_url: vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() @@ -704,6 +698,7 @@ async def send_message_to_model_wrapper( conversation_config = vision_enabled_config vision_available = True + subscribed = await ais_user_subscribed(user) chat_model = conversation_config.chat_model max_tokens = ( conversation_config.subscribed_max_prompt_size @@ -802,8 +797,9 @@ def send_message_to_model_wrapper_sync( message: str, system_message: str = "", response_type: str = "text", + user: KhojUser = None, ): - conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config() + conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) if conversation_config is None: raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") @@ -1182,7 +1178,7 @@ class CommonQueryParamsClass: CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()] -def should_notify(original_query: str, executed_query: str, ai_response: str) -> bool: +def should_notify(original_query: str, executed_query: str, ai_response: str, user: KhojUser) -> bool: """ Decide whether to notify the user of the AI response. Default to notifying the user for now. @@ -1199,7 +1195,7 @@ def should_notify(original_query: str, executed_query: str, ai_response: str) -> with timer("Chat actor: Decide to notify user of automation response", logger): try: # TODO Replace with async call so we don't have to maintain a sync version - response = send_message_to_model_wrapper_sync(to_notify_or_not) + response = send_message_to_model_wrapper_sync(to_notify_or_not, user) should_notify_result = "no" not in response.lower() logger.info(f'Decided to {"not " if not should_notify_result else ""}notify user of automation response.') return should_notify_result @@ -1291,7 +1287,9 @@ def scheduled_chat( ai_response = raw_response.text # Notify user if the AI response is satisfactory - if should_notify(original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response): + if should_notify( + original_query=scheduling_request, executed_query=cleaned_query, ai_response=ai_response, user=user + ): if is_resend_enabled(): send_task_email(user.get_short_name(), user.email, cleaned_query, ai_response, subject, is_image) else: @@ -1301,7 +1299,7 @@ def scheduled_chat( async def create_automation( q: str, timezone: str, user: KhojUser, calling_url: URL, meta_log: dict = {}, conversation_id: str = None ): - crontime, query_to_run, subject = await schedule_query(q, meta_log) + crontime, query_to_run, subject = await schedule_query(q, meta_log, user) job = await schedule_automation(query_to_run, subject, crontime, timezone, q, user, calling_url, conversation_id) return job, crontime, query_to_run, subject @@ -1495,9 +1493,9 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) current_notion_config = get_user_notion_config(user) notion_token = current_notion_config.token if current_notion_config else "" - selected_chat_model_config = ( - ConversationAdapters.get_conversation_config(user) or ConversationAdapters.get_default_conversation_config() - ) + selected_chat_model_config = ConversationAdapters.get_conversation_config( + user + ) or ConversationAdapters.get_default_conversation_config(user) chat_models = ConversationAdapters.get_conversation_processor_options().all() chat_model_options = list() for chat_model in chat_models: diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 90bb9921..6a39c41a 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -129,9 +129,6 @@ def initialization(interactive: bool = True): if user_chat_model_name and ChatModelOptions.objects.filter(chat_model=user_chat_model_name).exists(): default_chat_model_name = user_chat_model_name - # Create a server chat settings object with the default chat model - default_chat_model = ChatModelOptions.objects.filter(chat_model=default_chat_model_name).first() - ServerChatSettings.objects.create(chat_default=default_chat_model) logger.info("🗣️ Chat model configuration complete") # Set up offline speech to text model