Deduplicate searches in normal mode & across research iterations

- Deduplicate online, doc search queries across research iterations.
  This avoids running previously run online, doc searches again and
  dedupes online, doc context seen by model to generate response.
- Deduplicate online search queries generated by chat model for each
  user query.
- Do not pass online, docs, code context separately when generate
  response in research mode. These are already collected in the meta
  research passed with the user query
- Improve formatting of context passed to generate research response
  - Use xml tags to delimit context. Pass per iteration queries in each
    iteration result
  - Put user query before meta research results in user message passed
    for generating response

This deduplications will improve speed, cost & quality of research mode
This commit is contained in:
Debanjum 2024-11-08 10:43:02 -08:00
parent 306f7a2132
commit 137687ee49
6 changed files with 86 additions and 42 deletions

View file

@ -112,6 +112,7 @@ class InformationCollectionIteration:
onlineContext: dict = None, onlineContext: dict = None,
codeContext: dict = None, codeContext: dict = None,
summarizedResult: str = None, summarizedResult: str = None,
warning: str = None,
): ):
self.tool = tool self.tool = tool
self.query = query self.query = query
@ -119,6 +120,7 @@ class InformationCollectionIteration:
self.onlineContext = onlineContext self.onlineContext = onlineContext
self.codeContext = codeContext self.codeContext = codeContext
self.summarizedResult = summarizedResult self.summarizedResult = summarizedResult
self.warning = warning
def construct_iteration_history( def construct_iteration_history(

View file

@ -4,7 +4,7 @@ import logging
import os import os
import urllib.parse import urllib.parse
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import aiohttp import aiohttp
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
@ -66,6 +66,7 @@ async def search_online(
custom_filters: List[str] = [], custom_filters: List[str] = [],
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ, max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
query_images: List[str] = None, query_images: List[str] = None,
previous_subqueries: Set = set(),
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
): ):
@ -76,19 +77,24 @@ async def search_online(
return return
# Breakdown the query into subqueries to get the correct answer # Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries( new_subqueries = await generate_online_subqueries(
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
) )
response_dict = {} subqueries = list(new_subqueries - previous_subqueries)
response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}
if subqueries: if is_none_or_empty(subqueries):
logger.info(f"🌐 Searching the Internet for {list(subqueries)}") logger.info("No new subqueries to search online")
if send_status_func: yield response_dict
subqueries_str = "\n- " + "\n- ".join(list(subqueries)) return
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
yield {ChatEvent.STATUS: event}
with timer(f"Internet searches for {list(subqueries)} took", logger): logger.info(f"🌐 Searching the Internet for {subqueries}")
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(subqueries)
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
yield {ChatEvent.STATUS: event}
with timer(f"Internet searches for {subqueries} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
search_tasks = [search_func(subquery, location) for subquery in subqueries] search_tasks = [search_func(subquery, location) for subquery in subqueries]
search_results = await asyncio.gather(*search_tasks) search_results = await asyncio.gather(*search_tasks)
@ -119,7 +125,9 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"): async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
tasks = [ tasks = [
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer) read_webpage_and_extract_content(
data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer
)
for link, data in webpages.items() for link, data in webpages.items()
] ]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)

View file

@ -6,7 +6,7 @@ import os
import threading import threading
import time import time
import uuid import uuid
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Set, Union
import cron_descriptor import cron_descriptor
import pytz import pytz
@ -349,6 +349,7 @@ async def extract_references_and_questions(
location_data: LocationData = None, location_data: LocationData = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
previous_inferred_queries: Set = set(),
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
): ):
@ -477,6 +478,7 @@ async def extract_references_and_questions(
) )
# Collate search results as context for GPT # Collate search results as context for GPT
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
with timer("Searching knowledge base took", logger): with timer("Searching knowledge base took", logger):
search_results = [] search_results = []
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}") logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")

View file

@ -778,7 +778,8 @@ async def chat(
yield research_result yield research_result
# researched_results = await extract_relevant_info(q, researched_results, agent) # researched_results = await extract_relevant_info(q, researched_results, agent)
logger.info(f"Researched Results: {researched_results}") if state.verbose > 1:
logger.debug(f"Researched Results: {researched_results}")
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize] used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else [] file_filters = conversation.file_filters if conversation else []

