mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Deduplicate searches in normal mode & across research iterations
- Chat model generates online search queries for each user query. These should be deduplicated before search online - Research mode generates online search queries for each iteration. These need to be deduplicated across iteration. To avoid performing already run online searches again and dedupe online context seen by model to generate response. - Do not online, docs, code context separately when generate response in research mode. These are already collected in the meta research passed with the user query - Improve formatting of context passed to generate research response - Use xml tags to delimit context. Pass per iteration queries in each iteration result - Put user query before meta research results in user message passed for generating response This deduplication will reduce both speed and cost of research mode.
This commit is contained in:
parent
750b4fa404
commit
5e8bb71413
4 changed files with 50 additions and 21 deletions
|
@ -4,7 +4,7 @@ import logging
|
|||
import os
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
from bs4 import BeautifulSoup
|
||||
|
@ -66,6 +66,7 @@ async def search_online(
|
|||
custom_filters: List[str] = [],
|
||||
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
||||
query_images: List[str] = None,
|
||||
previous_subqueries: Set = set(),
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
|
@ -76,19 +77,20 @@ async def search_online(
|
|||
return
|
||||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(
|
||||
new_subqueries = await generate_online_subqueries(
|
||||
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
|
||||
)
|
||||
subqueries = list(new_subqueries - previous_subqueries)
|
||||
response_dict = {}
|
||||
|
||||
if subqueries:
|
||||
logger.info(f"🌐 Searching the Internet for {list(subqueries)}")
|
||||
logger.info(f"🌐 Searching the Internet for {subqueries}")
|
||||
if send_status_func:
|
||||
subqueries_str = "\n- " + "\n- ".join(list(subqueries))
|
||||
subqueries_str = "\n- " + "\n- ".join(subqueries)
|
||||
async for event in send_status_func(f"**Searching the Internet for**: {subqueries_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
with timer(f"Internet searches for {list(subqueries)} took", logger):
|
||||
with timer(f"Internet searches for {subqueries} took", logger):
|
||||
search_func = search_with_google if SERPER_DEV_API_KEY else search_with_jina
|
||||
search_tasks = [search_func(subquery, location) for subquery in subqueries]
|
||||
search_results = await asyncio.gather(*search_tasks)
|
||||
|
@ -115,7 +117,9 @@ async def search_online(
|
|||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [
|
||||
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent, tracer=tracer)
|
||||
read_webpage_and_extract_content(
|
||||
data["queries"], link, data.get("content"), user=user, agent=agent, tracer=tracer
|
||||
)
|
||||
for link, data in webpages.items()
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
|
|
@ -6,7 +6,7 @@ import os
|
|||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
from typing import Any, Callable, List, Optional, Set, Union
|
||||
|
||||
import cron_descriptor
|
||||
import pytz
|
||||
|
@ -349,6 +349,7 @@ async def extract_references_and_questions(
|
|||
location_data: LocationData = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
previous_inferred_queries: Set = set(),
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
|
@ -477,6 +478,7 @@ async def extract_references_and_questions(
|
|||
)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
inferred_queries = list(set(inferred_queries) - previous_inferred_queries)
|
||||
with timer("Searching knowledge base took", logger):
|
||||
search_results = []
|
||||
logger.info(f"🔍 Searching knowledge base with queries: {inferred_queries}")
|
||||
|
|
|
@ -20,6 +20,7 @@ from typing import (
|
|||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
@ -494,7 +495,7 @@ async def generate_online_subqueries(
|
|||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Generate subqueries from the given query
|
||||
"""
|
||||
|
@ -529,14 +530,14 @@ async def generate_online_subqueries(
|
|||
try:
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response or len(response) == 0:
|
||||
response = {q.strip() for q in response["queries"] if q.strip()}
|
||||
if not isinstance(response, set) or not response or len(response) == 0:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
|
||||
return [q]
|
||||
return {q}
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
|
||||
return [q]
|
||||
return {q}
|
||||
|
||||
|
||||
async def schedule_query(
|
||||
|
@ -1128,9 +1129,6 @@ def generate_chat_response(
|
|||
|
||||
metadata = {}
|
||||
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||
query_to_run = q
|
||||
if meta_research:
|
||||
query_to_run = f"AI Research: {meta_research} {q}"
|
||||
try:
|
||||
partial_completion = partial(
|
||||
save_to_conversation_log,
|
||||
|
@ -1148,6 +1146,13 @@ def generate_chat_response(
|
|||
train_of_thought=train_of_thought,
|
||||
)
|
||||
|
||||
query_to_run = q
|
||||
if meta_research:
|
||||
query_to_run = f"<query>{q}</query>\n<collected_research>\n{meta_research}\n</collected_research>"
|
||||
compiled_references = []
|
||||
online_results = {}
|
||||
code_results = {}
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
if not vision_available and query_images:
|
||||
|
|
|
@ -146,6 +146,7 @@ async def execute_information_collection(
|
|||
code_results: Dict = dict()
|
||||
document_results: List[Dict[str, str]] = []
|
||||
summarize_files: str = ""
|
||||
warning_results: str = ""
|
||||
this_iteration = InformationCollectionIteration(tool=None, query=query)
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||
|
||||
|
@ -167,9 +168,15 @@ async def execute_information_collection(
|
|||
elif isinstance(result, InformationCollectionIteration):
|
||||
this_iteration = result
|
||||
|
||||
# Detect, skip running query, tool combinations already executed during a previous iteration.
|
||||
previous_tool_query_combinations = {(i.tool, i.query) for i in previous_iterations}
|
||||
if (this_iteration.tool, this_iteration.query) in previous_tool_query_combinations:
|
||||
warning_results = f"Repeated tool, query combination detected. Skipping."
|
||||
|
||||
if this_iteration.tool == ConversationCommand.Notes:
|
||||
this_iteration.context = []
|
||||
document_results = []
|
||||
previous_inferred_queries = {iteration.context for iteration in previous_iterations}
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
|
||||
|
@ -181,6 +188,7 @@ async def execute_information_collection(
|
|||
location,
|
||||
send_status_func,
|
||||
query_images,
|
||||
previous_inferred_queries=previous_inferred_queries,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
|
@ -204,6 +212,12 @@ async def execute_information_collection(
|
|||
logger.error(f"Error extracting document references: {e}", exc_info=True)
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Online:
|
||||
previous_subqueries = {
|
||||
q
|
||||
for iteration in previous_iterations
|
||||
for q in iteration.onlineContext.keys()
|
||||
if iteration.onlineContext
|
||||
}
|
||||
async for result in search_online(
|
||||
this_iteration.query,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
|
||||
|
@ -213,6 +227,7 @@ async def execute_information_collection(
|
|||
[],
|
||||
max_webpages_to_read=0,
|
||||
query_images=query_images,
|
||||
previous_subqueries=previous_subqueries,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
|
@ -302,16 +317,19 @@ async def execute_information_collection(
|
|||
|
||||
current_iteration += 1
|
||||
|
||||
if document_results or online_results or code_results or summarize_files:
|
||||
results_data = f"**Results**:\n"
|
||||
if document_results or online_results or code_results or summarize_files or warning_results:
|
||||
results_data = f"\n<iteration>{current_iteration}\n<tool>{this_iteration.tool}</tool>\n<query>{this_iteration.query}</query>\n<results>"
|
||||
if document_results:
|
||||
results_data += f"**Document References**:\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
|
||||
results_data += f"\n<document_references>\n{yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</document_references>"
|
||||
if online_results:
|
||||
results_data += f"**Online Results**:\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
|
||||
results_data += f"\n<online_results>\n{yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</online_results>"
|
||||
if code_results:
|
||||
results_data += f"**Code Results**:\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
|
||||
results_data += f"\n<code_results>\n{yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</code_results>"
|
||||
if summarize_files:
|
||||
results_data += f"**Summarized Files**:\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
|
||||
results_data += f"\n<summarized_files>\n{yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n</summarized_files>"
|
||||
if warning_results:
|
||||
results_data += f"\n<warning>\n{warning_results}\n</warning>"
|
||||
results_data += "\n</results>\n</iteration>"
|
||||
|
||||
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
|
||||
this_iteration.summarizedResult = results_data
|
||||
|
|
Loading…
Add table
Reference in a new issue