Clean up summarization code paths

Use assumption of summarization response being a str
This commit is contained in:
sabaimran 2024-10-23 20:06:04 -07:00 committed by Debanjum Singh Solanky
parent 12b32a3d04
commit 5acf40c440
3 changed files with 32 additions and 55 deletions

View file

@ -704,7 +704,7 @@ async def chat(
location=location,
file_filters=conversation.file_filters if conversation else [],
):
if type(research_result) == InformationCollectionIteration:
if isinstance(research_result, InformationCollectionIteration):
if research_result.summarizedResult:
pending_research = False
if research_result.onlineContext:
@ -778,12 +778,13 @@ async def chat(
query_images=uploaded_images,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
send_response_func=partial(send_llm_response),
):
if isinstance(response, dict) and ChatEvent.STATUS in response:
yield result[ChatEvent.STATUS]
else:
response
if isinstance(response, str):
async for result in send_llm_response(response):
yield result
await sync_to_async(save_to_conversation_log)(
q,

View file

@ -623,7 +623,6 @@ async def generate_summary_from_files(
query_images: List[str] = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
send_response_func: Optional[Callable] = None,
):
try:
file_object = None
@ -636,11 +635,8 @@ async def generate_summary_from_files(
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = (
"Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
)
async for result in send_response_func(response_log):
yield result
response_log = "Sorry, I couldn't find the full text of this file."
yield response_log
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
@ -657,12 +653,11 @@ async def generate_summary_from_files(
agent=agent,
)
response_log = str(response)
async for result in send_response_func(response_log):
yield result
except Exception as e:
response_log = "Error summarizing file. Please try again, or contact support."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_response_func(response_log):
yield result

View file

@ -143,6 +143,7 @@ async def execute_information_collection(
online_results: Dict = dict()
code_results: Dict = dict()
compiled_references: List[Any] = []
summarize_files: str = ""
inferred_queries: List[Any] = []
this_iteration = InformationCollectionIteration(tool=None, query=query)
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
@ -271,53 +272,31 @@ async def execute_information_collection(
exc_info=True,
)
# TODO: Fix summarize later
# elif this_iteration.data_source == ConversationCommand.Summarize:
# response_log = ""
# agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
# if len(file_filters) == 0 and not agent_has_entries:
# previous_iterations.append(
# InformationCollectionIteration(
# data_source=this_iteration.data_source,
# query=this_iteration.query,
# context="No files selected for summarization.",
# )
# )
# elif len(file_filters) > 1 and not agent_has_entries:
# response_log = "Only one file can be selected for summarization."
# previous_iterations.append(
# InformationCollectionIteration(
# data_source=this_iteration.data_source,
# query=this_iteration.query,
# context=response_log,
# )
# )
# else:
# async for response in generate_summary_from_files(
# q=query,
# user=user,
# file_filters=file_filters,
# meta_log=conversation_history,
# subscribed=subscribed,
# send_status_func=send_status_func,
# ):
# if isinstance(response, dict) and ChatEvent.STATUS in response:
# yield response[ChatEvent.STATUS]
# else:
# response_log = response # type: ignore
# previous_iterations.append(
# InformationCollectionIteration(
# data_source=this_iteration.data_source,
# query=this_iteration.query,
# context=response_log,
# )
# )
elif this_iteration.tool == ConversationCommand.Summarize:
try:
async for result in generate_summary_from_files(
this_iteration.query,
user,
file_filters,
conversation_history,
query_images=query_images,
agent=agent,
send_status_func=send_status_func,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
summarize_files = result # type: ignore
except Exception as e:
logger.error(f"Error generating summary: {e}", exc_info=True)
else:
# No valid tools. This is our exit condition.
current_iteration = MAX_ITERATIONS
current_iteration += 1
if compiled_references or online_results or code_results:
if compiled_references or online_results or code_results or summarize_files:
results_data = f"**Results**:\n"
if compiled_references:
results_data += f"**Document References**: {compiled_references}\n"
@ -325,6 +304,8 @@ async def execute_information_collection(
results_data += f"**Online Results**: {online_results}\n"
if code_results:
results_data += f"**Code Results**: {code_results}\n"
if summarize_files:
results_data += f"**Summarized Files**: {summarize_files}\n"
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = results_data