From de73cbc6106ede0d02db64da707c2cb96acd8655 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Thu, 7 Nov 2024 15:58:52 -0800 Subject: [PATCH] Add support for relaying attached files through backend calls to models --- .../conversation/anthropic/anthropic_chat.py | 8 +- .../conversation/google/gemini_chat.py | 9 ++- src/khoj/processor/conversation/openai/gpt.py | 8 +- src/khoj/processor/conversation/utils.py | 49 +++++++----- src/khoj/processor/tools/run_code.py | 10 +-- src/khoj/routers/helpers.py | 80 ++++++++++++------- src/khoj/routers/research.py | 2 +- src/khoj/utils/rawconfig.py | 32 ++++++++ 8 files changed, 136 insertions(+), 62 deletions(-) diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py index df81f56f..c171c8fb 100644 --- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py +++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py @@ -37,6 +37,7 @@ def extract_questions_anthropic( vision_enabled: bool = False, personality_context: Optional[str] = None, tracer: dict = {}, + attached_files: str = None, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -84,7 +85,12 @@ def extract_questions_anthropic( vision_enabled=vision_enabled, ) - messages = [ChatMessage(content=prompt, role="user")] + messages = [] + + if attached_files: + messages.append(ChatMessage(content=attached_files, role="user")) + + messages.append(ChatMessage(content=prompt, role="user")) response = anthropic_completion_with_backoff( messages=messages, diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py index b7ec018d..6d257faa 100644 --- a/src/khoj/processor/conversation/google/gemini_chat.py +++ b/src/khoj/processor/conversation/google/gemini_chat.py @@ -38,6 +38,7 @@ def extract_questions_gemini( vision_enabled: bool = False, personality_context: Optional[str] = None, tracer: dict = {}, + attached_files: str = None, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -85,7 +86,13 @@ def extract_questions_gemini( vision_enabled=vision_enabled, ) - messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")] + messages = [] + + if attached_files: + messages.append(ChatMessage(content=attached_files, role="user")) + + messages.append(ChatMessage(content=prompt, role="user")) + messages.append(ChatMessage(content=system_prompt, role="system")) response = gemini_send_message_to_model( messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index bdb67448..65cdfa3f 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -35,6 +35,7 @@ def extract_questions( vision_enabled: bool = False, personality_context: Optional[str] = None, tracer: dict = {}, + attached_files: str = None, ): """ Infer search queries to retrieve relevant notes to answer user query @@ -81,7 +82,12 @@ def extract_questions( vision_enabled=vision_enabled, ) - messages = [ChatMessage(content=prompt, role="user")] + messages = [] + + if attached_files: + messages.append(ChatMessage(content=attached_files, role="user")) + + messages.append(ChatMessage(content=prompt, role="user")) response = send_message_to_model( messages, diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index b510a09a..791a98e0 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -36,6 +36,7 @@ from khoj.utils.helpers import ( is_none_or_empty, merge_dicts, ) +from khoj.utils.rawconfig import FileAttachment logger = logging.getLogger(__name__) @@ -137,25 +138,6 @@ def construct_iteration_history( return previous_iterations_history -def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: - chat_history = "" - for chat in conversation_history.get("chat", [])[-n:]: - if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]: - chat_history += f"User: {chat['intent']['query']}\n" - - if chat["intent"].get("inferred-queries"): - chat_history += f'Khoj: {{"queries": {chat["intent"].get("inferred-queries")}}}\n' - - chat_history += f"{agent_name}: {chat['message']}\n\n" - elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: [generated image redacted for space]\n" - elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")): - chat_history += f"User: {chat['intent']['query']}\n" - chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n" - return chat_history - - def construct_tool_chat_history( previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None ) -> Dict[str, list]: @@ -241,6 +223,7 @@ def save_to_conversation_log( conversation_id: str = None, automation_id: str = None, query_images: List[str] = None, + raw_attached_files: List[FileAttachment] = [], tracer: Dict[str, Any] = {}, train_of_thought: List[Any] = [], ): @@ -253,6 +236,7 @@ def save_to_conversation_log( "created": user_message_time, "images": query_images, "turnId": turn_id, + "attachedFiles": [file.model_dump(mode="json") for file in raw_attached_files], }, khoj_message_metadata={ "context": compiled_references, @@ -306,6 +290,22 @@ def construct_structured_message(message: str, images: list[str], model_type: st return message +def gather_raw_attached_files( + attached_files: Dict[str, str], +): + """_summary_ + Gather contextual data from the given (raw) files + """ + + if len(attached_files) == 0: + return "" + + contextual_data = " ".join( + [f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()] + ) + return f"I have attached the following files:\n\n{contextual_data}" + + def generate_chatml_messages_with_context( user_message, system_message=None, @@ -335,6 +335,8 @@ def generate_chatml_messages_with_context( chatml_messages: List[ChatMessage] = [] for chat in conversation_log.get("chat", []): message_context = "" + message_attached_files = "" + if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): message_context += chat.get("intent").get("inferred-queries")[0] if not is_none_or_empty(chat.get("context")): @@ -343,6 +345,15 @@ def generate_chatml_messages_with_context( ) message_context += f"{prompts.notes_conversation.format(references=references)}\n\n" + if chat.get("attachedFiles"): + raw_attached_files = chat.get("attachedFiles") + attached_files_dict = dict() + for file in raw_attached_files: + attached_files_dict[file["name"]] = file["content"] + + message_attached_files = gather_raw_attached_files(attached_files_dict) + chatml_messages.append(ChatMessage(content=message_attached_files, role="user")) + if not is_none_or_empty(chat.get("onlineContext")): message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" diff --git a/src/khoj/processor/tools/run_code.py b/src/khoj/processor/tools/run_code.py index 86388aee..418ab3a2 100644 --- a/src/khoj/processor/tools/run_code.py +++ b/src/khoj/processor/tools/run_code.py @@ -6,18 +6,12 @@ import os from typing import Any, Callable, List, Optional import aiohttp -import requests from khoj.database.adapters import ais_user_subscribed from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts -from khoj.processor.conversation.utils import ( - ChatEvent, - clean_code_python, - clean_json, - construct_chat_history, -) -from khoj.routers.helpers import send_message_to_model_wrapper +from khoj.processor.conversation.utils import ChatEvent, clean_code_python, clean_json +from khoj.routers.helpers import construct_chat_history, send_message_to_model_wrapper from khoj.utils.helpers import timer from khoj.utils.rawconfig import LocationData diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index bf010034..c62fe4bf 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -91,7 +91,6 @@ from khoj.processor.conversation.utils import ( ChatEvent, ThreadedGenerator, clean_json, - construct_chat_history, generate_chatml_messages_with_context, save_to_conversation_log, ) @@ -104,6 +103,7 @@ from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.helpers import ( LRU, ConversationCommand, + get_file_type, is_none_or_empty, is_valid_url, log_telemetry, @@ -111,7 +111,7 @@ from khoj.utils.helpers import ( timer, tool_descriptions_for_llm, ) -from khoj.utils.rawconfig import LocationData +from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData logger = logging.getLogger(__name__) @@ -167,6 +167,12 @@ async def is_ready_to_chat(user: KhojUser): raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.") +def get_file_content(file: UploadFile): + file_content = file.file.read() + file_type, encoding = get_file_type(file.content_type, file_content) + return FileData(name=file.filename, content=file_content, file_type=file_type, encoding=encoding) + + def update_telemetry_state( request: Request, telemetry_type: str, @@ -248,23 +254,49 @@ async def agenerate_chat_response(*args): return await loop.run_in_executor(executor, generate_chat_response, *args) -async def gather_attached_files( - user: KhojUser, - file_filters: List[str], -) -> str: +def gather_raw_attached_files( + attached_files: Dict[str, str], +): + """_summary_ + Gather contextual data from the given (raw) files """ - Gather contextual data from the given files - """ - if len(file_filters) == 0: + + if len(attached_files) == 0: return "" - file_objects = await FileObjectAdapters.async_get_file_objects_by_names(user, file_filters) + contextual_data = " ".join( + [f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in attached_files.items()] + ) + return f"I have attached the following files:\n\n{contextual_data}" - if len(file_objects) == 0: - return "" - contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects]) - return contextual_data +def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str: + chat_history = "" + for chat in conversation_history.get("chat", [])[-n:]: + if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]: + chat_history += f"User: {chat['intent']['query']}\n" + + if chat["intent"].get("inferred-queries"): + chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n' + + chat_history += f"{agent_name}: {chat['message']}\n\n" + elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): + chat_history += f"User: {chat['intent']['query']}\n" + chat_history += f"{agent_name}: [generated image redacted for space]\n" + elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")): + chat_history += f"User: {chat['intent']['query']}\n" + chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n" + elif chat["by"] == "you": + raw_attached_files = chat.get("attachedFiles") + if raw_attached_files: + attached_files: Dict[str, str] = {} + for file in raw_attached_files: + attached_files[file["name"]] = file["content"] + + attached_file_context = gather_raw_attached_files(attached_files) + chat_history += f"User: {attached_file_context}\n" + + return chat_history async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: @@ -1179,6 +1211,7 @@ def generate_chat_response( tracer: dict = {}, train_of_thought: List[Any] = [], attached_files: str = None, + raw_attached_files: List[FileAttachment] = None, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: # Initialize Variables chat_response = None @@ -1204,6 +1237,7 @@ def generate_chat_response( query_images=query_images, tracer=tracer, train_of_thought=train_of_thought, + raw_attached_files=raw_attached_files, ) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) @@ -1299,6 +1333,7 @@ def generate_chat_response( location_data=location_data, user_name=user_name, agent=agent, + query_images=query_images, vision_available=vision_available, tracer=tracer, attached_files=attached_files, @@ -1313,23 +1348,6 @@ def generate_chat_response( return chat_response, metadata -class ChatRequestBody(BaseModel): - q: str - n: Optional[int] = 7 - d: Optional[float] = None - stream: Optional[bool] = False - title: Optional[str] = None - conversation_id: Optional[str] = None - turn_id: Optional[str] = None - city: Optional[str] = None - region: Optional[str] = None - country: Optional[str] = None - country_code: Optional[str] = None - timezone: Optional[str] = None - images: Optional[list[str]] = None - create_new: Optional[bool] = False - - class DeleteMessageRequestBody(BaseModel): conversation_id: str turn_id: str diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index 960cf52f..dc34009c 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -20,7 +20,6 @@ from khoj.routers.api import extract_references_and_questions from khoj.routers.helpers import ( ChatEvent, construct_chat_history, - extract_relevant_info, generate_summary_from_files, send_message_to_model_wrapper, ) @@ -187,6 +186,7 @@ async def execute_information_collection( query_images, agent=agent, tracer=tracer, + attached_files=attached_files, ): if isinstance(result, dict) and ChatEvent.STATUS in result: yield result[ChatEvent.STATUS] diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 15f5ea01..2c956c2c 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -138,6 +138,38 @@ class SearchResponse(ConfigBase): corpus_id: str +class FileData(BaseModel): + name: str + content: bytes + file_type: str + encoding: str | None = None + + +class FileAttachment(BaseModel): + name: str + content: str + file_type: str + size: int + + +class ChatRequestBody(BaseModel): + q: str + n: Optional[int] = 7 + d: Optional[float] = None + stream: Optional[bool] = False + title: Optional[str] = None + conversation_id: Optional[str] = None + turn_id: Optional[str] = None + city: Optional[str] = None + region: Optional[str] = None + country: Optional[str] = None + country_code: Optional[str] = None + timezone: Optional[str] = None + images: Optional[list[str]] = None + files: Optional[list[FileAttachment]] = None + create_new: Optional[bool] = False + + class Entry: raw: str compiled: str