Deduplicate searches in normal mode & across research iterations

- Deduplicate online, doc search queries across research iterations.
  This avoids running previously run online, doc searches again and
  dedupes online, doc context seen by model to generate response.
- Deduplicate online search queries generated by chat model for each
  user query.
- Do not pass online, docs, code context separately when generate
  response in research mode. These are already collected in the meta
  research passed with the user query
- Improve formatting of context passed to generate research response
  - Use xml tags to delimit context. Pass per iteration queries in each
    iteration result
  - Put user query before meta research results in user message passed
    for generating response

This deduplications will improve speed, cost & quality of research mode
This commit is contained in:
Debanjum 2024-11-08 10:43:02 -08:00
parent 306f7a2132
commit 137687ee49
6 changed files with 86 additions and 42 deletions

View file

@ -112,6 +112,7 @@ class InformationCollectionIteration:
onlineContext: dict = None,
codeContext: dict = None,
summarizedResult: str = None,
warning: str = None,
):
self.tool = tool
self.query = query
@ -119,6 +120,7 @@ class InformationCollectionIteration:
self.onlineContext = onlineContext
self.codeContext = codeContext
self.summarizedResult = summarizedResult
self.warning = warning
def construct_iteration_history(

View file

@ -4,7 +4,7 @@ import logging
import os
import urllib.parse
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import aiohttp
from bs4 import BeautifulSoup
@ -66,6 +66,7 @@ async def search_online(
custom_filters: List[str] = [],
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
query_images: List[str] = None,
previous_subqueries: Set = set(),
agent: Agent = None,
tracer: dict = {},
):
@ -76,19 +77,24 @@ async def search_online(
return
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(
new_subqueries = await generate_online_subqueries(
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
)
response_dict = {}
subqueries = list(new_subqueries - previous_subqueries)
response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}
if subqueries:
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
yield {ChatEvent.STATUS: event}
if is_none_or_empty(subqueries):
logger.info("No new subqueries to search online")
yield response_dict
return
with timer(f"Internet searches for {list(subqueries)} took", logger):
logger.info(f"🌐 Searching the Internet for {subqueries}")
if send_status_func:
subqueries_str = "\n- " + "\n- ".join(subqueries)
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
yield {ChatEvent.STATUS: event}
with timer(f"Internet searches for {subqueries} took", logger):
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
search_tasks = [search_func(subquery, location) for subquery in subqueries]
search_results = await asyncio.gather(*search_tasks)
@ -119,7 +125,9 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
read_webpage_and_extract_content(
data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer
)
for link, data in webpages.items()
]
results = await asyncio.gather(*tasks)

View file

@ -6,7 +6,7 @@ import os
import threading
import time
import uuid
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Set, Union
import cron_descriptor
import pytz
@ -349,6 +349,7 @@ async def extract_references_and_questions(
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None,
previous_inferred_queries: Set = set(),
agent: Agent = None,
tracer: dict = {},
):
@ -477,6 +478,7 @@ async def extract_references_and_questions(
)
# Collate search results as context for GPT
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
with timer("Searching knowledge base took", logger):
search_results = []
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")

View file

@ -778,7 +778,8 @@ async def chat(
yield research_result
# researched_results = await extract_relevant_info(q, researched_results, agent)
logger.info(f"Researched Results: {researched_results}")
if state.verbose > 1:
logger.debug(f"Researched Results: {researched_results}")
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
file_filters = conversation.file_filters if conversation else []

View file

@ -20,6 +20,7 @@ from typing import (
Iterator,
List,
Optional,
Set,
Tuple,
Union,
)
@ -494,7 +495,7 @@ async def generate_online_subqueries(
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
) -> List[str]:
) -> Set[str]:
"""
Generate subqueries from the given query
"""
@ -529,14 +530,14 @@ async def generate_online_subqueries(
try:
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
response = {q.strip() for q in response["queries"] if q.strip()}
if not isinstance(response, set) or not response or len(response) == 0:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
return [q]
return {q}
return response
except Exception as e:
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
return [q]
return {q}
async def schedule_query(
@ -1128,9 +1129,6 @@ def generate_chat_response(
metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
query_to_run = q
if meta_research:
query_to_run = f"AI Research: {meta_research} {q}"
try:
partial_completion = partial(
save_to_conversation_log,
@ -1148,6 +1146,13 @@ def generate_chat_response(
train_of_thought=train_of_thought,
)
query_to_run = q
if meta_research:
query_to_run = f"<query>{q}</query>\n<collected_research>\n{meta_research}\n</collected_research>"
compiled_references = []
online_results = {}
code_results = {}
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
vision_available = conversation_config.vision_enabled
if not vision_available and query_images:

View file

@ -43,38 +43,35 @@ async def apick_next_tool(
location: LocationData = None,
user_name: str = None,
agent: Agent = None,
previous_iterations_history: str = None,
previous_iterations: List[InformationCollectionIteration] = [],
max_iterations: int = 5,
send_status_func: Optional[Callable] = None,
tracer: dict = {},
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
"""
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
# Construct tool options for the agent to choose from
tool_options = dict()
tool_options_str = ""
agent_tools = agent.input_tools if agent else []
for tool, description in function_calling_description_for_llm.items():
tool_options[tool.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options_str += f'- "{tool.value}": "{description}"\n'
# Construct chat history with user and iteration history with researcher agent for context
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
if query_images:
query = f"[placeholder for user attached images]\n{query}"
today = datetime.today()
location_data = f"{location}" if location else "Unknown"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
# Extract Past User Message and Inferred Questions from Conversation Log
today = datetime.today()
location_data = f"{location}" if location else "Unknown"
function_planning_prompt = prompts.plan_function_execution.format(
tools=tool_options_str,
chat_history=chat_history,
@ -112,8 +109,15 @@ async def apick_next_tool(
selected_tool = response.get("tool", None)
generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None)
warning = None
logger.info(f"Response for determining relevant tools: {response}")
if send_status_func:
# Detect selection of previously used query, tool combination.
previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations}
if (selected_tool, generated_query) in previous_tool_query_combinations:
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
# Only send client status updates if we'll execute this iteration
elif send_status_func:
determined_tool_message = "**Determined Tool**: "
determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond."
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
@ -123,13 +127,14 @@ async def apick_next_tool(
yield InformationCollectionIteration(
tool=selected_tool,
query=generated_query,
warning=warning,
)
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
yield InformationCollectionIteration(
tool=None,
query=None,
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
)
@ -156,7 +161,6 @@ async def execute_information_collection(
document_results: List[Dict[str, str]] = []
summarize_files: str = ""
this_iteration = InformationCollectionIteration(tool=None, query=query)
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
async for result in apick_next_tool(
query,
@ -166,7 +170,7 @@ async def execute_information_collection(
location,
user_name,
agent,
previous_iterations_history,
previous_iterations,
MAX_ITERATIONS,
send_status_func,
tracer=tracer,
@ -176,9 +180,16 @@ async def execute_information_collection(
elif isinstance(result, InformationCollectionIteration):
this_iteration = result
if this_iteration.tool == ConversationCommand.Notes:
# Skip running iteration if warning present in iteration
if this_iteration.warning:
logger.warning(f"Research mode: {this_iteration.warning}.")
elif this_iteration.tool == ConversationCommand.Notes:
this_iteration.context = []
document_results = []
previous_inferred_queries = {
c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
}
async for result in extract_references_and_questions(
request,
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
@ -190,6 +201,7 @@ async def execute_information_collection(
location,
send_status_func,
query_images,
previous_inferred_queries=previous_inferred_queries,
agent=agent,
tracer=tracer,
):
@ -213,6 +225,12 @@ async def execute_information_collection(
logger.error(f"Error extracting document references: {e}", exc_info=True)
elif this_iteration.tool == ConversationCommand.Online:
previous_subqueries = {
subquery
for iteration in previous_iterations
if iteration.onlineContext
for subquery in iteration.onlineContext.keys()
}
async for result in search_online(
this_iteration.query,
construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
@ -222,11 +240,16 @@ async def execute_information_collection(
[],
max_webpages_to_read=0,
query_images=query_images,
previous_subqueries=previous_subqueries,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif is_none_or_empty(result):
this_iteration.warning = (
"Detected previously run online search queries. Skipping iteration. Try something different."
)
else:
online_results: Dict[str, Dict] = result # type: ignore
this_iteration.onlineContext = online_results
@ -311,16 +334,19 @@ async def execute_information_collection(
current_iteration += 1
if document_results or online_results or code_results or summarize_files:
results_data = f"**Results**:\n"
if document_results or online_results or code_results or summarize_files or this_iteration.warning:
results_data = f"\n<iteration>{current_iteration}\n<tool>{this_iteration.tool}</tool>\n<query>{this_iteration.query}</query>\n<results>"
if document_results:
results_data += f"**Document References**:\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
if online_results:
results_data += f"**Online Results**:\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"\n<online_results>\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
if code_results:
results_data += f"**Code Results**:\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"\n<code_results>\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
if summarize_files:
results_data += f"**Summarized Files**:\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
results_data += f"\n<summarized_files>\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</summarized_files>"
if this_iteration.warning:
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
results_data += "\n</results>\n</iteration>"
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = results_data