From 5acf40c440fc7aec95ba244946dd468e450f7a38 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 23 Oct 2024 20:06:04 -0700 Subject: [PATCH] Clean up summarization code paths Use assumption of summarization response being a str --- src/khoj/routers/api_chat.py | 7 ++-- src/khoj/routers/helpers.py | 15 +++------ src/khoj/routers/research.py | 65 +++++++++++++----------------------- 3 files changed, 32 insertions(+), 55 deletions(-) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 69196e83..32513b55 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -704,7 +704,7 @@ async def chat( location=location, file_filters=conversation.file_filters if conversation else [], ): - if type(research_result) == InformationCollectionIteration: + if isinstance(research_result, InformationCollectionIteration): if research_result.summarizedResult: pending_research = False if research_result.onlineContext: @@ -778,12 +778,13 @@ async def chat( query_images=uploaded_images, agent=agent, send_status_func=partial(send_event, ChatEvent.STATUS), - send_response_func=partial(send_llm_response), ): if isinstance(response, dict) and ChatEvent.STATUS in response: yield result[ChatEvent.STATUS] else: - response + if isinstance(response, str): + async for result in send_llm_response(response): + yield result await sync_to_async(save_to_conversation_log)( q, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 2585f77d..bfe25fe3 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -623,7 +623,6 @@ async def generate_summary_from_files( query_images: List[str] = None, agent: Agent = None, send_status_func: Optional[Callable] = None, - send_response_func: Optional[Callable] = None, ): try: file_object = None @@ -636,11 +635,8 @@ async def generate_summary_from_files( file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) if len(file_object) == 0: - response_log = ( - "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again." - ) - async for result in send_response_func(response_log): - yield result + response_log = "Sorry, I couldn't find the full text of this file." + yield response_log return contextual_data = " ".join([file.raw_text for file in file_object]) if not q: @@ -657,13 +653,12 @@ async def generate_summary_from_files( agent=agent, ) response_log = str(response) - async for result in send_response_func(response_log): - yield result + + yield result except Exception as e: response_log = "Error summarizing file. Please try again, or contact support." logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True) - async for result in send_response_func(response_log): - yield result + yield result async def generate_excalidraw_diagram( diff --git a/src/khoj/routers/research.py b/src/khoj/routers/research.py index d1578b5b..1beb1f69 100644 --- a/src/khoj/routers/research.py +++ b/src/khoj/routers/research.py @@ -143,6 +143,7 @@ async def execute_information_collection( online_results: Dict = dict() code_results: Dict = dict() compiled_references: List[Any] = [] + summarize_files: str = "" inferred_queries: List[Any] = [] this_iteration = InformationCollectionIteration(tool=None, query=query) previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration) @@ -271,53 +272,31 @@ async def execute_information_collection( exc_info=True, ) - # TODO: Fix summarize later - # elif this_iteration.data_source == ConversationCommand.Summarize: - # response_log = "" - # agent_has_entries = await EntryAdapters.aagent_has_entries(agent) - # if len(file_filters) == 0 and not agent_has_entries: - # previous_iterations.append( - # InformationCollectionIteration( - # data_source=this_iteration.data_source, - # query=this_iteration.query, - # context="No files selected for summarization.", - # ) - # ) - # elif len(file_filters) > 1 and not agent_has_entries: - # response_log = "Only one file can be selected for summarization." - # previous_iterations.append( - # InformationCollectionIteration( - # data_source=this_iteration.data_source, - # query=this_iteration.query, - # context=response_log, - # ) - # ) - # else: - # async for response in generate_summary_from_files( - # q=query, - # user=user, - # file_filters=file_filters, - # meta_log=conversation_history, - # subscribed=subscribed, - # send_status_func=send_status_func, - # ): - # if isinstance(response, dict) and ChatEvent.STATUS in response: - # yield response[ChatEvent.STATUS] - # else: - # response_log = response # type: ignore - # previous_iterations.append( - # InformationCollectionIteration( - # data_source=this_iteration.data_source, - # query=this_iteration.query, - # context=response_log, - # ) - # ) + elif this_iteration.tool == ConversationCommand.Summarize: + try: + async for result in generate_summary_from_files( + this_iteration.query, + user, + file_filters, + conversation_history, + query_images=query_images, + agent=agent, + send_status_func=send_status_func, + ): + if isinstance(result, dict) and ChatEvent.STATUS in result: + yield result[ChatEvent.STATUS] + else: + summarize_files = result # type: ignore + except Exception as e: + logger.error(f"Error generating summary: {e}", exc_info=True) + else: + # No valid tools. This is our exit condition. current_iteration = MAX_ITERATIONS current_iteration += 1 - if compiled_references or online_results or code_results: + if compiled_references or online_results or code_results or summarize_files: results_data = f"**Results**:\n" if compiled_references: results_data += f"**Document References**: {compiled_references}\n" @@ -325,6 +304,8 @@ async def execute_information_collection( results_data += f"**Online Results**: {online_results}\n" if code_results: results_data += f"**Code Results**: {code_results}\n" + if summarize_files: + results_data += f"**Summarized Files**: {summarize_files}\n" # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) this_iteration.summarizedResult = results_data