Add intermediate summarization of results when planning with o1

This commit is contained in:
sabaimran 2024-10-09 17:40:56 -07:00
parent 7b288a1179
commit 5b8d663cf1
3 changed files with 84 additions and 78 deletions

View file

@ -490,9 +490,11 @@ plan_function_execution = PromptTemplate.from_template(
You are an extremely methodical planner. Your goal is to make a plan to execute a function based on the user's query.
{personality_context}
- You have access to a variety of data sources to help you answer the user's question
- You can use the data sources listed below to collect more relevant information, one at a time
- You can use the data sources listed below to collect more relevant information, one at a time. The outputs will be chained.
- You are given multiple iterations to with these data sources to answer the user's question
- You are provided with additional context. If you have enough context to answer the question, then exit execution
- Each query is self-contained and you can use the data source to answer the user's question. There will be no additional data injected between queries, so make sure the query you're asking is answered in the current iteration.
- Limit each query to a *single* intention. For example, do not say "Look up the top city by population and output the GDP." Instead, say "Look up the top city by population." and then "Tell me the GDP of <the city>."
If you already know the answer to the question, return an empty response, e.g., {{}}.
@ -500,7 +502,7 @@ Which of the data sources listed below you would use to answer the user's questi
{tools}
Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data source and associated query in a JSON object. Do not say anything else.
Provide the data source and associated query in a JSON object. Do not say anything else.
Previous Iterations:
{previous_iterations}
@ -520,8 +522,7 @@ previous_iteration = PromptTemplate.from_template(
"""
data_source: {data_source}
query: {query}
context: {context}
onlineContext: {onlineContext}
summary: {summary}
---
""".strip()
)

View file

@ -718,18 +718,20 @@ async def chat(
):
if type(research_result) == InformationCollectionIteration:
pending_research = False
if research_result.onlineContext:
researched_results += str(research_result.onlineContext)
online_results.update(research_result.onlineContext)
# if research_result.onlineContext:
# researched_results += str(research_result.onlineContext)
# online_results.update(research_result.onlineContext)
if research_result.context:
researched_results += str(research_result.context)
compiled_references.extend(research_result.context)
# if research_result.context:
# researched_results += str(research_result.context)
# compiled_references.extend(research_result.context)
researched_results += research_result.summarizedResult
else:
yield research_result
researched_results = await extract_relevant_info(q, researched_results, agent)
# researched_results = await extract_relevant_info(q, researched_results, agent)
logger.info(f"Researched Results: {researched_results}")

View file

@ -13,6 +13,7 @@ from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ChatEvent,
construct_chat_history,
extract_relevant_info,
generate_summary_from_files,
send_message_to_model_wrapper,
)
@ -27,11 +28,19 @@ logger = logging.getLogger(__name__)
class InformationCollectionIteration:
def __init__(self, data_source: str, query: str, context: str = None, onlineContext: dict = None):
def __init__(
self,
data_source: str,
query: str,
context: str = None,
onlineContext: dict = None,
summarizedResult: str = None,
):
self.data_source = data_source
self.query = query
self.context = context
self.onlineContext = onlineContext
self.summarizedResult = summarizedResult
async def apick_next_tool(
@ -63,8 +72,7 @@ async def apick_next_tool(
iteration_data = prompts.previous_iteration.format(
query=iteration.query,
data_source=iteration.data_source,
context=str(iteration.context),
onlineContext=str(iteration.onlineContext),
summary=iteration.summarizedResult,
)
previous_iterations_history += iteration_data
@ -138,7 +146,8 @@ async def execute_information_collection(
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
defiltered_query = None
result: str = ""
this_iteration = await apick_next_tool(
query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations
@ -165,13 +174,7 @@ async def execute_information_collection(
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
context=str(compiled_references),
)
)
this_iteration.context = str(compiled_references)
elif this_iteration.data_source == ConversationCommand.Online:
async for result in search_online(
@ -189,13 +192,7 @@ async def execute_information_collection(
yield result[ChatEvent.STATUS]
else:
online_results = result
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
onlineContext=online_results,
)
)
this_iteration.onlineContext = online_results
elif this_iteration.data_source == ConversationCommand.Webpage:
async for result in read_webpages(
@ -224,57 +221,63 @@ async def execute_information_collection(
webpages.append(webpage["link"])
yield send_status_func(f"**Read web pages**: {webpages}")
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
onlineContext=online_results,
)
)
this_iteration.onlineContext = online_results
elif this_iteration.data_source == ConversationCommand.Summarize:
response_log = ""
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
if len(file_filters) == 0 and not agent_has_entries:
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
context="No files selected for summarization.",
)
)
elif len(file_filters) > 1 and not agent_has_entries:
response_log = "Only one file can be selected for summarization."
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
context=response_log,
)
)
else:
async for response in generate_summary_from_files(
q=query,
user=user,
file_filters=file_filters,
meta_log=conversation_history,
subscribed=subscribed,
send_status_func=send_status_func,
):
if isinstance(response, dict) and ChatEvent.STATUS in response:
yield response[ChatEvent.STATUS]
else:
response_log = response # type: ignore
previous_iterations.append(
InformationCollectionIteration(
data_source=this_iteration.data_source,
query=this_iteration.query,
context=response_log,
)
)
# TODO: Fix summarize later
# elif this_iteration.data_source == ConversationCommand.Summarize:
# response_log = ""
# agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
# if len(file_filters) == 0 and not agent_has_entries:
# previous_iterations.append(
# InformationCollectionIteration(
# data_source=this_iteration.data_source,
# query=this_iteration.query,
# context="No files selected for summarization.",
# )
# )
# elif len(file_filters) > 1 and not agent_has_entries:
# response_log = "Only one file can be selected for summarization."
# previous_iterations.append(
# InformationCollectionIteration(
# data_source=this_iteration.data_source,
# query=this_iteration.query,
# context=response_log,
# )
# )
# else:
# async for response in generate_summary_from_files(
# q=query,
# user=user,
# file_filters=file_filters,
# meta_log=conversation_history,
# subscribed=subscribed,
# send_status_func=send_status_func,
# ):
# if isinstance(response, dict) and ChatEvent.STATUS in response:
# yield response[ChatEvent.STATUS]
# else:
# response_log = response # type: ignore
# previous_iterations.append(
# InformationCollectionIteration(
# data_source=this_iteration.data_source,
# query=this_iteration.query,
# context=response_log,
# )
# )
else:
iteration = MAX_ITERATIONS
iteration += 1
for completed_iter in previous_iterations:
yield completed_iter
if compiled_references or online_results:
results_data = f"**Results**:\n"
if compiled_references:
results_data += f"**Document References**: {compiled_references}\n"
if online_results:
results_data += f"**Online Results**: {online_results}\n"
intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = intermediate_result
previous_iterations.append(this_iteration)
yield this_iteration