mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 19:03:01 +01:00
Add intermediate summarization of results when planning with o1
This commit is contained in:
parent
7b288a1179
commit
5b8d663cf1
3 changed files with 84 additions and 78 deletions
|
@ -490,9 +490,11 @@ plan_function_execution = PromptTemplate.from_template(
|
|||
You are an extremely methodical planner. Your goal is to make a plan to execute a function based on the user's query.
|
||||
{personality_context}
|
||||
- You have access to a variety of data sources to help you answer the user's question
|
||||
- You can use the data sources listed below to collect more relevant information, one at a time
|
||||
- You can use the data sources listed below to collect more relevant information, one at a time. The outputs will be chained.
|
||||
- You are given multiple iterations to with these data sources to answer the user's question
|
||||
- You are provided with additional context. If you have enough context to answer the question, then exit execution
|
||||
- Each query is self-contained and you can use the data source to answer the user's question. There will be no additional data injected between queries, so make sure the query you're asking is answered in the current iteration.
|
||||
- Limit each query to a *single* intention. For example, do not say "Look up the top city by population and output the GDP." Instead, say "Look up the top city by population." and then "Tell me the GDP of <the city>."
|
||||
|
||||
If you already know the answer to the question, return an empty response, e.g., {{}}.
|
||||
|
||||
|
@ -500,7 +502,7 @@ Which of the data sources listed below you would use to answer the user's questi
|
|||
|
||||
{tools}
|
||||
|
||||
Now it's your turn to pick the data sources you would like to use to answer the user's question. Provide the data source and associated query in a JSON object. Do not say anything else.
|
||||
Provide the data source and associated query in a JSON object. Do not say anything else.
|
||||
|
||||
Previous Iterations:
|
||||
{previous_iterations}
|
||||
|
@ -520,8 +522,7 @@ previous_iteration = PromptTemplate.from_template(
|
|||
"""
|
||||
data_source: {data_source}
|
||||
query: {query}
|
||||
context: {context}
|
||||
onlineContext: {onlineContext}
|
||||
summary: {summary}
|
||||
---
|
||||
""".strip()
|
||||
)
|
||||
|
|
|
@ -718,18 +718,20 @@ async def chat(
|
|||
):
|
||||
if type(research_result) == InformationCollectionIteration:
|
||||
pending_research = False
|
||||
if research_result.onlineContext:
|
||||
researched_results += str(research_result.onlineContext)
|
||||
online_results.update(research_result.onlineContext)
|
||||
# if research_result.onlineContext:
|
||||
# researched_results += str(research_result.onlineContext)
|
||||
# online_results.update(research_result.onlineContext)
|
||||
|
||||
if research_result.context:
|
||||
researched_results += str(research_result.context)
|
||||
compiled_references.extend(research_result.context)
|
||||
# if research_result.context:
|
||||
# researched_results += str(research_result.context)
|
||||
# compiled_references.extend(research_result.context)
|
||||
|
||||
researched_results += research_result.summarizedResult
|
||||
|
||||
else:
|
||||
yield research_result
|
||||
|
||||
researched_results = await extract_relevant_info(q, researched_results, agent)
|
||||
# researched_results = await extract_relevant_info(q, researched_results, agent)
|
||||
|
||||
logger.info(f"Researched Results: {researched_results}")
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ from khoj.routers.api import extract_references_and_questions
|
|||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
construct_chat_history,
|
||||
extract_relevant_info,
|
||||
generate_summary_from_files,
|
||||
send_message_to_model_wrapper,
|
||||
)
|
||||
|
@ -27,11 +28,19 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class InformationCollectionIteration:
|
||||
def __init__(self, data_source: str, query: str, context: str = None, onlineContext: dict = None):
|
||||
def __init__(
|
||||
self,
|
||||
data_source: str,
|
||||
query: str,
|
||||
context: str = None,
|
||||
onlineContext: dict = None,
|
||||
summarizedResult: str = None,
|
||||
):
|
||||
self.data_source = data_source
|
||||
self.query = query
|
||||
self.context = context
|
||||
self.onlineContext = onlineContext
|
||||
self.summarizedResult = summarizedResult
|
||||
|
||||
|
||||
async def apick_next_tool(
|
||||
|
@ -63,8 +72,7 @@ async def apick_next_tool(
|
|||
iteration_data = prompts.previous_iteration.format(
|
||||
query=iteration.query,
|
||||
data_source=iteration.data_source,
|
||||
context=str(iteration.context),
|
||||
onlineContext=str(iteration.onlineContext),
|
||||
summary=iteration.summarizedResult,
|
||||
)
|
||||
|
||||
previous_iterations_history += iteration_data
|
||||
|
@ -138,7 +146,8 @@ async def execute_information_collection(
|
|||
|
||||
compiled_references: List[Any] = []
|
||||
inferred_queries: List[Any] = []
|
||||
defiltered_query = None
|
||||
|
||||
result: str = ""
|
||||
|
||||
this_iteration = await apick_next_tool(
|
||||
query, conversation_history, subscribed, uploaded_image_url, agent, previous_iterations
|
||||
|
@ -165,13 +174,7 @@ async def execute_information_collection(
|
|||
compiled_references.extend(result[0])
|
||||
inferred_queries.extend(result[1])
|
||||
defiltered_query = result[2]
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
context=str(compiled_references),
|
||||
)
|
||||
)
|
||||
this_iteration.context = str(compiled_references)
|
||||
|
||||
elif this_iteration.data_source == ConversationCommand.Online:
|
||||
async for result in search_online(
|
||||
|
@ -189,13 +192,7 @@ async def execute_information_collection(
|
|||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
online_results = result
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
onlineContext=online_results,
|
||||
)
|
||||
)
|
||||
this_iteration.onlineContext = online_results
|
||||
|
||||
elif this_iteration.data_source == ConversationCommand.Webpage:
|
||||
async for result in read_webpages(
|
||||
|
@ -224,57 +221,63 @@ async def execute_information_collection(
|
|||
webpages.append(webpage["link"])
|
||||
yield send_status_func(f"**Read web pages**: {webpages}")
|
||||
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
onlineContext=online_results,
|
||||
)
|
||||
)
|
||||
this_iteration.onlineContext = online_results
|
||||
|
||||
elif this_iteration.data_source == ConversationCommand.Summarize:
|
||||
response_log = ""
|
||||
agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||
if len(file_filters) == 0 and not agent_has_entries:
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
context="No files selected for summarization.",
|
||||
)
|
||||
)
|
||||
elif len(file_filters) > 1 and not agent_has_entries:
|
||||
response_log = "Only one file can be selected for summarization."
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
context=response_log,
|
||||
)
|
||||
)
|
||||
else:
|
||||
async for response in generate_summary_from_files(
|
||||
q=query,
|
||||
user=user,
|
||||
file_filters=file_filters,
|
||||
meta_log=conversation_history,
|
||||
subscribed=subscribed,
|
||||
send_status_func=send_status_func,
|
||||
):
|
||||
if isinstance(response, dict) and ChatEvent.STATUS in response:
|
||||
yield response[ChatEvent.STATUS]
|
||||
else:
|
||||
response_log = response # type: ignore
|
||||
previous_iterations.append(
|
||||
InformationCollectionIteration(
|
||||
data_source=this_iteration.data_source,
|
||||
query=this_iteration.query,
|
||||
context=response_log,
|
||||
)
|
||||
)
|
||||
# TODO: Fix summarize later
|
||||
# elif this_iteration.data_source == ConversationCommand.Summarize:
|
||||
# response_log = ""
|
||||
# agent_has_entries = await EntryAdapters.aagent_has_entries(agent)
|
||||
# if len(file_filters) == 0 and not agent_has_entries:
|
||||
# previous_iterations.append(
|
||||
# InformationCollectionIteration(
|
||||
# data_source=this_iteration.data_source,
|
||||
# query=this_iteration.query,
|
||||
# context="No files selected for summarization.",
|
||||
# )
|
||||
# )
|
||||
# elif len(file_filters) > 1 and not agent_has_entries:
|
||||
# response_log = "Only one file can be selected for summarization."
|
||||
# previous_iterations.append(
|
||||
# InformationCollectionIteration(
|
||||
# data_source=this_iteration.data_source,
|
||||
# query=this_iteration.query,
|
||||
# context=response_log,
|
||||
# )
|
||||
# )
|
||||
# else:
|
||||
# async for response in generate_summary_from_files(
|
||||
# q=query,
|
||||
# user=user,
|
||||
# file_filters=file_filters,
|
||||
# meta_log=conversation_history,
|
||||
# subscribed=subscribed,
|
||||
# send_status_func=send_status_func,
|
||||
# ):
|
||||
# if isinstance(response, dict) and ChatEvent.STATUS in response:
|
||||
# yield response[ChatEvent.STATUS]
|
||||
# else:
|
||||
# response_log = response # type: ignore
|
||||
# previous_iterations.append(
|
||||
# InformationCollectionIteration(
|
||||
# data_source=this_iteration.data_source,
|
||||
# query=this_iteration.query,
|
||||
# context=response_log,
|
||||
# )
|
||||
# )
|
||||
else:
|
||||
iteration = MAX_ITERATIONS
|
||||
|
||||
iteration += 1
|
||||
for completed_iter in previous_iterations:
|
||||
yield completed_iter
|
||||
|
||||
if compiled_references or online_results:
|
||||
results_data = f"**Results**:\n"
|
||||
if compiled_references:
|
||||
results_data += f"**Document References**: {compiled_references}\n"
|
||||
if online_results:
|
||||
results_data += f"**Online Results**: {online_results}\n"
|
||||
|
||||
intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
|
||||
this_iteration.summarizedResult = intermediate_result
|
||||
|
||||
previous_iterations.append(this_iteration)
|
||||
yield this_iteration
|
||||
|
|
Loading…
Reference in a new issue