From 9e8ac7f89e7316077351b69bb7a0aebfee86826c Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 26 Oct 2024 16:37:58 -0700 Subject: [PATCH] Fix input/output mismatches in the /summarize command --- src/khoj/database/adapters/__init__.py | 4 +++- src/khoj/routers/api_chat.py | 3 ++- src/khoj/routers/helpers.py | 7 +++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index a2c531f8..14b092d7 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1480,7 +1480,9 @@ class EntryAdapters: @staticmethod async def aget_agent_entry_filepaths(agent: Agent): - return await sync_to_async(list)(Entry.objects.filter(agent=agent).values_list("file_path", flat=True)) + return await sync_to_async(set)( + Entry.objects.filter(agent=agent).distinct("file_path").values_list("file_path", flat=True) + ) @staticmethod def get_all_filenames_by_source(user: KhojUser, file_source: str): diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 5ebfd911..69894a68 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -792,9 +792,10 @@ async def chat( tracer=tracer, ): if isinstance(response, dict) and ChatEvent.STATUS in response: - yield result[ChatEvent.STATUS] + yield response[ChatEvent.STATUS] else: if isinstance(response, str): + response_log = response async for result in send_llm_response(response): yield result diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 2af1f64d..c648c12b 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -650,7 +650,7 @@ async def generate_summary_from_files( if await EntryAdapters.aagent_has_entries(agent): file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) if len(file_names) > 0: - file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names[0], agent) + file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) if len(file_filters) > 0: file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) @@ -663,7 +663,7 @@ async def generate_summary_from_files( if not q: q = "Create a general summary of the file" async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"): - yield result + yield {ChatEvent.STATUS: result} response = await extract_relevant_summary( q, @@ -674,9 +674,8 @@ async def generate_summary_from_files( agent=agent, tracer=tracer, ) - response_log = str(response) - yield result + yield str(response) except Exception as e: response_log = "Error summarizing file. Please try again, or contact support." logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)