mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Simplify, modularize and add type hints to online search functions
- Simplify content arg to `extract_relevant_info' function. Validate, clean the content arg inside the `extract_relevant_info' function - Extract `search_with_google' function outside the parent function - Call the parent function a more appropriate `search_online' instead of `search_with_google' - Simplify the `search_with_google' function using list comprehension. Drop empty search result fields from chat model context for response to reduce cost and response latency - No need to show stacktrace when unable to read webpage, basic error is enough - Add type hints to online search functions to catch issues with mypy
This commit is contained in:
parent
88f096977b
commit
d136a6be44
5 changed files with 38 additions and 43 deletions
|
@ -247,7 +247,7 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
|
|||
|
||||
def send_message_to_model_offline(
|
||||
message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False, system_message=""
|
||||
):
|
||||
) -> str:
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError as e:
|
||||
|
|
|
@ -43,7 +43,7 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
|||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(**kwargs):
|
||||
def completion_with_backoff(**kwargs) -> str:
|
||||
messages = kwargs.pop("messages")
|
||||
if not "openai_api_key" in kwargs:
|
||||
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Union
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
@ -16,12 +16,10 @@ from khoj.utils.rawconfig import LocationData
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
|
||||
OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY")
|
||||
|
||||
SERPER_DEV_URL = "https://google.serper.dev/search"
|
||||
|
||||
OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY")
|
||||
OLOSTEP_API_URL = "https://agent.olostep.com/olostep-p2p-incomingAPI"
|
||||
|
||||
OLOSTEP_QUERY_PARAMS = {
|
||||
"timeout": 35, # seconds
|
||||
"waitBeforeScraping": 1, # seconds
|
||||
|
@ -39,31 +37,7 @@ OLOSTEP_QUERY_PARAMS = {
|
|||
MAX_WEBPAGES_TO_READ = 1
|
||||
|
||||
|
||||
async def search_with_google(query: str, conversation_history: dict, location: LocationData):
|
||||
def _search_with_google(subquery: str):
|
||||
payload = json.dumps(
|
||||
{
|
||||
"q": subquery,
|
||||
}
|
||||
)
|
||||
|
||||
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
|
||||
|
||||
response = requests.request("POST", SERPER_DEV_URL, headers=headers, data=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(response.text)
|
||||
return {}
|
||||
|
||||
json_response = response.json()
|
||||
sub_response_dict = {}
|
||||
sub_response_dict["knowledgeGraph"] = json_response.get("knowledgeGraph", {})
|
||||
sub_response_dict["organic"] = json_response.get("organic", [])
|
||||
sub_response_dict["answerBox"] = json_response.get("answerBox", [])
|
||||
sub_response_dict["peopleAlsoAsk"] = json_response.get("peopleAlsoAsk", [])
|
||||
|
||||
return sub_response_dict
|
||||
|
||||
async def search_online(query: str, conversation_history: dict, location: LocationData):
|
||||
if SERPER_DEV_API_KEY is None:
|
||||
logger.warn("SERPER_DEV_API_KEY is not set")
|
||||
return {}
|
||||
|
@ -74,14 +48,14 @@ async def search_with_google(query: str, conversation_history: dict, location: L
|
|||
|
||||
for subquery in subqueries:
|
||||
logger.info(f"Searching with Google for '{subquery}'")
|
||||
response_dict[subquery] = _search_with_google(subquery)
|
||||
response_dict[subquery] = search_with_google(subquery)
|
||||
|
||||
# Gather distinct web pages from organic search results of each subquery without an instant answer
|
||||
webpage_links = {
|
||||
result["link"]
|
||||
for subquery in response_dict
|
||||
for result in response_dict[subquery].get("organic")[:MAX_WEBPAGES_TO_READ]
|
||||
if is_none_or_empty(response_dict[subquery].get("answerBox"))
|
||||
for result in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]
|
||||
if "answerBox" not in response_dict[subquery]
|
||||
}
|
||||
|
||||
# Read, extract relevant info from the retrieved web pages
|
||||
|
@ -100,15 +74,34 @@ async def search_with_google(query: str, conversation_history: dict, location: L
|
|||
return response_dict
|
||||
|
||||
|
||||
async def read_webpage_and_extract_content(subquery, url):
|
||||
def search_with_google(subquery: str):
|
||||
payload = json.dumps({"q": subquery})
|
||||
headers = {"X-API-KEY": SERPER_DEV_API_KEY, "Content-Type": "application/json"}
|
||||
|
||||
response = requests.request("POST", SERPER_DEV_URL, headers=headers, data=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(response.text)
|
||||
return {}
|
||||
|
||||
json_response = response.json()
|
||||
extraction_fields = ["organic", "answerBox", "peopleAlsoAsk", "knowledgeGraph"]
|
||||
extracted_search_result = {
|
||||
field: json_response[field] for field in extraction_fields if not is_none_or_empty(json_response.get(field))
|
||||
}
|
||||
|
||||
return extracted_search_result
|
||||
|
||||
|
||||
async def read_webpage_and_extract_content(subquery: str, url: str) -> Tuple[str, Union[None, str]]:
|
||||
try:
|
||||
with timer(f"Reading web page at '{url}' took", logger):
|
||||
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage(url)
|
||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||
extracted_info = await extract_relevant_info(subquery, {subquery: [content.strip()]}) if content else None
|
||||
extracted_info = await extract_relevant_info(subquery, content)
|
||||
return subquery, extracted_info
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read web page at '{url}': {e}", exc_info=True)
|
||||
logger.error(f"Failed to read web page at '{url}' with {e}")
|
||||
return subquery, None
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ from khoj.database.adapters import ConversationAdapters, EntryAdapters, aget_use
|
|||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||
from khoj.processor.tools.online_search import search_with_google
|
||||
from khoj.processor.tools.online_search import search_online
|
||||
from khoj.routers.api import extract_references_and_questions
|
||||
from khoj.routers.helpers import (
|
||||
ApiUserRateLimiter,
|
||||
|
@ -284,7 +284,7 @@ async def chat(
|
|||
|
||||
if ConversationCommand.Online in conversation_commands:
|
||||
try:
|
||||
online_results = await search_with_google(defiltered_query, meta_log, location)
|
||||
online_results = await search_online(defiltered_query, meta_log, location)
|
||||
except ValueError as e:
|
||||
return StreamingResponse(
|
||||
iter(["Please set your SERPER_DEV_API_KEY to get started with online searches 🌐"]),
|
||||
|
|
|
@ -256,15 +256,17 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio
|
|||
return [q]
|
||||
|
||||
|
||||
async def extract_relevant_info(q: str, corpus: dict) -> List[str]:
|
||||
async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
|
||||
"""
|
||||
Given a target corpus, extract the most relevant info given a query
|
||||
Extract relevant information for a given query from the target corpus
|
||||
"""
|
||||
|
||||
key = list(corpus.keys())[0]
|
||||
if is_none_or_empty(corpus) or is_none_or_empty(q):
|
||||
return None
|
||||
|
||||
extract_relevant_information = prompts.extract_relevant_information.format(
|
||||
query=q,
|
||||
corpus=corpus[key],
|
||||
corpus=corpus.strip(),
|
||||
)
|
||||
|
||||
response = await send_message_to_model_wrapper(
|
||||
|
|
Loading…
Reference in a new issue