Simplify, modularize and add type hints to online search functions

- Simplify content arg to `extract_relevant_info' function. Validate,
  clean the content arg inside the `extract_relevant_info' function

- Extract `search_with_google' function outside the parent function
- Call the parent function a more appropriate `search_online' instead
  of `search_with_google'
- Simplify the `search_with_google' function using list comprehension.
  Drop empty search result fields from chat model context for response
  to reduce cost and response latency

- No need to show stacktrace when unable to read webpage, basic error
  is enough
- Add type hints to online search functions to catch issues with mypy
This commit is contained in:
Debanjum Singh Solanky 2024-03-10 02:09:11 +05:30
parent 88f096977b
commit d136a6be44
5 changed files with 38 additions and 43 deletions

View file

@ -247,7 +247,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
def send_message_to_model_offline(
message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False, system_message=""
):
) -> str:
try:
from gpt4all import GPT4All
except ModuleNotFoundError as e:

View file

@ -43,7 +43,7 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def completion_with_backoff(**kwargs):
def completion_with_backoff(**kwargs) -> str:
messages = kwargs.pop("messages")
if not "openai_api_key" in kwargs:
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")

View file

@ -2,7 +2,7 @@ import asyncio
import json
import logging
import os
from typing import Dict, Union
from typing import Dict, Tuple, Union
import aiohttp
import requests
@ -16,12 +16,10 @@ from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY")
SERPER_DEV_URL = "https://google.serper.dev/search"
OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY")
OLOSTEP_API_URL = "https://agent.olostep.com/olostep-p2p-incomingAPI"
OLOSTEP_QUERY_PARAMS = {
"timeout": 35, # seconds
"waitBeforeScraping": 1, # seconds
@ -39,31 +37,7 @@ OLOSTEP_QUERY_PARAMS = {
MAX_WEBPAGES_TO_READ = 1
async def search_with_google(query: str, conversation_history: dict, location: LocationData):
def _search_with_google(subquery: str):
payload = json.dumps(
{
"q": subquery,
}
)
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
response = requests.request("POST", SERPER_DEV_URL, headers=headers, data=payload)
if response.status_code != 200:
logger.error(response.text)
return {}
json_response = response.json()
sub_response_dict = {}
sub_response_dict["knowledgeGraph"] = json_response.get("knowledgeGraph", {})
sub_response_dict["organic"] = json_response.get("organic", [])
sub_response_dict["answerBox"] = json_response.get("answerBox", [])
sub_response_dict["peopleAlsoAsk"] = json_response.get("peopleAlsoAsk", [])
return sub_response_dict
async def search_online(query: str, conversation_history: dict, location: LocationData):
if SERPER_DEV_API_KEY is None:
logger.warn("SERPER_DEV_API_KEY is not set")
return {}
@ -74,14 +48,14 @@ async def search_with_google(query: str, conversation_history: dict, location: L
for subquery in subqueries:
logger.info(f"Searching with Google for '{subquery}'")
response_dict[subquery] = _search_with_google(subquery)
response_dict[subquery] = search_with_google(subquery)
# Gather distinct web pages from organic search results of each subquery without an instant answer
webpage_links = {
result["link"]
for subquery in response_dict
for result in response_dict[subquery].get("organic")[:MAX_WEBPAGES_TO_READ]
if is_none_or_empty(response_dict[subquery].get("answerBox"))
for result in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]
if "answerBox" not in response_dict[subquery]
}
# Read, extract relevant info from the retrieved web pages
@ -100,15 +74,34 @@ async def search_with_google(query: str, conversation_history: dict, location: L
return response_dict
async def read_webpage_and_extract_content(subquery, url):
def search_with_google(subquery: str):
payload = json.dumps({"q": subquery})
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
response = requests.request("POST", SERPER_DEV_URL, headers=headers, data=payload)
if response.status_code != 200:
logger.error(response.text)
return {}
json_response = response.json()
extraction_fields = ["organic", "answerBox", "peopleAlsoAsk", "knowledgeGraph"]
extracted_search_result = {
field: json_response[field] for field in extraction_fields if not is_none_or_empty(json_response.get(field))
}
return extracted_search_result
async def read_webpage_and_extract_content(subquery: str, url: str) -> Tuple[str, Union[None, str]]:
try:
with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, {subquery: [content.strip()]}) if content else None
extracted_info = await extract_relevant_info(subquery, content)
return subquery, extracted_info
except Exception as e:
logger.error(f"Failed to read web page at '{url}': {e}", exc_info=True)
logger.error(f"Failed to read web page at '{url}' with {e}")
return subquery, None

View file

@ -14,7 +14,7 @@ from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_use
from khoj.database.models import KhojUser
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.tools.online_search import search_with_google
from khoj.processor.tools.online_search import search_online
from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ApiUserRateLimiter,
@ -284,7 +284,7 @@ async def chat(
if ConversationCommand.Online in conversation_commands:
try:
online_results = await search_with_google(defiltered_query, meta_log, location)
online_results = await search_online(defiltered_query, meta_log, location)
except ValueError as e:
return StreamingResponse(
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),

View file

@ -256,15 +256,17 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio
return [q]
async def extract_relevant_info(q: str, corpus: dict) -> List[str]:
async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
"""
Given a target corpus, extract the most relevant info given a query
Extract relevant information for a given query from the target corpus
"""
key = list(corpus.keys())[0]
if is_none_or_empty(corpus) or is_none_or_empty(q):
return None
extract_relevant_information = prompts.extract_relevant_information.format(
query=q,
corpus=corpus[key],
corpus=corpus.strip(),
)
response = await send_message_to_model_wrapper(