View file

@ -20,6 +20,7 @@ from typing import (
Iterator, Iterator,
List, List,
Optional, Optional,
Set,
Tuple, Tuple,
Union, Union,
) )
@ -494,7 +495,7 @@ async def generate_online_subqueries(
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
) -> List[str]: ) -> Set[str]:
""" """
Generate subqueries from the given query Generate subqueries from the given query
""" """
@ -529,14 +530,14 @@ async def generate_online_subqueries(
try: try:
response = clean_json(response) response = clean_json(response)
response = json.loads(response) response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()] response = {q.strip() for q in response["queries"] if q.strip()}
if not isinstance(response, list) or not response or len(response) == 0: if not isinstance(response, set) or not response or len(response) == 0:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}") logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
return [q] return {q}
return response return response
except Exception as e: except Exception as e:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}") logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
return [q] return {q}
async def schedule_query( async def schedule_query(
@ -1128,9 +1129,6 @@ def generate_chat_response(
metadata = {} metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
query_to_run = q
if meta_research:
query_to_run = f"AI Research: {meta_research} {q}"
try: try:
partial_completion = partial( partial_completion = partial(
save_to_conversation_log, save_to_conversation_log,
@ -1148,6 +1146,13 @@ def generate_chat_response(
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
) )
query_to_run = q
if meta_research:
query_to_run = f"<query>{q}</query>\n<collected_research>\n{meta_research}\n</collected_research>"
compiled_references = []
online_results = {}
code_results = {}
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
vision_available = conversation_config.vision_enabled vision_available = conversation_config.vision_enabled
if not vision_available and query_images: if not vision_available and query_images:

View file

@ -43,38 +43,35 @@ async def apick_next_tool(
location: LocationData = None, location: LocationData = None,
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
previous_iterations_history: str = None, previous_iterations: List[InformationCollectionIteration] = [],
max_iterations: int = 5, max_iterations: int = 5,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
"""
# Construct tool options for the agent to choose from
tool_options = dict() tool_options = dict()
tool_options_str = "" tool_options_str = ""
agent_tools = agent.input_tools if agent else [] agent_tools = agent.input_tools if agent else []
for tool, description in function_calling_description_for_llm.items(): for tool, description in function_calling_description_for_llm.items():
tool_options[tool.value] = description tool_options[tool.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools: if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options_str += f'- "{tool.value}": "{description}"\n' tool_options_str += f'- "{tool.value}": "{description}"\n'
# Construct chat history with user and iteration history with researcher agent for context
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj") chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
if query_images: if query_images:
query = f"[placeholder for user attached images]\n{query}" query = f"[placeholder for user attached images]\n{query}"
today = datetime.today()
location_data = f"{location}" if location else "Unknown"
personality_context = ( personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
) )
# Extract Past User Message and Inferred Questions from Conversation Log
today = datetime.today()
location_data = f"{location}" if location else "Unknown"
function_planning_prompt = prompts.plan_function_execution.format( function_planning_prompt = prompts.plan_function_execution.format(
tools=tool_options_str, tools=tool_options_str,
chat_history=chat_history, chat_history=chat_history,
@ -112,8 +109,15 @@ async def apick_next_tool(
selected_tool = response.get("tool", None) selected_tool = response.get("tool", None)
generated_query = response.get("query", None) generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None) scratchpad = response.get("scratchpad", None)
warning = None
logger.info(f"Response for determining relevant tools: {response}") logger.info(f"Response for determining relevant tools: {response}")
if send_status_func:
# Detect selection of previously used query, tool combination.
previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations}
if (selected_tool, generated_query) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
# Only send client status updates if we'll execute this iteration
elif send_status_func:
determined_tool_message = "**Determined Tool**: " determined_tool_message = "**Determined Tool**: "
determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond." determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond."
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else "" determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
@ -123,13 +127,14 @@ async def apick_next_tool(
yield InformationCollectionIteration( yield InformationCollectionIteration(
tool=selected_tool, tool=selected_tool,
query=generated_query, query=generated_query,
warning=warning,
) )
except Exception as e: except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True) logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
yield InformationCollectionIteration( yield InformationCollectionIteration(
tool=None, tool=None,
query=None, query=None,
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
) )
@ -156,7 +161,6 @@ async def execute_information_collection(
document_results: List[Dict[str, str]] = [] document_results: List[Dict[str, str]] = []
summarize_files: str = "" summarize_files: str = ""
this_iteration = InformationCollectionIteration(tool=None, query=query) this_iteration = InformationCollectionIteration(tool=None, query=query)
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
async for result in apick_next_tool( async for result in apick_next_tool(
query, query,
@ -166,7 +170,7 @@ async def execute_information_collection(
location, location,
user_name, user_name,
agent, agent,
previous_iterations_history, previous_iterations,
MAX_ITERATIONS, MAX_ITERATIONS,
send_status_func, send_status_func,
tracer=tracer, tracer=tracer,
@ -176,9 +180,16 @@ async def execute_information_collection(
elif isinstance(result, InformationCollectionIteration): elif isinstance(result, InformationCollectionIteration):
this_iteration = result this_iteration = result
if this_iteration.tool == ConversationCommand.Notes: # Skip running iteration if warning present in iteration
if this_iteration.warning:
logger.warning(f"Research mode: {this_iteration.warning}.")
elif this_iteration.tool == ConversationCommand.Notes:
this_iteration.context = [] this_iteration.context = []
document_results = [] document_results = []
previous_inferred_queries = {
c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
}
async for result in extract_references_and_questions( async for result in extract_references_and_questions(
request, request,
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes), construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
@ -190,6 +201,7 @@ async def execute_information_collection(
location, location,
send_status_func, send_status_func,
query_images, query_images,
previous_inferred_queries=previous_inferred_queries,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
): ):
@ -213,6 +225,12 @@ async def execute_information_collection(
logger.error(f"Error extracting document references: {e}", exc_info=True) logger.error(f"Error extracting document references: {e}", exc_info=True)
elif this_iteration.tool == ConversationCommand.Online: elif this_iteration.tool == ConversationCommand.Online:
previous_subqueries = {
subquery
for iteration in previous_iterations
if iteration.onlineContext
for subquery in iteration.onlineContext.keys()
}
async for result in search_online( async for result in search_online(
this_iteration.query, this_iteration.query,
construct_tool_chat_history(previous_iterations, ConversationCommand.Online), construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
@ -222,11 +240,16 @@ async def execute_information_collection(
[], [],
max_webpages_to_read=0, max_webpages_to_read=0,
query_images=query_images, query_images=query_images,
previous_subqueries=previous_subqueries,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
elif is_none_or_empty(result):
this_iteration.warning = (
"Detected previously run online search queries. Skipping iteration. Try something different."
)
else: else:
online_results: Dict[str, Dict] = result # type: ignore online_results: Dict[str, Dict] = result # type: ignore
this_iteration.onlineContext = online_results this_iteration.onlineContext = online_results
@ -311,16 +334,19 @@ async def execute_information_collection(
current_iteration += 1 current_iteration += 1
if document_results or online_results or code_results or summarize_files: if document_results or online_results or code_results or summarize_files or this_iteration.warning:
results_data = f"**Results**:\n" results_data = f"\n<iteration>{current_iteration}\n<tool>{this_iteration.tool}</tool>\n<query>{this_iteration.query}</query>\n<results>"
if document_results: if document_results:
results_data += f"**Document References**:\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
if online_results: if online_results:
results_data += f"**Online Results**:\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" results_data += f"\n<online_results>\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
if code_results: if code_results:
results_data += f"**Code Results**:\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" results_data += f"\n<code_results>\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
if summarize_files: if summarize_files:
results_data += f"**Summarized Files**:\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n" results_data += f"\n<summarized_files>\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</summarized_files>"
if this_iteration.warning:
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
results_data += "\n</results>\n</iteration>"
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent) # intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = results_data this_iteration.summarizedResult = results_data