From 362bdebd026a7d65c2d4c2fcddca54939a103f64 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 4 Nov 2024 16:37:13 -0800 Subject: [PATCH 01/42] Add methods for reading full files by name and including context Now that models have much larger context windows, we can reasonably include full texts of certain files in the messages. Do this when an explicit file filter is set in a conversation. Do so in a separate user message in order to mitigate any confusion in the operation. Pipe the relevant attached_files context through all methods calling into models. We'll want to limit the file sizes for which this is used and provide more helpful UI indicators that this sort of behavior is taking place. --- src/khoj/database/adapters/__init__.py | 4 + .../conversation/anthropic/anthropic_chat.py | 2 + .../conversation/google/gemini_chat.py | 11 ++- .../conversation/offline/chat_model.py | 4 + src/khoj/processor/conversation/openai/gpt.py | 2 + src/khoj/processor/conversation/utils.py | 8 ++ src/khoj/processor/image/generate.py | 2 + src/khoj/processor/tools/online_search.py | 20 ++++- src/khoj/processor/tools/run_code.py | 5 ++ src/khoj/routers/api_chat.py | 14 +++- src/khoj/routers/helpers.py | 76 +++++++++++++++++-- src/khoj/routers/research.py | 7 ++ 12 files changed, 142 insertions(+), 13 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 164b4023..1c8336a3 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1387,6 +1387,10 @@ class FileObjectAdapters: async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) + @staticmethod + async def async_get_file_objects_by_names(user: KhojUser, file_names: List[str]): + return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names)) + @staticmethod async def async_get_all_file_objects(user: KhojUser): return await sync_to_async(list)(FileObject.objects.filter(user=user)) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index e2fd0c74..df81f56f 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -147,6 +147,7 @@ def converse_anthropic( query_images: Optional[list[str]] = None, vision_available: bool = False, tracer: dict = {}, + attached_files: str = None, ): """ Converse with user using Anthropic's Claude @@ -203,6 +204,7 @@ def converse_anthropic( query_images=query_images, vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.ANTHROPIC, + attached_files=attached_files, ) messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index aebda1a8..b7ec018d 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -108,7 +108,14 @@ def extract_questions_gemini( def gemini_send_message_to_model( - messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={} + messages, + api_key, + model, + response_type="text", + temperature=0, + model_kwargs=None, + tracer={}, + attached_files: str = None, ): """ Send message to model @@ -152,6 +159,7 @@ def converse_gemini( query_images: Optional[list[str]] = None, vision_available: bool = False, tracer={}, + attached_files: str = None, ): """ Converse with user using Google's Gemini @@ -209,6 +217,7 @@ def converse_gemini( query_images=query_images, vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.GOOGLE, + attached_files=attached_files, ) messages, system_prompt = format_messages_for_gemini(messages, system_prompt) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index aaaaa081..d0b62f3d 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -38,6 +38,7 @@ def extract_questions_offline( temperature: float = 0.7, personality_context: Optional[str] = None, tracer: dict = {}, + attached_files: str = None, ) -> List[str]: """ Infer search queries to retrieve relevant notes to answer user query @@ -87,6 +88,7 @@ def extract_questions_offline( loaded_model=offline_chat_model, max_prompt_size=max_prompt_size, model_type=ChatModelOptions.ModelType.OFFLINE, + attached_files=attached_files, ) state.chat_lock.acquire() @@ -153,6 +155,7 @@ def converse_offline( user_name: str = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ) -> Union[ThreadedGenerator, Iterator[str]]: """ Converse with user using Llama @@ -216,6 +219,7 @@ def converse_offline( max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, model_type=ChatModelOptions.ModelType.OFFLINE, + attached_files=attached_files, ) truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index c376a90e..bdb67448 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -149,6 +149,7 @@ def converse( query_images: Optional[list[str]] = None, vision_available: bool = False, tracer: dict = {}, + attached_files: str = None, ): """ Converse with user using OpenAI's ChatGPT @@ -206,6 +207,7 @@ def converse( query_images=query_images, vision_enabled=vision_available, model_type=ChatModelOptions.ModelType.OPENAI, + attached_files=attached_files, ) truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) logger.debug(f"Conversation Context for GPT: {truncated_messages}") diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 74c464d9..b510a09a 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -318,6 +318,7 @@ def generate_chatml_messages_with_context( vision_enabled=False, model_type="", context_message="", + attached_files: str = None, ): """Generate chat messages with appropriate context from previous conversation to send to the chat model""" # Set max prompt size from user config or based on pre-configured for model and machine specs @@ -341,8 +342,10 @@ def generate_chatml_messages_with_context( {f"# File: {item['file']}\n## {item['compiled']}\n" for item in chat.get("context") or []} ) message_context += f"{prompts.notes_conversation.format(references=references)}\n\n" + if not is_none_or_empty(chat.get("onlineContext")): message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" + if not is_none_or_empty(message_context): reconstructed_context_message = ChatMessage(content=message_context, role="user") chatml_messages.insert(0, reconstructed_context_message) @@ -366,8 +369,13 @@ def generate_chatml_messages_with_context( ) if not is_none_or_empty(context_message): messages.append(ChatMessage(content=context_message, role="user")) + + if not is_none_or_empty(attached_files): + messages.append(ChatMessage(content=attached_files, role="user")) + if len(chatml_messages) > 0: messages += chatml_messages + if not is_none_or_empty(system_message): messages.append(ChatMessage(content=system_message, role="system")) diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py index bdc00e09..ec5254ec 100644 --- a/src/khoj/processor/image/generate.py +++ b/src/khoj/processor/image/generate.py @@ -29,6 +29,7 @@ async def text_to_image( query_images: Optional[List[str]] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ): status_code = 200 image = None @@ -70,6 +71,7 @@ async def text_to_image( user=user, agent=agent, tracer=tracer, + attached_files=attached_files, ) if send_status_func: diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index c6fc7c20..3b4bd16a 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -68,6 +68,7 @@ async def search_online( query_images: List[str] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ): query += " ".join(custom_filters) if not is_internet_connected(): @@ -77,7 +78,14 @@ async def search_online( # Breakdown the query into subqueries to get the correct answer subqueries = await generate_online_subqueries( - query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer + query, + conversation_history, + location, + user, + query_images=query_images, + agent=agent, + tracer=tracer, + attached_files=attached_files, ) response_dict = {} @@ -159,11 +167,19 @@ async def read_webpages( agent: Agent = None, tracer: dict = {}, max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ, + attached_files: str = None, ): "Infer web pages to read from the query and extract relevant information from them" logger.info(f"Inferring web pages to read") urls = await infer_webpage_urls( - query, conversation_history, location, user, query_images, agent=agent, tracer=tracer + query, + conversation_history, + location, + user, + query_images, + agent=agent, + tracer=tracer, + attached_files=attached_files, ) # Get the top 10 web pages to read diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index d4ba9af1..86388aee 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -6,6 +6,7 @@ import os from typing import Any, Callable, List, Optional import aiohttp +import requests from khoj.database.adapters import ais_user_subscribed from khoj.database.models import Agent, KhojUser @@ -37,6 +38,7 @@ async def run_code( agent: Agent = None, sandbox_url: str = SANDBOX_URL, tracer: dict = {}, + attached_files: str = None, ): # Generate Code if send_status_func: @@ -53,6 +55,7 @@ async def run_code( query_images, agent, tracer, + attached_files, ) except Exception as e: raise ValueError(f"Failed to generate code for {query} with error: {e}") @@ -82,6 +85,7 @@ async def generate_python_code( query_images: List[str] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ) -> List[str]: location = f"{location_data}" if location_data else "Unknown" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" @@ -109,6 +113,7 @@ async def generate_python_code( response_type="json_object", user=user, tracer=tracer, + attached_files=attached_files, ) # Validate that the response is a non-empty, JSON-serializable list diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index a20982ea..bb561ca5 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -19,7 +19,6 @@ from khoj.database.adapters import ( AgentAdapters, ConversationAdapters, EntryAdapters, - FileObjectAdapters, PublicConversationAdapters, aget_user_name, ) @@ -46,7 +45,7 @@ from khoj.routers.helpers import ( aget_relevant_output_modes, construct_automation_created_message, create_automation, - extract_relevant_info, + gather_attached_files, generate_excalidraw_diagram, generate_summary_from_files, get_conversation_command, @@ -707,6 +706,8 @@ async def chat( ## Extract Document References compiled_references: List[Any] = [] inferred_queries: List[Any] = [] + file_filters = conversation.file_filters if conversation and conversation.file_filters else [] + attached_file_context = await gather_attached_files(user, file_filters) if conversation_commands == [ConversationCommand.Default] or is_automated_task: conversation_commands = await aget_relevant_information_sources( @@ -717,6 +718,7 @@ async def chat( query_images=uploaded_images, agent=agent, tracer=tracer, + attached_files=attached_file_context, ) # If we're doing research, we don't want to do anything else @@ -757,6 +759,7 @@ async def chat( location=location, file_filters=conversation.file_filters if conversation else [], tracer=tracer, + attached_files=attached_file_context, ): if isinstance(research_result, InformationCollectionIteration): if research_result.summarizedResult: @@ -812,6 +815,7 @@ async def chat( agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), tracer=tracer, + attached_files=attached_file_context, ): if isinstance(response, dict) and ChatEvent.STATUS in response: yield response[ChatEvent.STATUS] @@ -945,6 +949,7 @@ async def chat( query_images=uploaded_images, agent=agent, tracer=tracer, + attached_files=attached_file_context, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -970,6 +975,7 @@ async def chat( query_images=uploaded_images, agent=agent, tracer=tracer, + attached_files=attached_file_context, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1010,6 +1016,7 @@ async def chat( query_images=uploaded_images, agent=agent, tracer=tracer, + attached_files=attached_file_context, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1049,6 +1056,7 @@ async def chat( query_images=uploaded_images, agent=agent, tracer=tracer, + attached_files=attached_file_context, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1110,6 +1118,7 @@ async def chat( agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), tracer=tracer, + attached_files=attached_file_context, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1166,6 +1175,7 @@ async def chat( uploaded_images, tracer, train_of_thought, + attached_file_context, ) # Send Response diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 990fa33f..760b6f2e 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -248,6 +248,25 @@ async def agenerate_chat_response(*args): return await loop.run_in_executor(executor, generate_chat_response, *args) +async def gather_attached_files( + user: KhojUser, + file_filters: List[str], +) -> str: + """ + Gather contextual data from the given files + """ + if len(file_filters) == 0: + return "" + + file_objects = await FileObjectAdapters.async_get_file_objects_by_names(user, file_filters) + + if len(file_objects) == 0: + return "" + + contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects]) + return contextual_data + + async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: """ Create a title from the given query @@ -294,6 +313,7 @@ async def aget_relevant_information_sources( query_images: List[str] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. @@ -331,6 +351,7 @@ async def aget_relevant_information_sources( response_type="json_object", user=user, tracer=tracer, + attached_files=attached_files, ) try: @@ -440,6 +461,7 @@ async def infer_webpage_urls( query_images: List[str] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ) -> List[str]: """ Infer webpage links from the given query @@ -469,6 +491,7 @@ async def infer_webpage_urls( response_type="json_object", user=user, tracer=tracer, + attached_files=attached_files, ) # Validate that the response is a non-empty, JSON-serializable list of URLs @@ -494,6 +517,7 @@ async def generate_online_subqueries( query_images: List[str] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ) -> List[str]: """ Generate subqueries from the given query @@ -523,6 +547,7 @@ async def generate_online_subqueries( response_type="json_object", user=user, tracer=tracer, + attached_files=attached_files, ) # Validate that the response is a non-empty, JSON-serializable list @@ -645,6 +670,7 @@ async def generate_summary_from_files( agent: Agent = None, send_status_func: Optional[Callable] = None, tracer: dict = {}, + attached_files: str = None, ): try: file_object = None @@ -653,17 +679,28 @@ async def generate_summary_from_files( if len(file_names) > 0: file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) - if len(file_filters) > 0: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) - - if len(file_object) == 0: + if len(file_object) == 0 and not attached_files: response_log = "Sorry, I couldn't find the full text of this file." yield response_log return - contextual_data = " ".join([file.raw_text for file in file_object]) + + contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_object]) + + if attached_files: + contextual_data += f"\n\n{attached_files}" + if not q: q = "Create a general summary of the file" - async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"): + + file_names = [file.file_name for file in file_object] + file_names.extend(file_filters) + + all_file_names = "" + + for file_name in file_names: + all_file_names += f"- {file_name}\n" + + async for result in send_status_func(f"**Constructing Summary Using:**\n{all_file_names}"): yield {ChatEvent.STATUS: result} response = await extract_relevant_summary( @@ -694,6 +731,7 @@ async def generate_excalidraw_diagram( agent: Agent = None, send_status_func: Optional[Callable] = None, tracer: dict = {}, + attached_files: str = None, ): if send_status_func: async for event in send_status_func("**Enhancing the Diagramming Prompt**"): @@ -709,6 +747,7 @@ async def generate_excalidraw_diagram( user=user, agent=agent, tracer=tracer, + attached_files=attached_files, ) if send_status_func: @@ -735,6 +774,7 @@ async def generate_better_diagram_description( user: KhojUser = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ) -> str: """ Generate a diagram description from the given query and context @@ -772,7 +812,11 @@ async def generate_better_diagram_description( with timer("Chat actor: Generate better diagram description", logger): response = await send_message_to_model_wrapper( - improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer + improve_diagram_description_prompt, + query_images=query_images, + user=user, + tracer=tracer, + attached_files=attached_files, ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): @@ -820,6 +864,7 @@ async def generate_better_image_prompt( user: KhojUser = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = "", ) -> str: """ Generate a better image prompt from the given query @@ -867,7 +912,7 @@ async def generate_better_image_prompt( with timer("Chat actor: Generate contextual image prompt", logger): response = await send_message_to_model_wrapper( - image_prompt, query_images=query_images, user=user, tracer=tracer + image_prompt, query_images=query_images, user=user, tracer=tracer, attached_files=attached_files ) response = response.strip() if response.startswith(('"', "'")) and response.endswith(('"', "'")): @@ -884,6 +929,7 @@ async def send_message_to_model_wrapper( query_images: List[str] = None, context: str = "", tracer: dict = {}, + attached_files: str = None, ): conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) vision_available = conversation_config.vision_enabled @@ -922,6 +968,7 @@ async def send_message_to_model_wrapper( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=conversation_config.model_type, + attached_files=attached_files, ) return send_message_to_model_offline( @@ -948,6 +995,7 @@ async def send_message_to_model_wrapper( vision_enabled=vision_available, query_images=query_images, model_type=conversation_config.model_type, + attached_files=attached_files, ) return send_message_to_model( @@ -970,6 +1018,7 @@ async def send_message_to_model_wrapper( vision_enabled=vision_available, query_images=query_images, model_type=conversation_config.model_type, + attached_files=attached_files, ) return anthropic_send_message_to_model( @@ -991,6 +1040,7 @@ async def send_message_to_model_wrapper( vision_enabled=vision_available, query_images=query_images, model_type=conversation_config.model_type, + attached_files=attached_files, ) return gemini_send_message_to_model( @@ -1006,6 +1056,7 @@ def send_message_to_model_wrapper_sync( response_type: str = "text", user: KhojUser = None, tracer: dict = {}, + attached_files: str = "", ): conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) @@ -1029,6 +1080,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=conversation_config.model_type, + attached_files=attached_files, ) return send_message_to_model_offline( @@ -1050,6 +1102,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=conversation_config.model_type, + attached_files=attached_files, ) openai_response = send_message_to_model( @@ -1071,6 +1124,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=conversation_config.model_type, + attached_files=attached_files, ) return anthropic_send_message_to_model( @@ -1090,6 +1144,7 @@ def send_message_to_model_wrapper_sync( max_prompt_size=max_tokens, vision_enabled=vision_available, model_type=conversation_config.model_type, + attached_files=attached_files, ) return gemini_send_message_to_model( @@ -1121,6 +1176,7 @@ def generate_chat_response( query_images: Optional[List[str]] = None, tracer: dict = {}, train_of_thought: List[Any] = [], + attached_files: str = None, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables chat_response = None @@ -1173,6 +1229,7 @@ def generate_chat_response( user_name=user_name, agent=agent, tracer=tracer, + attached_files=attached_files, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: @@ -1198,6 +1255,7 @@ def generate_chat_response( agent=agent, vision_available=vision_available, tracer=tracer, + attached_files=attached_files, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: @@ -1220,6 +1278,7 @@ def generate_chat_response( agent=agent, vision_available=vision_available, tracer=tracer, + attached_files=attached_files, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -1240,6 +1299,7 @@ def generate_chat_response( agent=agent, vision_available=vision_available, tracer=tracer, + attached_files=attached_files, ) metadata.update({"chat_model": conversation_config.chat_model}) diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 4f9c6b4e..960cf52f 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -47,6 +47,7 @@ async def apick_next_tool( max_iterations: int = 5, send_status_func: Optional[Callable] = None, tracer: dict = {}, + attached_files: str = None, ): """ Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer. @@ -95,6 +96,7 @@ async def apick_next_tool( user=user, query_images=query_images, tracer=tracer, + attached_files=attached_files, ) try: @@ -137,6 +139,7 @@ async def execute_information_collection( location: LocationData = None, file_filters: List[str] = [], tracer: dict = {}, + attached_files: str = None, ): current_iteration = 0 MAX_ITERATIONS = 5 @@ -161,6 +164,7 @@ async def execute_information_collection( MAX_ITERATIONS, send_status_func, tracer=tracer, + attached_files=attached_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -233,6 +237,7 @@ async def execute_information_collection( query_images=query_images, agent=agent, tracer=tracer, + attached_files=attached_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -264,6 +269,7 @@ async def execute_information_collection( query_images=query_images, agent=agent, tracer=tracer, + attached_files=attached_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -288,6 +294,7 @@ async def execute_information_collection( query_images=query_images, agent=agent, send_status_func=send_status_func, + attached_files=attached_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] From a27b8d3e5462d1e55cf13e0c584e67b3ce49057a Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 4 Nov 2024 16:51:37 -0800 Subject: [PATCH 02/42] Remove summarize condition for only 1 file filter --- src/khoj/routers/api_chat.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index bb561ca5..a9a6f09f 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -801,10 +801,6 @@ async def chat( response_log = "No files selected for summarization. Please add files using the section on the left." async for result in send_llm_response(response_log): yield result - elif len(file_filters) > 1 and not agent_has_entries: - response_log = "Only one file can be selected for summarization." - async for result in send_llm_response(response_log): - yield result else: async for response in generate_summary_from_files( q=q, From 3dc9139cee9f9e9d21c1d500f015aef68e6b0d7d Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 4 Nov 2024 16:53:07 -0800 Subject: [PATCH 03/42] Add additional handling for when file_object comes back empty --- src/khoj/routers/helpers.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 760b6f2e..bf010034 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -673,18 +673,20 @@ async def generate_summary_from_files( attached_files: str = None, ): try: - file_object = None + file_objects = None if await EntryAdapters.aagent_has_entries(agent): file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) if len(file_names) > 0: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) + file_objects = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) - if len(file_object) == 0 and not attached_files: - response_log = "Sorry, I couldn't find the full text of this file." + if (file_objects and len(file_objects) == 0 and not attached_files) or ( + not file_objects and not attached_files + ): + response_log = "Sorry, I couldn't find anything to summarize." yield response_log return - contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_object]) + contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects]) if attached_files: contextual_data += f"\n\n{attached_files}" @@ -692,7 +694,7 @@ async def generate_summary_from_files( if not q: q = "Create a general summary of the file" - file_names = [file.file_name for file in file_object] + file_names = [file.file_name for file in file_objects] file_names.extend(file_filters) all_file_names = "" From 1f372bf2b1268aee47f26187b6b9d246191c7f9d Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 4 Nov 2024 17:45:54 -0800 Subject: [PATCH 04/42] Update file summarization unit tests now that multiple files are allowed --- tests/test_offline_chat_director.py | 7 ++----- tests/test_openai_chat_director.py | 5 +---- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py index afb5d4ce..f8285f40 100644 --- a/tests/test_offline_chat_director.py +++ b/tests/test_offline_chat_director.py @@ -10,7 +10,7 @@ from khoj.processor.conversation.utils import message_to_log from khoj.routers.helpers import aget_relevant_information_sources from tests.helpers import ConversationFactory -SKIP_TESTS = True +SKIP_TESTS = False pytestmark = pytest.mark.skipif( SKIP_TESTS, reason="Disable in CI to avoid long test runs.", @@ -337,7 +337,6 @@ def test_summarize_one_file(client_offline_chat, default_user2: KhojUser): # Assert assert response_message != "" assert response_message != "No files selected for summarization. Please add files using the section on the left." - assert response_message != "Only one file can be selected for summarization." @pytest.mark.django_db(transaction=True) @@ -375,7 +374,6 @@ def test_summarize_extra_text(client_offline_chat, default_user2: KhojUser): # Assert assert response_message != "" assert response_message != "No files selected for summarization. Please add files using the section on the left." - assert response_message != "Only one file can be selected for summarization." @pytest.mark.django_db(transaction=True) @@ -404,7 +402,7 @@ def test_summarize_multiple_files(client_offline_chat, default_user2: KhojUser): response_message = response.json()["response"] # Assert - assert response_message == "Only one file can be selected for summarization." + assert response_message is not None @pytest.mark.django_db(transaction=True) @@ -460,7 +458,6 @@ def test_summarize_different_conversation(client_offline_chat, default_user2: Kh # Assert assert response_message != "" assert response_message != "No files selected for summarization. Please add files using the section on the left." - assert response_message != "Only one file can be selected for summarization." @pytest.mark.django_db(transaction=True) diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index 7d460408..56dde6d3 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -312,7 +312,6 @@ def test_summarize_one_file(chat_client, default_user2: KhojUser): # Assert assert response_message != "" assert response_message != "No files selected for summarization. Please add files using the section on the left." - assert response_message != "Only one file can be selected for summarization." @pytest.mark.django_db(transaction=True) @@ -344,7 +343,6 @@ def test_summarize_extra_text(chat_client, default_user2: KhojUser): # Assert assert response_message != "" assert response_message != "No files selected for summarization. Please add files using the section on the left." - assert response_message != "Only one file can be selected for summarization." @pytest.mark.django_db(transaction=True) @@ -371,7 +369,7 @@ def test_summarize_multiple_files(chat_client, default_user2: KhojUser): response_message = response.json()["response"] # Assert - assert response_message == "Only one file can be selected for summarization." + assert response_message is not None @pytest.mark.django_db(transaction=True) @@ -435,7 +433,6 @@ def test_summarize_different_conversation(chat_client, default_user2: KhojUser): assert ( response_message_conv1 != "No files selected for summarization. Please add files using the section on the left." ) - assert response_message_conv1 != "Only one file can be selected for summarization." @pytest.mark.django_db(transaction=True) From cf0bcec0e7ec7b179c8bfffbe03be5d230bb7bda Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 4 Nov 2024 19:06:54 -0800 Subject: [PATCH 05/42] Revert SKIP_TESTS flag in offline chat director tests --- tests/test_offline_chat_director.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py index f8285f40..45f540ed 100644 --- a/tests/test_offline_chat_director.py +++ b/tests/test_offline_chat_director.py @@ -10,7 +10,7 @@ from khoj.processor.conversation.utils import message_to_log from khoj.routers.helpers import aget_relevant_information_sources from tests.helpers import ConversationFactory -SKIP_TESTS = False +SKIP_TESTS = True pytestmark = pytest.mark.skipif( SKIP_TESTS, reason="Disable in CI to avoid long test runs.", From dc26da0a12762c71ae2650c903c1c9bfdbfd90f4 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 4 Nov 2024 22:00:47 -0800 Subject: [PATCH 06/42] Add uploaded files in the conversation file filter for a new convo --- src/interface/web/app/page.tsx | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/src/interface/web/app/page.tsx b/src/interface/web/app/page.tsx index 532b4420..1b714414 100644 --- a/src/interface/web/app/page.tsx +++ b/src/interface/web/app/page.tsx @@ -30,6 +30,7 @@ import { useRouter, useSearchParams } from "next/navigation"; import { ScrollArea, ScrollBar } from "@/components/ui/scroll-area"; import { AgentCard } from "@/app/components/agentCard/agentCard"; import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { modifyFileFilterForConversation } from "./common/chatFunctions"; interface ChatBodyDataProps { chatOptionsData: ChatOptions | null; @@ -150,12 +151,26 @@ function ChatBodyData(props: ChatBodyDataProps) { setProcessingMessage(true); try { const newConversationId = await createNewConversation(selectedAgent || "khoj"); + const uploadedFiles = localStorage.getItem("uploadedFiles"); onConversationIdChange?.(newConversationId); localStorage.setItem("message", message); if (images.length > 0) { localStorage.setItem("images", JSON.stringify(images)); } - window.location.href = `/chat?conversationId=${newConversationId}`; + + if (uploadedFiles) { + modifyFileFilterForConversation( + newConversationId, + JSON.parse(uploadedFiles), + () => { + window.location.href = `/chat?conversationId=${newConversationId}`; + }, + "add", + ); + localStorage.removeItem("uploadedFiles"); + } else { + window.location.href = `/chat?conversationId=${newConversationId}`; + } } catch (error) { console.error("Error creating new conversation:", error); setProcessingMessage(false); @@ -417,6 +432,10 @@ export default function Home() { setUserConfig(initialUserConfig); }, [initialUserConfig]); + useEffect(() => { + localStorage.setItem("uploadedFiles", JSON.stringify(uploadedFiles)); + }, [uploadedFiles]); + useEffect(() => { fetch("/api/chat/options") .then((response) => response.json()) From a0480d5f6c741ceabf2cb8a9406fb37d7af09113 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 4 Nov 2024 22:01:09 -0800 Subject: [PATCH 07/42] use fill weight for the toggle right (enabled state) for research mode --- src/interface/web/app/components/chatInputArea/chatInputArea.tsx | 1 + 1 file changed, 1 insertion(+) diff --git a/src/interface/web/app/components/chatInputArea/chatInputArea.tsx b/src/interface/web/app/components/chatInputArea/chatInputArea.tsx index 74e15523..9f8f8c18 100644 --- a/src/interface/web/app/components/chatInputArea/chatInputArea.tsx +++ b/src/interface/web/app/components/chatInputArea/chatInputArea.tsx @@ -582,6 +582,7 @@ export const ChatInputArea = forwardRef((pr Research Mode {useResearchMode ? ( ) : ( From de73cbc6106ede0d02db64da707c2cb96acd8655 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Thu, 7 Nov 2024 15:58:52 -0800 Subject: [PATCH 08/42] Add support for relaying attached files through backend calls to models --- .../conversation/anthropic/anthropic_chat.py | 8 +- .../conversation/google/gemini_chat.py | 9 ++- src/khoj/processor/conversation/openai/gpt.py | 8 +- src/khoj/processor/conversation/utils.py | 49 +++++++----- src/khoj/processor/tools/run_code.py | 10 +-- src/khoj/routers/helpers.py | 80 ++++++++++++------- src/khoj/routers/research.py | 2 +- src/khoj/utils/rawconfig.py | 32 ++++++++ 8 files changed, 136 insertions(+), 62 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index df81f56f..c171c8fb 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -37,6 +37,7 @@ def extract_questions_anthropic( vision_enabled: bool = False, personality_context: Optional[str] = None, tracer: dict = {}, + attached_files: str = None, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -84,7 +85,12 @@ def extract_questions_anthropic( vision_enabled=vision_enabled, ) - messages = [ChatMessage(content=prompt, role="user")] + messages = [] + + if attached_files: + messages.append(ChatMessage(content=attached_files, role="user")) + + messages.append(ChatMessage(content=prompt, role="user")) response = anthropic_completion_with_backoff( messages=messages, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index b7ec018d..6d257faa 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -38,6 +38,7 @@ def extract_questions_gemini( vision_enabled: bool = False, personality_context: Optional[str] = None, tracer: dict = {}, + attached_files: str = None, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -85,7 +86,13 @@ def extract_questions_gemini( vision_enabled=vision_enabled, ) - messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")] + messages = [] + + if attached_files: + messages.append(ChatMessage(content=attached_files, role="user")) + + messages.append(ChatMessage(content=prompt, role="user")) + messages.append(ChatMessage(content=system_prompt, role="system")) response = gemini_send_message_to_model( messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index bdb67448..65cdfa3f 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -35,6 +35,7 @@ def extract_questions( vision_enabled: bool = False, personality_context: Optional[str] = None, tracer: dict = {}, + attached_files: str = None, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -81,7 +82,12 @@ def extract_questions( vision_enabled=vision_enabled, ) - messages = [ChatMessage(content=prompt, role="user")] + messages = [] + + if attached_files: + messages.append(ChatMessage(content=attached_files, role="user")) + + messages.append(ChatMessage(content=prompt, role="user")) response = send_message_to_model( messages, diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b510a09a..791a98e0 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -36,6 +36,7 @@ from khoj.utils.helpers import ( is_none_or_empty, merge_dicts, ) +from khoj.utils.rawconfig import FileAttachment logger = logging.getLogger(__name__) @@ -137,25 +138,6 @@ def construct_iteration_history( return previous_iterations_history -def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: - chat_history = "" - for chat in conversation_history.get("chat", [])[-n:]: - if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]: - chat_history += f"User: {chat['intent']['query']}\n" - - if chat["intent"].get("inferred-queries"): - chat_history += f'Khoj: {{"queries": {chat["intent"].get("inferred-queries")}}}\n' - - chat_history += f"{agent_name}: {chat['message']}\n\n" - elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: [generated image redacted for space]\n" - elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n" - return chat_history - - def construct_tool_chat_history( previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None ) -> Dict[str, list]: @@ -241,6 +223,7 @@ def save_to_conversation_log( conversation_id: str = None, automation_id: str = None, query_images: List[str] = None, + raw_attached_files: List[FileAttachment] = [], tracer: Dict[str, Any] = {}, train_of_thought: List[Any] = [], ): @@ -253,6 +236,7 @@ def save_to_conversation_log( "created": user_message_time, "images": query_images, "turnId": turn_id, + "attachedFiles": [file.model_dump(mode="json") for file in raw_attached_files], }, khoj_message_metadata={ "context": compiled_references, @@ -306,6 +290,22 @@ def construct_structured_message(message: str, images: list[str], model_type: st return message +def gather_raw_attached_files( + attached_files: Dict[str, str], +): + """_summary_ + Gather contextual data from the given (raw) files + """ + + if len(attached_files) == 0: + return "" + + contextual_data = " ".join( + [f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()] + ) + return f"I have attached the following files:\n\n{contextual_data}" + + def generate_chatml_messages_with_context( user_message, system_message=None, @@ -335,6 +335,8 @@ def generate_chatml_messages_with_context( chatml_messages: List[ChatMessage] = [] for chat in conversation_log.get("chat", []): message_context = "" + message_attached_files = "" + if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): message_context += chat.get("intent").get("inferred-queries")[0] if not is_none_or_empty(chat.get("context")): @@ -343,6 +345,15 @@ def generate_chatml_messages_with_context( ) message_context += f"{prompts.notes_conversation.format(references=references)}\n\n" + if chat.get("attachedFiles"): + raw_attached_files = chat.get("attachedFiles") + attached_files_dict = dict() + for file in raw_attached_files: + attached_files_dict[file["name"]] = file["content"] + + message_attached_files = gather_raw_attached_files(attached_files_dict) + chatml_messages.append(ChatMessage(content=message_attached_files, role="user")) + if not is_none_or_empty(chat.get("onlineContext")): message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 86388aee..418ab3a2 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -6,18 +6,12 @@ import os from typing import Any, Callable, List, Optional import aiohttp -import requests from khoj.database.adapters import ais_user_subscribed from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts -from khoj.processor.conversation.utils import ( - ChatEvent, - clean_code_python, - clean_json, - construct_chat_history, -) -from khoj.routers.helpers import send_message_to_model_wrapper +from khoj.processor.conversation.utils import ChatEvent, clean_code_python, clean_json +from khoj.routers.helpers import construct_chat_history, send_message_to_model_wrapper from khoj.utils.helpers import timer from khoj.utils.rawconfig import LocationData diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index bf010034..c62fe4bf 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -91,7 +91,6 @@ from khoj.processor.conversation.utils import ( ChatEvent, ThreadedGenerator, clean_json, - construct_chat_history, generate_chatml_messages_with_context, save_to_conversation_log, ) @@ -104,6 +103,7 @@ from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.helpers import ( LRU, ConversationCommand, + get_file_type, is_none_or_empty, is_valid_url, log_telemetry, @@ -111,7 +111,7 @@ from khoj.utils.helpers import ( timer, tool_descriptions_for_llm, ) -from khoj.utils.rawconfig import LocationData +from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData logger = logging.getLogger(__name__) @@ -167,6 +167,12 @@ async def is_ready_to_chat(user: KhojUser): raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.") +def get_file_content(file: UploadFile): + file_content = file.file.read() + file_type, encoding = get_file_type(file.content_type, file_content) + return FileData(name=file.filename, content=file_content, file_type=file_type, encoding=encoding) + + def update_telemetry_state( request: Request, telemetry_type: str, @@ -248,23 +254,49 @@ async def agenerate_chat_response(*args): return await loop.run_in_executor(executor, generate_chat_response, *args) -async def gather_attached_files( - user: KhojUser, - file_filters: List[str], -) -> str: +def gather_raw_attached_files( + attached_files: Dict[str, str], +): + """_summary_ + Gather contextual data from the given (raw) files """ - Gather contextual data from the given files - """ - if len(file_filters) == 0: + + if len(attached_files) == 0: return "" - file_objects = await FileObjectAdapters.async_get_file_objects_by_names(user, file_filters) + contextual_data = " ".join( + [f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()] + ) + return f"I have attached the following files:\n\n{contextual_data}" - if len(file_objects) == 0: - return "" - contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects]) - return contextual_data +def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: + chat_history = "" + for chat in conversation_history.get("chat", [])[-n:]: + if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]: + chat_history += f"User: {chat['intent']['query']}\n" + + if chat["intent"].get("inferred-queries"): + chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n' + + chat_history += f"{agent_name}: {chat['message']}\n\n" + elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): + chat_history += f"User: {chat['intent']['query']}\n" + chat_history += f"{agent_name}: [generated image redacted for space]\n" + elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")): + chat_history += f"User: {chat['intent']['query']}\n" + chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n" + elif chat["by"] == "you": + raw_attached_files = chat.get("attachedFiles") + if raw_attached_files: + attached_files: Dict[str, str] = {} + for file in raw_attached_files: + attached_files[file["name"]] = file["content"] + + attached_file_context = gather_raw_attached_files(attached_files) + chat_history += f"User: {attached_file_context}\n" + + return chat_history async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: @@ -1179,6 +1211,7 @@ def generate_chat_response( tracer: dict = {}, train_of_thought: List[Any] = [], attached_files: str = None, + raw_attached_files: List[FileAttachment] = None, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables chat_response = None @@ -1204,6 +1237,7 @@ def generate_chat_response( query_images=query_images, tracer=tracer, train_of_thought=train_of_thought, + raw_attached_files=raw_attached_files, ) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) @@ -1299,6 +1333,7 @@ def generate_chat_response( location_data=location_data, user_name=user_name, agent=agent, + query_images=query_images, vision_available=vision_available, tracer=tracer, attached_files=attached_files, @@ -1313,23 +1348,6 @@ def generate_chat_response( return chat_response, metadata -class ChatRequestBody(BaseModel): - q: str - n: Optional[int] = 7 - d: Optional[float] = None - stream: Optional[bool] = False - title: Optional[str] = None - conversation_id: Optional[str] = None - turn_id: Optional[str] = None - city: Optional[str] = None - region: Optional[str] = None - country: Optional[str] = None - country_code: Optional[str] = None - timezone: Optional[str] = None - images: Optional[list[str]] = None - create_new: Optional[bool] = False - - class DeleteMessageRequestBody(BaseModel): conversation_id: str turn_id: str diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 960cf52f..dc34009c 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -20,7 +20,6 @@ from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( ChatEvent, construct_chat_history, - extract_relevant_info, generate_summary_from_files, send_message_to_model_wrapper, ) @@ -187,6 +186,7 @@ async def execute_information_collection( query_images, agent=agent, tracer=tracer, + attached_files=attached_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 15f5ea01..2c956c2c 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -138,6 +138,38 @@ class SearchResponse(ConfigBase): corpus_id: str +class FileData(BaseModel): + name: str + content: bytes + file_type: str + encoding: str | None = None + + +class FileAttachment(BaseModel): + name: str + content: str + file_type: str + size: int + + +class ChatRequestBody(BaseModel): + q: str + n: Optional[int] = 7 + d: Optional[float] = None + stream: Optional[bool] = False + title: Optional[str] = None + conversation_id: Optional[str] = None + turn_id: Optional[str] = None + city: Optional[str] = None + region: Optional[str] = None + country: Optional[str] = None + country_code: Optional[str] = None + timezone: Optional[str] = None + images: Optional[list[str]] = None + files: Optional[list[FileAttachment]] = None + create_new: Optional[bool] = False + + class Entry: raw: str compiled: str From 3b1e8462cd6daf60216c679e76d617efac1932a2 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Thu, 7 Nov 2024 15:59:15 -0800 Subject: [PATCH 09/42] Include attach files in calls to extract questions --- src/khoj/routers/api.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index bed7c27b..5474497d 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -351,6 +351,7 @@ async def extract_references_and_questions( query_images: Optional[List[str]] = None, agent: Agent = None, tracer: dict = {}, + attached_files: str = None, ): user = request.user.object if request.user.is_authenticated else None @@ -425,6 +426,7 @@ async def extract_references_and_questions( max_prompt_size=conversation_config.max_prompt_size, personality_context=personality_context, tracer=tracer, + attached_files=attached_files, ) elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = conversation_config.openai_config @@ -443,6 +445,7 @@ async def extract_references_and_questions( vision_enabled=vision_enabled, personality_context=personality_context, tracer=tracer, + attached_files=attached_files, ) elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: api_key = conversation_config.openai_config.api_key @@ -458,6 +461,7 @@ async def extract_references_and_questions( vision_enabled=vision_enabled, personality_context=personality_context, tracer=tracer, + attached_files=attached_files, ) elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: api_key = conversation_config.openai_config.api_key @@ -474,6 +478,7 @@ async def extract_references_and_questions( vision_enabled=vision_enabled, personality_context=personality_context, tracer=tracer, + attached_files=attached_files, ) # Collate search results as context for GPT From 394035136d791cf6ef74b0d822b7b999b2815bc4 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Thu, 7 Nov 2024 16:00:10 -0800 Subject: [PATCH 10/42] Add an api that gets a document, and converts it to just text --- src/khoj/routers/api_content.py | 61 ++++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 5 deletions(-) diff --git a/src/khoj/routers/api_content.py b/src/khoj/routers/api_content.py index 72b304ef..d5a6f2ad 100644 --- a/src/khoj/routers/api_content.py +++ b/src/khoj/routers/api_content.py @@ -36,16 +36,18 @@ from khoj.database.models import ( LocalPlaintextConfig, NotionConfig, ) +from khoj.processor.content.docx.docx_to_entries import DocxToEntries +from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries from khoj.routers.helpers import ( ApiIndexedDataLimiter, CommonQueryParams, configure_content, + get_file_content, get_user_config, update_telemetry_state, ) from khoj.utils import constants, state from khoj.utils.config import SearchModels -from khoj.utils.helpers import get_file_type from khoj.utils.rawconfig import ( ContentConfig, FullConfig, @@ -375,6 +377,54 @@ async def delete_content_source( return {"status": "ok"} +@api_content.post("/convert", status_code=200) +@requires(["authenticated"]) +async def convert_documents( + request: Request, + files: List[UploadFile], + client: Optional[str] = None, +): + converted_files = [] + supported_files = ["org", "markdown", "pdf", "plaintext", "docx"] + + for file in files: + file_data = get_file_content(file) + if file_data.file_type in supported_files: + extracted_content = ( + file_data.content.decode(file_data.encoding) if file_data.encoding else file_data.content + ) + + if file_data.file_type == "docx": + entries_per_page = DocxToEntries.extract_text(file_data.content) + extracted_content = "\n".join(entries_per_page) + + elif file_data.file_type == "pdf": + entries_per_page = PdfToEntries.extract_text(file_data.content) + extracted_content = "\n".join(entries_per_page) + + size_in_bytes = len(extracted_content.encode("utf-8")) + + converted_files.append( + { + "name": file_data.name, + "content": extracted_content, + "file_type": file_data.file_type, + "size": size_in_bytes, + } + ) + else: + logger.warning(f"Skipped converting unsupported file type sent by {client} client: {file.filename}") + + update_telemetry_state( + request=request, + telemetry_type="api", + api="convert_documents", + client=client, + ) + + return Response(content=json.dumps(converted_files), media_type="application/json", status_code=200) + + async def indexer( request: Request, files: list[UploadFile], @@ -398,10 +448,11 @@ async def indexer( try: logger.info(f"📬 Updating content index via API call by {client} client") for file in files: - file_content = file.file.read() - file_type, encoding = get_file_type(file.content_type, file_content) - if file_type in index_files: - index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content + file_data = get_file_content(file) + if file_data.file_type in index_files: + index_files[file_data.file_type][file_data.filename] = ( + file_data.content.decode(file_data.encoding) if file_data.encoding else file_data.content + ) else: logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}") From ecc81e06a7ecffec003b4f41281eda190ea0df57 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Thu, 7 Nov 2024 16:01:08 -0800 Subject: [PATCH 11/42] Add separate methods for docx and pdf files to just convert files to raw text, before further processing --- .../processor/content/docx/docx_to_entries.py | 46 +++++++++------ .../processor/content/pdf/pdf_to_entries.py | 56 ++++++++++--------- 2 files changed, 59 insertions(+), 43 deletions(-) diff --git a/src/khoj/processor/content/docx/docx_to_entries.py b/src/khoj/processor/content/docx/docx_to_entries.py index 00ed3ca4..a2948caa 100644 --- a/src/khoj/processor/content/docx/docx_to_entries.py +++ b/src/khoj/processor/content/docx/docx_to_entries.py @@ -58,28 +58,13 @@ class DocxToEntries(TextToEntries): file_to_text_map = dict() for docx_file in docx_files: try: - timestamp_now = datetime.utcnow().timestamp() - tmp_file = f"tmp_docx_file_{timestamp_now}.docx" - with open(tmp_file, "wb") as f: - bytes_content = docx_files[docx_file] - f.write(bytes_content) - - # Load the content using Docx2txtLoader - loader = Docx2txtLoader(tmp_file) - docx_entries_per_file = loader.load() - - # Convert the loaded entries into the desired format - docx_texts = [page.page_content for page in docx_entries_per_file] - + docx_texts = DocxToEntries.extract_text(docx_files[docx_file]) entry_to_location_map += zip(docx_texts, [docx_file] * len(docx_texts)) entries.extend(docx_texts) file_to_text_map[docx_file] = docx_texts except Exception as e: - logger.warning(f"Unable to process file: {docx_file}. This file will not be indexed.") + logger.warning(f"Unable to extract entries from file: {docx_file}") logger.warning(e, exc_info=True) - finally: - if os.path.exists(f"{tmp_file}"): - os.remove(f"{tmp_file}") return file_to_text_map, DocxToEntries.convert_docx_entries_to_maps(entries, dict(entry_to_location_map)) @staticmethod @@ -103,3 +88,30 @@ class DocxToEntries(TextToEntries): logger.debug(f"Converted {len(parsed_entries)} DOCX entries to dictionaries") return entries + + @staticmethod + def extract_text(docx_file): + """Extract text from specified DOCX file""" + try: + timestamp_now = datetime.utcnow().timestamp() + tmp_file = f"tmp_docx_file_{timestamp_now}.docx" + docx_entry_by_pages = [] + with open(tmp_file, "wb") as f: + bytes_content = docx_file + f.write(bytes_content) + + # Load the content using Docx2txtLoader + loader = Docx2txtLoader(tmp_file) + docx_entries_per_file = loader.load() + + # Convert the loaded entries into the desired format + docx_entry_by_pages = [page.page_content for page in docx_entries_per_file] + + except Exception as e: + logger.warning(f"Unable to extract text from file: {docx_file}") + logger.warning(e, exc_info=True) + finally: + if os.path.exists(f"{tmp_file}"): + os.remove(f"{tmp_file}") + + return docx_entry_by_pages diff --git a/src/khoj/processor/content/pdf/pdf_to_entries.py b/src/khoj/processor/content/pdf/pdf_to_entries.py index 063d1e74..35aa203f 100644 --- a/src/khoj/processor/content/pdf/pdf_to_entries.py +++ b/src/khoj/processor/content/pdf/pdf_to_entries.py @@ -59,32 +59,9 @@ class PdfToEntries(TextToEntries): entries: List[str] = [] entry_to_location_map: List[Tuple[str, str]] = [] for pdf_file in pdf_files: - try: - # Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path - timestamp_now = datetime.utcnow().timestamp() - tmp_file = f"tmp_pdf_file_{timestamp_now}.pdf" - with open(f"{tmp_file}", "wb") as f: - bytes = pdf_files[pdf_file] - f.write(bytes) - try: - loader = PyMuPDFLoader(f"{tmp_file}", extract_images=False) - pdf_entries_per_file = [page.page_content for page in loader.load()] - except ImportError: - loader = PyMuPDFLoader(f"{tmp_file}") - pdf_entries_per_file = [ - page.page_content for page in loader.load() - ] # page_content items list for a given pdf. - entry_to_location_map += zip( - pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file) - ) # this is an indexed map of pdf_entries for the pdf. - entries.extend(pdf_entries_per_file) - file_to_text_map[pdf_file] = pdf_entries_per_file - except Exception as e: - logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.") - logger.warning(e, exc_info=True) - finally: - if os.path.exists(f"{tmp_file}"): - os.remove(f"{tmp_file}") + pdf_entries_per_file = PdfToEntries.extract_text(pdf_file) + entries.extend(pdf_entries_per_file) + file_to_text_map[pdf_file] = pdf_entries_per_file return file_to_text_map, PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map)) @@ -109,3 +86,30 @@ class PdfToEntries(TextToEntries): logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries") return entries + + @staticmethod + def extract_text(pdf_file): + """Extract text from specified PDF files""" + try: + # Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path + timestamp_now = datetime.utcnow().timestamp() + tmp_file = f"tmp_pdf_file_{timestamp_now}.pdf" + pdf_entry_by_pages = [] + with open(f"{tmp_file}", "wb") as f: + f.write(pdf_file) + try: + loader = PyMuPDFLoader(f"{tmp_file}", extract_images=False) + pdf_entry_by_pages = [page.page_content for page in loader.load()] + except ImportError: + loader = PyMuPDFLoader(f"{tmp_file}") + pdf_entry_by_pages = [ + page.page_content for page in loader.load() + ] # page_content items list for a given pdf. + except Exception as e: + logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.") + logger.warning(e, exc_info=True) + finally: + if os.path.exists(f"{tmp_file}"): + os.remove(f"{tmp_file}") + + return pdf_entry_by_pages From b8ed98530f310bbb25433cca115d58d4f36b3e1b Mon Sep 17 00:00:00 2001 From: sabaimran Date: Thu, 7 Nov 2024 16:01:48 -0800 Subject: [PATCH 12/42] Accept attached files in the chat API - weave through all subsequent subcalls to models, where relevant, and save to conversation log --- src/khoj/routers/api_chat.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index a9a6f09f..cc69930e 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -45,7 +45,7 @@ from khoj.routers.helpers import ( aget_relevant_output_modes, construct_automation_created_message, create_automation, - gather_attached_files, + gather_raw_attached_files, generate_excalidraw_diagram, generate_summary_from_files, get_conversation_command, @@ -71,7 +71,12 @@ from khoj.utils.helpers import ( get_device, is_none_or_empty, ) -from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, LocationData +from khoj.utils.rawconfig import ( + ChatRequestBody, + FileFilterRequest, + FilesFilterRequest, + LocationData, +) # Initialize Router logger = logging.getLogger(__name__) @@ -566,6 +571,7 @@ async def chat( country_code = body.country_code or get_country_code_from_timezone(body.timezone) timezone = body.timezone raw_images = body.images + raw_attached_files = body.files async def event_generator(q: str, images: list[str]): start_time = time.perf_counter() @@ -577,6 +583,7 @@ async def chat( q = unquote(q) train_of_thought = [] nonlocal conversation_id + nonlocal raw_attached_files tracer: dict = { "mid": turn_id, @@ -596,6 +603,11 @@ async def chat( if uploaded_image: uploaded_images.append(uploaded_image) + attached_files: Dict[str, str] = {} + if raw_attached_files: + for file in raw_attached_files: + attached_files[file.name] = file.content + async def send_event(event_type: ChatEvent, data: str | dict): nonlocal connection_alive, ttft, train_of_thought if not connection_alive or await request.is_disconnected(): @@ -707,7 +719,7 @@ async def chat( compiled_references: List[Any] = [] inferred_queries: List[Any] = [] file_filters = conversation.file_filters if conversation and conversation.file_filters else [] - attached_file_context = await gather_attached_files(user, file_filters) + attached_file_context = gather_raw_attached_files(attached_files) if conversation_commands == [ConversationCommand.Default] or is_automated_task: conversation_commands = await aget_relevant_information_sources( @@ -833,6 +845,7 @@ async def chat( query_images=uploaded_images, tracer=tracer, train_of_thought=train_of_thought, + raw_attached_files=raw_attached_files, ) return @@ -878,6 +891,7 @@ async def chat( query_images=uploaded_images, tracer=tracer, train_of_thought=train_of_thought, + raw_attached_files=raw_attached_files, ) async for result in send_llm_response(llm_response): yield result @@ -900,6 +914,7 @@ async def chat( query_images=uploaded_images, agent=agent, tracer=tracer, + attached_files=attached_file_context, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] @@ -1085,6 +1100,8 @@ async def chat( query_images=uploaded_images, tracer=tracer, train_of_thought=train_of_thought, + attached_file_context=attached_file_context, + raw_attached_files=raw_attached_files, ) content_obj = { "intentType": intent_type, @@ -1144,6 +1161,8 @@ async def chat( query_images=uploaded_images, tracer=tracer, train_of_thought=train_of_thought, + attached_file_context=attached_file_context, + raw_attached_files=raw_attached_files, ) async for result in send_llm_response(json.dumps(content_obj)): @@ -1172,6 +1191,7 @@ async def chat( tracer, train_of_thought, attached_file_context, + raw_attached_files, ) # Send Response From 140c67f6b55ec1999f832f04957cfa232ef462e0 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Thu, 7 Nov 2024 16:02:02 -0800 Subject: [PATCH 13/42] Remove focus ring from the text area component --- src/interface/web/components/ui/textarea.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface/web/components/ui/textarea.tsx b/src/interface/web/components/ui/textarea.tsx index 68aefb3e..ce071877 100644 --- a/src/interface/web/components/ui/textarea.tsx +++ b/src/interface/web/components/ui/textarea.tsx @@ -9,7 +9,7 @@ const Textarea = React.forwardRef( return (