mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Deduplicate searches in normal mode & across research iterations
- Deduplicate online, doc search queries across research iterations. This avoids running previously run online, doc searches again and dedupes online, doc context seen by model to generate response. - Deduplicate online search queries generated by chat model for each user query. - Do not pass online, docs, code context separately when generate response in research mode. These are already collected in the meta research passed with the user query - Improve formatting of context passed to generate research response - Use xml tags to delimit context. Pass per iteration queries in each iteration result - Put user query before meta research results in user message passed for generating response This deduplications will improve speed, cost & quality of research mode
This commit is contained in:
parent
306f7a2132
commit
137687ee49
6 changed files with 86 additions and 42 deletions
|
@ -112,6 +112,7 @@ class InformationCollectionIteration:
|
|||
onlineContext: dict = None,
|
||||
codeContext: dict = None,
|
||||
summarizedResult: str = None,
|
||||
warning: str = None,
|
||||
):
|
||||
self.tool = tool
|
||||
self.query = query
|
||||
|
@ -119,6 +120,7 @@ class InformationCollectionIteration:
|
|||
self.onlineContext = onlineContext
|
||||
self.codeContext = codeContext
|
||||
self.summarizedResult = summarizedResult
|
||||
self.warning = warning
|
||||
|
||||
|
||||
def construct_iteration_history(
|
||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import os
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
from bs4 import BeautifulSoup
|
||||
|
@ -66,6 +66,7 @@ async def search_online(
|
|||
custom_filters: List[str] = [],
|
||||
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
||||
query_images: List[str] = None,
|
||||
previous_subqueries: Set = set(),
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
|
@ -76,19 +77,24 @@ async def search_online(
|
|||
return
|
||||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(
|
||||
new_subqueries = await generate_online_subqueries(
|
||||
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
|
||||
)
|
||||
response_dict = {}
|
||||
subqueries = list(new_subqueries - previous_subqueries)
|
||||
response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}
|
||||
|
||||
if subqueries:
|
||||
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
|
||||
if is_none_or_empty(subqueries):
|
||||
logger.info("No new subqueries to search online")
|
||||
yield response_dict
|
||||
return
|
||||
|
||||
logger.info(f"🌐 Searching the Internet for {subqueries}")
|
||||
if send_status_func:
|
||||
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
|
||||
subqueries_str = "\n- " + "\n- ".join(subqueries)
|
||||
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
with timer(f"Internet searches for {list(subqueries)} took", logger):
|
||||
with timer(f"Internet searches for {subqueries} took", logger):
|
||||
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
|
||||
search_tasks = [search_func(subquery, location) for subquery in subqueries]
|
||||
search_results = await asyncio.gather(*search_tasks)
|
||||
|
@ -119,7 +125,9 @@ async def search_online(
|
|||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [
|
||||
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
|
||||
read_webpage_and_extract_content(
|
||||
data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer
|
||||
)
|
||||
for link, data in webpages.items()
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
|
|
@ -6,7 +6,7 @@ import os
|
|||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing import Any, Callable, List, Optional, Set, Union
|
||||
|
||||
import cron_descriptor
|
||||
import pytz
|
||||
|
@ -349,6 +349,7 @@ async def extract_references_and_questions(
|
|||
location_data: LocationData = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
previous_inferred_queries: Set = set(),
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
|
@ -477,6 +478,7 @@ async def extract_references_and_questions(
|
|||
)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
|
||||
with timer("Searching knowledge base took", logger):
|
||||
search_results = []
|
||||
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
|
||||
|
|
|
@ -778,7 +778,8 @@ async def chat(
|
|||
yield research_result
|
||||
|
||||
# researched_results = await extract_relevant_info(q, researched_results, agent)
|
||||
logger.info(f"Researched Results: {researched_results}")
|
||||
if state.verbose > 1:
|
||||
logger.debug(f"Researched Results: {researched_results}")
|
||||
|
||||
used_slash_summarize = conversation_commands == [ConversationCommand.Summarize]
|
||||
file_filters = conversation.file_filters if conversation else []
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import (
|
|||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
@ -494,7 +495,7 @@ async def generate_online_subqueries(
|
|||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Generate subqueries from the given query
|
||||
"""
|
||||
|
@ -529,14 +530,14 @@ async def generate_online_subqueries(
|
|||
try:
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response or len(response) == 0:
|
||||
response = {q.strip() for q in response["queries"] if q.strip()}
|
||||
if not isinstance(response, set) or not response or len(response) == 0:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
|
||||
return [q]
|
||||
return {q}
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
|
||||
return [q]
|
||||
return {q}
|
||||
|
||||
|
||||
async def schedule_query(
|
||||
|
@ -1128,9 +1129,6 @@ def generate_chat_response(
|
|||
|
||||
metadata = {}
|
||||
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||
query_to_run = q
|
||||
if meta_research:
|
||||
query_to_run = f"AI Research: {meta_research} {q}"
|
||||
try:
|
||||
partial_completion = partial(
|
||||
save_to_conversation_log,
|
||||
|
@ -1148,6 +1146,13 @@ def generate_chat_response(
|
|||
train_of_thought=train_of_thought,
|
||||
)
|
||||
|
||||
query_to_run = q
|
||||
if meta_research:
|
||||
query_to_run = f"<query>{q}</query>\n<collected_research>\n{meta_research}\n</collected_research>"
|
||||
compiled_references = []
|
||||
online_results = {}
|
||||
code_results = {}
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
if not vision_available and query_images:
|
||||
|
|
|
@ -43,38 +43,35 @@ async def apick_next_tool(
|
|||
location: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
previous_iterations_history: str = None,
|
||||
previous_iterations: List[InformationCollectionIteration] = [],
|
||||
max_iterations: int = 5,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
|
||||
"""
|
||||
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
|
||||
|
||||
# Construct tool options for the agent to choose from
|
||||
tool_options = dict()
|
||||
tool_options_str = ""
|
||||
|
||||
agent_tools = agent.input_tools if agent else []
|
||||
|
||||
for tool, description in function_calling_description_for_llm.items():
|
||||
tool_options[tool.value] = description
|
||||
if len(agent_tools) == 0 or tool.value in agent_tools:
|
||||
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||
|
||||
# Construct chat history with user and iteration history with researcher agent for context
|
||||
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||
|
||||
if query_images:
|
||||
query = f"[placeholder for user attached images]\n{query}"
|
||||
|
||||
today = datetime.today()
|
||||
location_data = f"{location}" if location else "Unknown"
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
today = datetime.today()
|
||||
location_data = f"{location}" if location else "Unknown"
|
||||
|
||||
function_planning_prompt = prompts.plan_function_execution.format(
|
||||
tools=tool_options_str,
|
||||
chat_history=chat_history,
|
||||
|
@ -112,8 +109,15 @@ async def apick_next_tool(
|
|||
selected_tool = response.get("tool", None)
|
||||
generated_query = response.get("query", None)
|
||||
scratchpad = response.get("scratchpad", None)
|
||||
warning = None
|
||||
logger.info(f"Response for determining relevant tools: {response}")
|
||||
if send_status_func:
|
||||
|
||||
# Detect selection of previously used query, tool combination.
|
||||
previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations}
|
||||
if (selected_tool, generated_query) in previous_tool_query_combinations:
|
||||
warning = f"Repeated tool, query combination detected. Skipping iteration. Try something different."
|
||||
# Only send client status updates if we'll execute this iteration
|
||||
elif send_status_func:
|
||||
determined_tool_message = "**Determined Tool**: "
|
||||
determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond."
|
||||
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
|
||||
|
@ -123,13 +127,14 @@ async def apick_next_tool(
|
|||
yield InformationCollectionIteration(
|
||||
tool=selected_tool,
|
||||
query=generated_query,
|
||||
warning=warning,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
|
||||
yield InformationCollectionIteration(
|
||||
tool=None,
|
||||
query=None,
|
||||
warning=f"Invalid response for determining relevant tools: {response}. Skipping iteration. Fix error: {e}",
|
||||
)
|
||||
|
||||
|
||||
|
@ -156,7 +161,6 @@ async def execute_information_collection(
|
|||
document_results: List[Dict[str, str]] = []
|
||||
summarize_files: str = ""
|
||||
this_iteration = InformationCollectionIteration(tool=None, query=query)
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||
|
||||
async for result in apick_next_tool(
|
||||
query,
|
||||
|
@ -166,7 +170,7 @@ async def execute_information_collection(
|
|||
location,
|
||||
user_name,
|
||||
agent,
|
||||
previous_iterations_history,
|
||||
previous_iterations,
|
||||
MAX_ITERATIONS,
|
||||
send_status_func,
|
||||
tracer=tracer,
|
||||
|
@ -176,9 +180,16 @@ async def execute_information_collection(
|
|||
elif isinstance(result, InformationCollectionIteration):
|
||||
this_iteration = result
|
||||
|
||||
if this_iteration.tool == ConversationCommand.Notes:
|
||||
# Skip running iteration if warning present in iteration
|
||||
if this_iteration.warning:
|
||||
logger.warning(f"Research mode: {this_iteration.warning}.")
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Notes:
|
||||
this_iteration.context = []
|
||||
document_results = []
|
||||
previous_inferred_queries = {
|
||||
c["query"] for iteration in previous_iterations if iteration.context for c in iteration.context
|
||||
}
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
|
||||
|
@ -190,6 +201,7 @@ async def execute_information_collection(
|
|||
location,
|
||||
send_status_func,
|
||||
query_images,
|
||||
previous_inferred_queries=previous_inferred_queries,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
|
@ -213,6 +225,12 @@ async def execute_information_collection(
|
|||
logger.error(f"Error extracting document references: {e}", exc_info=True)
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Online:
|
||||
previous_subqueries = {
|
||||
subquery
|
||||
for iteration in previous_iterations
|
||||
if iteration.onlineContext
|
||||
for subquery in iteration.onlineContext.keys()
|
||||
}
|
||||
async for result in search_online(
|
||||
this_iteration.query,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
|
||||
|
@ -222,11 +240,16 @@ async def execute_information_collection(
|
|||
[],
|
||||
max_webpages_to_read=0,
|
||||
query_images=query_images,
|
||||
previous_subqueries=previous_subqueries,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
elif is_none_or_empty(result):
|
||||
this_iteration.warning = (
|
||||
"Detected previously run online search queries. Skipping iteration. Try something different."
|
||||
)
|
||||
else:
|
||||
online_results: Dict[str, Dict] = result # type: ignore
|
||||
this_iteration.onlineContext = online_results
|
||||
|
@ -311,16 +334,19 @@ async def execute_information_collection(
|
|||
|
||||
current_iteration += 1
|
||||
|
||||
if document_results or online_results or code_results or summarize_files:
|
||||
results_data = f"**Results**:\n"
|
||||
if document_results or online_results or code_results or summarize_files or this_iteration.warning:
|
||||
results_data = f"\n<iteration>{current_iteration}\n<tool>{this_iteration.tool}</tool>\n<query>{this_iteration.query}</query>\n<results>"
|
||||
if document_results:
|
||||
results_data += f"**Document References**:\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
|
||||
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
|
||||
if online_results:
|
||||
results_data += f"**Online Results**:\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
|
||||
results_data += f"\n<online_results>\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
|
||||
if code_results:
|
||||
results_data += f"**Code Results**:\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
|
||||
results_data += f"\n<code_results>\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
|
||||
if summarize_files:
|
||||
results_data += f"**Summarized Files**:\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
|
||||
results_data += f"\n<summarized_files>\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</summarized_files>"
|
||||
if this_iteration.warning:
|
||||
results_data += f"\n<warning>\n{this_iteration.warning}\n</warning>"
|
||||
results_data += "\n</results>\n</iteration>"
|
||||
|
||||
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
|
||||
this_iteration.summarizedResult = results_data
|
||||
|
|
Loading…
Reference in a new issue