From f700d5bddb99660b646a6d741b2bf68de995919c Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 7 Oct 2024 21:20:23 -0700 Subject: [PATCH] Add summarization capability with agent knowledge base --- src/khoj/database/adapters/__init__.py | 12 ++++++++++-- src/khoj/routers/api_chat.py | 21 ++++++++++++++++----- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index a6afb15c..9485005a 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1214,8 +1214,8 @@ class FileObjectAdapters: return await FileObject.objects.acreate(user=user, file_name=file_name, raw_text=raw_text) @staticmethod - async def async_get_file_objects_by_name(user: KhojUser, file_name: str): - return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name)) + async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): + return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) @staticmethod async def async_get_all_file_objects(user: KhojUser): @@ -1305,6 +1305,10 @@ class EntryAdapters: async def auser_has_entries(user: KhojUser): return await Entry.objects.filter(user=user).aexists() + @staticmethod + async def aagent_has_entries(agent: Agent): + return await Entry.objects.filter(agent=agent).aexists() + @staticmethod async def adelete_entry_by_file(user: KhojUser, file_path: str): return await Entry.objects.filter(user=user, file_path=file_path).adelete() @@ -1319,6 +1323,10 @@ class EntryAdapters: return deleted_count + @staticmethod + async def aget_agent_entry_filepaths(agent: Agent): + return await sync_to_async(list)(Entry.objects.filter(agent=agent).values_list("file_path", flat=True)) + @staticmethod def get_all_filenames_by_source(user: KhojUser, file_source: str): return ( diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 45ded48b..37a33534 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -722,19 +722,30 @@ async def chat( conversation_commands.remove(ConversationCommand.Summarize) elif ConversationCommand.Summarize in conversation_commands: response_log = "" - if len(file_filters) == 0: + agent_has_entries = await EntryAdapters.aagent_has_entries(agent) + if len(file_filters) == 0 and not agent_has_entries: response_log = "No files selected for summarization. Please add files using the section on the left." async for result in send_llm_response(response_log): yield result - elif len(file_filters) > 1: + elif len(file_filters) > 1 and not agent_has_entries: response_log = "Only one file can be selected for summarization." async for result in send_llm_response(response_log): yield result else: try: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) + file_object = None + if await EntryAdapters.aagent_has_entries(agent): + file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) + if len(file_names) > 0: + file_object = await FileObjectAdapters.async_get_file_objects_by_name( + None, file_names[0], agent + ) + + if len(file_filters) > 0: + file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) + if len(file_object) == 0: - response_log = "Sorry, we couldn't find the full text of this file. Please re-upload the document and try again." + response_log = "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again." async for result in send_llm_response(response_log): yield result return @@ -753,7 +764,7 @@ async def chat( async for result in send_llm_response(response_log): yield result except Exception as e: - response_log = "Error summarizing file." + response_log = "Error summarizing file. Please try again, or contact support." logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) async for result in send_llm_response(response_log): yield result