diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index 6d0aaace..a37e3c3a 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -6,7 +6,6 @@ from collections import defaultdict from typing import Callable, Dict, List, Optional, Tuple, Union import aiohttp -import requests from bs4 import BeautifulSoup from markdownify import markdownify @@ -61,11 +60,16 @@ async def search_online( subqueries = await generate_online_subqueries(query, conversation_history, location) response_dict = {} - for subquery in subqueries: + if subqueries: + logger.info(f"🌐 Searching the Internet for {list(subqueries)}") if send_status_func: - await send_status_func(f"**🌐 Searching the Internet for**: {subquery}") - logger.info(f"🌐 Searching the Internet for '{subquery}'") - response_dict[subquery] = search_with_google(subquery) + subqueries_str = "\n- " + "\n- ".join(list(subqueries)) + await send_status_func(f"**🌐 Searching the Internet for**: {subqueries_str}") + + with timer(f"Internet searches for {list(subqueries)} took", logger): + search_tasks = [search_with_google(subquery) for subquery in subqueries] + search_results = await asyncio.gather(*search_tasks) + response_dict = {subquery: search_result for subquery, search_result in search_results} # Gather distinct web pages from organic search results of each subquery without an instant answer webpage_links = { @@ -92,23 +96,24 @@ async def search_online( return response_dict -def search_with_google(subquery: str): - payload = json.dumps({"q": subquery}) +async def search_with_google(query: str) -> Tuple[str, Dict[str, List[Dict]]]: + payload = json.dumps({"q": query}) headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"} - response = requests.request("POST", SERPER_DEV_URL, headers=headers, data=payload) + async with aiohttp.ClientSession() as session: + async with session.post(SERPER_DEV_URL, headers=headers, data=payload) as response: + if response.status != 200: + logger.error(await response.text()) + return query, {} + json_response = await response.json() + extraction_fields = ["organic", "answerBox", "peopleAlsoAsk", "knowledgeGraph"] + extracted_search_result = { + field: json_response[field] + for field in extraction_fields + if not is_none_or_empty(json_response.get(field)) + } - if response.status_code != 200: - logger.error(response.text) - return {} - - json_response = response.json() - extraction_fields = ["organic", "answerBox", "peopleAlsoAsk", "knowledgeGraph"] - extracted_search_result = { - field: json_response[field] for field in extraction_fields if not is_none_or_empty(json_response.get(field)) - } - - return extracted_search_result + return query, extracted_search_result async def read_webpages(