diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 164b4023..1c8336a3 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1387,6 +1387,10 @@ class FileObjectAdapters: async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) + @staticmethod + async def async_get_file_objects_by_names(user: KhojUser, file_names: List[str]): + return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names)) + @staticmethod async def async_get_all_file_objects(user: KhojUser): return await sync_to_async(list)(FileObject.objects.filter(user=user)) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index e2fd0c74..df81f56f 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -147,6 +147,7 @@ def converse_anthropic( query_images: Optional[list[str]] = None, vision_available: bool = False, tracer: dict = {}, + attached_files: str = None, ): """ Converse with user using Anthropic's Claude @@ -203,6 +204,7 @@ def converse_anthropic( query_images=query_images, vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.ANTHROPIC, + attached_files=attached_files, ) messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index aebda1a8..b7ec018d 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -108,7 +108,14 @@ def extract_questions_gemini( def gemini_send_message_to_model( - messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={} + messages, + api_key, + model, + response_type="text", + temperature=0, + model_kwargs=None, + tracer={}, + attached_files: str = None, ): """ Send message to model @@ -152,6 +159,7 @@ def converse_gemini( query_images: Optional[list[str]] = None, vision_available: bool = False, tracer={}, + attached_files: str = None, ): """ Converse with user using Google's Gemini @@ -209,6 +217,7 @@ def converse_gemini( query_images=query_images, vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.GOOGLE, + attached_files=attached_files, ) messages, system_prompt = format_messages_for_gemini(messages, system_prompt) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index aaaaa081..d0b62f3d 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -38,6 +38,7 @@ def extract_questions_offline( temperature: float = 0.7, personality_context: Optional[str] = None, tracer: dict = {}, + attached_files: str = None, ) -> List[str]: """ Infer search queries to retrieve relevant notes to answer user query @@ -87,6 +88,7 @@ def extract_questions_offline( loaded_model=offline_chat_model, max_prompt_size=max_prompt_size, model_type=ChatModelOptions.ModelType.OFFLINE, + attached_files=attached_files, ) state.chat_lock.acquire() @@ -153,6 +155,7 @@ def converse_offline( user_name: str = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ) -> Union[ThreadedGenerator, Iterator[str]]: """ Converse with user using Llama @@ -216,6 +219,7 @@ def converse_offline( max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, model_type=ChatModelOptions.ModelType.OFFLINE, + attached_files=attached_files, ) truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index c376a90e..bdb67448 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -149,6 +149,7 @@ def converse( query_images: Optional[list[str]] = None, vision_available: bool = False, tracer: dict = {}, + attached_files: str = None, ): """ Converse with user using OpenAI's ChatGPT @@ -206,6 +207,7 @@ def converse( query_images=query_images, vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.OPENAI, + attached_files=attached_files, ) truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) logger.debug(f"Conversation Context for GPT: {truncated_messages}") diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 74c464d9..b510a09a 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -318,6 +318,7 @@ def generate_chatml_messages_with_context( vision_enabled=False, model_type="", context_message="", + attached_files: str = None, ): """Generate chat messages with appropriate context from previous conversation to send to the chat model""" # Set max prompt size from user config or based on pre-configured for model and machine specs @@ -341,8 +342,10 @@ def generate_chatml_messages_with_context( {f"# File: {item['file']}\n## {item['compiled']}\n" for item in chat.get("context") or []} ) message_context += f"{prompts.notes_conversation.format(references=references)}\n\n" + if not is_none_or_empty(chat.get("onlineContext")): message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" + if not is_none_or_empty(message_context): reconstructed_context_message = ChatMessage(content=message_context, role="user") chatml_messages.insert(0, reconstructed_context_message) @@ -366,8 +369,13 @@ def generate_chatml_messages_with_context( ) if not is_none_or_empty(context_message): messages.append(ChatMessage(content=context_message, role="user")) + + if not is_none_or_empty(attached_files): + messages.append(ChatMessage(content=attached_files, role="user")) + if len(chatml_messages) > 0: messages += chatml_messages + if not is_none_or_empty(system_message): messages.append(ChatMessage(content=system_message, role="system")) diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index bdc00e09..ec5254ec 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -29,6 +29,7 @@ async def text_to_image( query_images: Optional[List[str]] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ): status_code = 200 image = None @@ -70,6 +71,7 @@ async def text_to_image( user=user, agent=agent, tracer=tracer, + attached_files=attached_files, ) if send_status_func: diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index c6fc7c20..3b4bd16a 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -68,6 +68,7 @@ async def search_online( query_images: List[str] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ): query += " ".join(custom_filters) if not is_internet_connected(): @@ -77,7 +78,14 @@ async def search_online( # Breakdown the query into subqueries to get the correct answer subqueries = await generate_online_subqueries( - query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer + query, + conversation_history, + location, + user, + query_images=query_images, + agent=agent, + tracer=tracer, + attached_files=attached_files, ) response_dict = {} @@ -159,11 +167,19 @@ async def read_webpages( agent: Agent = None, tracer: dict = {}, max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ, + attached_files: str = None, ): "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") urls = await infer_webpage_urls( - query, conversation_history, location, user, query_images, agent=agent, tracer=tracer + query, + conversation_history, + location, + user, + query_images, + agent=agent, + tracer=tracer, + attached_files=attached_files, ) # Get the top 10 web pages to read diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index d4ba9af1..86388aee 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -6,6 +6,7 @@ import os from typing import Any, Callable, List, Optional import aiohttp +import requests from khoj.database.adapters import ais_user_subscribed from khoj.database.models import Agent, KhojUser @@ -37,6 +38,7 @@ async def run_code( agent: Agent = None, sandbox_url: str = SANDBOX_URL, tracer: dict = {}, + attached_files: str = None, ): # Generate Code if send_status_func: @@ -53,6 +55,7 @@ async def run_code( query_images, agent, tracer, + attached_files, ) except Exception as e: raise ValueError(f"Failed to generate code for {query} with error: {e}") @@ -82,6 +85,7 @@ async def generate_python_code( query_images: List[str] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ) -> List[str]: location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" @@ -109,6 +113,7 @@ async def generate_python_code( response_type="json_object", user=user, tracer=tracer, + attached_files=attached_files, ) # Validate that the response is a non-empty, JSON-serializable list diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index a20982ea..bb561ca5 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -19,7 +19,6 @@ from khoj.database.adapters import ( AgentAdapters, ConversationAdapters, EntryAdapters, - FileObjectAdapters, PublicConversationAdapters, aget_user_name, ) @@ -46,7 +45,7 @@ from khoj.routers.helpers import ( aget_relevant_output_modes, construct_automation_created_message, create_automation, - extract_relevant_info, + gather_attached_files, generate_excalidraw_diagram, generate_summary_from_files, get_conversation_command, @@ -707,6 +706,8 @@ async def chat( ## Extract Document References compiled_references: List[Any] = [] inferred_queries: List[Any] = [] + file_filters = conversation.file_filters if conversation and conversation.file_filters else [] + attached_file_context = await gather_attached_files(user, file_filters) if conversation_commands == [ConversationCommand.Default] or is_automated_task: conversation_commands = await aget_relevant_information_sources( @@ -717,6 +718,7 @@ async def chat( query_images=uploaded_images, agent=agent, tracer=tracer, + attached_files=attached_file_context, ) # If we're doing research, we don't want to do anything else @@ -757,6 +759,7 @@ async def chat( location=location, file_filters=conversation.file_filters if conversation else [], tracer=tracer, + attached_files=attached_file_context, ): if isinstance(research_result, InformationCollectionIteration): if research_result.summarizedResult: @@ -812,6 +815,7 @@ async def chat( agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), tracer=tracer, + attached_files=attached_file_context, ): if isinstance(response, dict) and ChatEvent.STATUS in response: yield response[ChatEvent.STATUS] @@ -945,6 +949,7 @@ async def chat( query_images=uploaded_images, agent=agent, tracer=tracer, + attached_files=attached_file_context, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -970,6 +975,7 @@ async def chat( query_images=uploaded_images, agent=agent, tracer=tracer, + attached_files=attached_file_context, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1010,6 +1016,7 @@ async def chat( query_images=uploaded_images, agent=agent, tracer=tracer, + attached_files=attached_file_context, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1049,6 +1056,7 @@ async def chat( query_images=uploaded_images, agent=agent, tracer=tracer, + attached_files=attached_file_context, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1110,6 +1118,7 @@ async def chat( agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), tracer=tracer, + attached_files=attached_file_context, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1166,6 +1175,7 @@ async def chat( uploaded_images, tracer, train_of_thought, + attached_file_context, ) # Send Response diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 990fa33f..760b6f2e 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -248,6 +248,25 @@ async def agenerate_chat_response(*args): return await loop.run_in_executor(executor, generate_chat_response, *args) +async def gather_attached_files( + user: KhojUser, + file_filters: List[str], +) -> str: + """ + Gather contextual data from the given files + """ + if len(file_filters) == 0: + return "" + + file_objects = await FileObjectAdapters.async_get_file_objects_by_names(user, file_filters) + + if len(file_objects) == 0: + return "" + + contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects]) + return contextual_data + + async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: """ Create a title from the given query @@ -294,6 +313,7 @@ async def aget_relevant_information_sources( query_images: List[str] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. @@ -331,6 +351,7 @@ async def aget_relevant_information_sources( response_type="json_object", user=user, tracer=tracer, + attached_files=attached_files, ) try: @@ -440,6 +461,7 @@ async def infer_webpage_urls( query_images: List[str] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ) -> List[str]: """ Infer webpage links from the given query @@ -469,6 +491,7 @@ async def infer_webpage_urls( response_type="json_object", user=user, tracer=tracer, + attached_files=attached_files, ) # Validate that the response is a non-empty, JSON-serializable list of URLs @@ -494,6 +517,7 @@ async def generate_online_subqueries( query_images: List[str] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ) -> List[str]: """ Generate subqueries from the given query @@ -523,6 +547,7 @@ async def generate_online_subqueries( response_type="json_object", user=user, tracer=tracer, + attached_files=attached_files, ) # Validate that the response is a non-empty, JSON-serializable list @@ -645,6 +670,7 @@ async def generate_summary_from_files( agent: Agent = None, send_status_func: Optional[Callable] = None, tracer: dict = {}, + attached_files: str = None, ): try: file_object = None @@ -653,17 +679,28 @@ async def generate_summary_from_files( if len(file_names) > 0: file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) - if len(file_filters) > 0: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - - if len(file_object) == 0: + if len(file_object) == 0 and not attached_files: response_log = "Sorry, I couldn't find the full text of this file." yield response_log return - contextual_data = " ".join([file.raw_text for file in file_object]) + + contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_object]) + + if attached_files: + contextual_data += f"\n\n{attached_files}" + if not q: q = "Create a general summary of the file" - async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"): + + file_names = [file.file_name for file in file_object] + file_names.extend(file_filters) + + all_file_names = "" + + for file_name in file_names: + all_file_names += f"- {file_name}\n" + + async for result in send_status_func(f"**Constructing Summary Using:**\n{all_file_names}"): yield {ChatEvent.STATUS: result} response = await extract_relevant_summary( @@ -694,6 +731,7 @@ async def generate_excalidraw_diagram( agent: Agent = None, send_status_func: Optional[Callable] = None, tracer: dict = {}, + attached_files: str = None, ): if send_status_func: async for event in send_status_func("**Enhancing the Diagramming Prompt**"): @@ -709,6 +747,7 @@ async def generate_excalidraw_diagram( user=user, agent=agent, tracer=tracer, + attached_files=attached_files, ) if send_status_func: @@ -735,6 +774,7 @@ async def generate_better_diagram_description( user: KhojUser = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ) -> str: """ Generate a diagram description from the given query and context @@ -772,7 +812,11 @@ async def generate_better_diagram_description( with timer("Chat actor: Generate better diagram description", logger): response = await send_message_to_model_wrapper( - improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer + improve_diagram_description_prompt, + query_images=query_images, + user=user, + tracer=tracer, + attached_files=attached_files, ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): @@ -820,6 +864,7 @@ async def generate_better_image_prompt( user: KhojUser = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = "", ) -> str: """ Generate a better image prompt from the given query @@ -867,7 +912,7 @@ async def generate_better_image_prompt( with timer("Chat actor: Generate contextual image prompt", logger): response = await send_message_to_model_wrapper( - image_prompt, query_images=query_images, user=user, tracer=tracer + image_prompt, query_images=query_images, user=user, tracer=tracer, attached_files=attached_files ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): @@ -884,6 +929,7 @@ async def send_message_to_model_wrapper( query_images: List[str] = None, context: str = "", tracer: dict = {}, + attached_files: str = None, ): conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) vision_available = conversation_config.vision_enabled @@ -922,6 +968,7 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=conversation_config.model_type, + attached_files=attached_files, ) return send_message_to_model_offline( @@ -948,6 +995,7 @@ async def send_message_to_model_wrapper( vision_enabled=vision_available, query_images=query_images, model_type=conversation_config.model_type, + attached_files=attached_files, ) return send_message_to_model( @@ -970,6 +1018,7 @@ async def send_message_to_model_wrapper( vision_enabled=vision_available, query_images=query_images, model_type=conversation_config.model_type, + attached_files=attached_files, ) return anthropic_send_message_to_model( @@ -991,6 +1040,7 @@ async def send_message_to_model_wrapper( vision_enabled=vision_available, query_images=query_images, model_type=conversation_config.model_type, + attached_files=attached_files, ) return gemini_send_message_to_model( @@ -1006,6 +1056,7 @@ def send_message_to_model_wrapper_sync( response_type: str = "text", user: KhojUser = None, tracer: dict = {}, + attached_files: str = "", ): conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) @@ -1029,6 +1080,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=conversation_config.model_type, + attached_files=attached_files, ) return send_message_to_model_offline( @@ -1050,6 +1102,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=conversation_config.model_type, + attached_files=attached_files, ) openai_response = send_message_to_model( @@ -1071,6 +1124,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=conversation_config.model_type, + attached_files=attached_files, ) return anthropic_send_message_to_model( @@ -1090,6 +1144,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=conversation_config.model_type, + attached_files=attached_files, ) return gemini_send_message_to_model( @@ -1121,6 +1176,7 @@ def generate_chat_response( query_images: Optional[List[str]] = None, tracer: dict = {}, train_of_thought: List[Any] = [], + attached_files: str = None, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables chat_response = None @@ -1173,6 +1229,7 @@ def generate_chat_response( user_name=user_name, agent=agent, tracer=tracer, + attached_files=attached_files, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: @@ -1198,6 +1255,7 @@ def generate_chat_response( agent=agent, vision_available=vision_available, tracer=tracer, + attached_files=attached_files, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: @@ -1220,6 +1278,7 @@ def generate_chat_response( agent=agent, vision_available=vision_available, tracer=tracer, + attached_files=attached_files, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -1240,6 +1299,7 @@ def generate_chat_response( agent=agent, vision_available=vision_available, tracer=tracer, + attached_files=attached_files, ) metadata.update({"chat_model": conversation_config.chat_model}) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 4f9c6b4e..960cf52f 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -47,6 +47,7 @@ async def apick_next_tool( max_iterations: int = 5, send_status_func: Optional[Callable] = None, tracer: dict = {}, + attached_files: str = None, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer. @@ -95,6 +96,7 @@ async def apick_next_tool( user=user, query_images=query_images, tracer=tracer, + attached_files=attached_files, ) try: @@ -137,6 +139,7 @@ async def execute_information_collection( location: LocationData = None, file_filters: List[str] = [], tracer: dict = {}, + attached_files: str = None, ): current_iteration = 0 MAX_ITERATIONS = 5 @@ -161,6 +164,7 @@ async def execute_information_collection( MAX_ITERATIONS, send_status_func, tracer=tracer, + attached_files=attached_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -233,6 +237,7 @@ async def execute_information_collection( query_images=query_images, agent=agent, tracer=tracer, + attached_files=attached_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -264,6 +269,7 @@ async def execute_information_collection( query_images=query_images, agent=agent, tracer=tracer, + attached_files=attached_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -288,6 +294,7 @@ async def execute_information_collection( query_images=query_images, agent=agent, send_status_func=send_status_func, + attached_files=attached_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS]