mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Make Chat and Online Search Resilient and Faster (#936)
## Overview ### New - Support using Firecrawl(https://firecrawl.dev) to read web pages - Add, switch and re-prioritize web page reader(s) to use via the admin panel ### Speed - Improve response speed by aggregating web page read, extract queries to run only once for each web page ### Response Resilience - Fallback through enabled web page readers until web page read - Enable reading web pages on the internal network for self-hosted Khoj running in anonymous mode - Try respond even if web search, web page read fails during chat - Try respond even if document search via inference endpoint fails ### Fix - Return data sources to use if exception in data source chat actor ## Details ### Configure web page readers to use - Only the web scraper set in Server Chat Settings via the Django admin panel, if set - Otherwise use the web scrapers added via the Django admin panel (in order of priority), if set - Otherwise, use all the web scrapers enabled by settings API keys via environment variables (e.g `FIRECRAWL_API_KEY', `JINA_API_KEY' env vars set), if set - Otherwise, use Jina to web scrape if no scrapers explicitly defined For self-hosted setups running in anonymous-mode, the ability to directly read webpages is also enabled by default. This is especially useful for reading webpages in your internal network that the other web page readers will not be able to access. ### Aggregate webpage extract queries to run once for each distinct web page Previously, we'd run separate webpage read and extract relevant content pipes for each distinct (query, url) pair. Now we aggregate all queries for each url to extract information from and run the webpage read and extract relevant content pipes once for each distinct URL. Even though the webpage content extraction pipes were previously being run in parallel. They increased the response time by 1. adding more ~duplicate context for the response generation step to read 2. being more susceptible to variability in web page read latency of the parallel jobs The aggregated retrieval of context for all queries for a given webpage could result in some hit to context quality. But it should improve and reduce variability in response time, quality and costs. This should especially help with speed and quality of online search for offline or low context chat models.
This commit is contained in:
commit
7fb4c2939d
10 changed files with 493 additions and 88 deletions
|
@ -1,6 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import secrets
|
||||
|
@ -10,7 +11,6 @@ from enum import Enum
|
|||
from typing import Callable, Iterable, List, Optional, Type
|
||||
|
||||
import cron_descriptor
|
||||
import django
|
||||
from apscheduler.job import Job
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
|
@ -52,6 +52,7 @@ from khoj.database.models import (
|
|||
UserTextToImageModelConfig,
|
||||
UserVoiceModelConfig,
|
||||
VoiceModelOption,
|
||||
WebScraper,
|
||||
)
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
|
@ -59,7 +60,12 @@ from khoj.search_filter.file_filter import FileFilter
|
|||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import OfflineChatProcessorModel
|
||||
from khoj.utils.helpers import generate_random_name, is_none_or_empty, timer
|
||||
from khoj.utils.helpers import (
|
||||
generate_random_name,
|
||||
in_debug_mode,
|
||||
is_none_or_empty,
|
||||
timer,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -1031,6 +1037,70 @@ class ConversationAdapters:
|
|||
return server_chat_settings.chat_advanced
|
||||
return await ConversationAdapters.aget_default_conversation_config(user)
|
||||
|
||||
@staticmethod
|
||||
async def aget_server_webscraper():
|
||||
server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst()
|
||||
if server_chat_settings is not None and server_chat_settings.web_scraper is not None:
|
||||
return server_chat_settings.web_scraper
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def aget_enabled_webscrapers() -> list[WebScraper]:
|
||||
enabled_scrapers: list[WebScraper] = []
|
||||
server_webscraper = await ConversationAdapters.aget_server_webscraper()
|
||||
if server_webscraper:
|
||||
# Only use the webscraper set in the server chat settings
|
||||
enabled_scrapers = [server_webscraper]
|
||||
if not enabled_scrapers:
|
||||
# Use the enabled web scrapers, ordered by priority, until get web page content
|
||||
enabled_scrapers = [scraper async for scraper in WebScraper.objects.all().order_by("priority").aiterator()]
|
||||
if not enabled_scrapers:
|
||||
# Use scrapers enabled via environment variables
|
||||
if os.getenv("FIRECRAWL_API_KEY"):
|
||||
api_url = os.getenv("FIRECRAWL_API_URL", "https://api.firecrawl.dev")
|
||||
enabled_scrapers.append(
|
||||
WebScraper(
|
||||
type=WebScraper.WebScraperType.FIRECRAWL,
|
||||
name=WebScraper.WebScraperType.FIRECRAWL.capitalize(),
|
||||
api_key=os.getenv("FIRECRAWL_API_KEY"),
|
||||
api_url=api_url,
|
||||
)
|
||||
)
|
||||
if os.getenv("OLOSTEP_API_KEY"):
|
||||
api_url = os.getenv("OLOSTEP_API_URL", "https://agent.olostep.com/olostep-p2p-incomingAPI")
|
||||
enabled_scrapers.append(
|
||||
WebScraper(
|
||||
type=WebScraper.WebScraperType.OLOSTEP,
|
||||
name=WebScraper.WebScraperType.OLOSTEP.capitalize(),
|
||||
api_key=os.getenv("OLOSTEP_API_KEY"),
|
||||
api_url=api_url,
|
||||
)
|
||||
)
|
||||
# Jina is the default fallback scrapers to use as it does not require an API key
|
||||
api_url = os.getenv("JINA_READER_API_URL", "https://r.jina.ai/")
|
||||
enabled_scrapers.append(
|
||||
WebScraper(
|
||||
type=WebScraper.WebScraperType.JINA,
|
||||
name=WebScraper.WebScraperType.JINA.capitalize(),
|
||||
api_key=os.getenv("JINA_API_KEY"),
|
||||
api_url=api_url,
|
||||
)
|
||||
)
|
||||
|
||||
# Only enable the direct web page scraper by default in self-hosted single user setups.
|
||||
# Useful for reading webpages on your intranet.
|
||||
if state.anonymous_mode or in_debug_mode():
|
||||
enabled_scrapers.append(
|
||||
WebScraper(
|
||||
type=WebScraper.WebScraperType.DIRECT,
|
||||
name=WebScraper.WebScraperType.DIRECT.capitalize(),
|
||||
api_key=None,
|
||||
api_url=None,
|
||||
)
|
||||
)
|
||||
|
||||
return enabled_scrapers
|
||||
|
||||
@staticmethod
|
||||
def create_conversation_from_public_conversation(
|
||||
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
|
||||
|
|
|
@ -31,6 +31,7 @@ from khoj.database.models import (
|
|||
UserSearchModelConfig,
|
||||
UserVoiceModelConfig,
|
||||
VoiceModelOption,
|
||||
WebScraper,
|
||||
)
|
||||
from khoj.utils.helpers import ImageIntentType
|
||||
|
||||
|
@ -198,9 +199,24 @@ class ServerChatSettingsAdmin(admin.ModelAdmin):
|
|||
list_display = (
|
||||
"chat_default",
|
||||
"chat_advanced",
|
||||
"web_scraper",
|
||||
)
|
||||
|
||||
|
||||
@admin.register(WebScraper)
|
||||
class WebScraperAdmin(admin.ModelAdmin):
|
||||
list_display = (
|
||||
"priority",
|
||||
"name",
|
||||
"type",
|
||||
"api_key",
|
||||
"api_url",
|
||||
"created_at",
|
||||
)
|
||||
search_fields = ("name", "api_key", "api_url", "type")
|
||||
ordering = ("priority",)
|
||||
|
||||
|
||||
@admin.register(Conversation)
|
||||
class ConversationAdmin(admin.ModelAdmin):
|
||||
list_display = (
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
# Generated by Django 5.0.8 on 2024-10-18 00:41
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0068_alter_agent_output_modes"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="WebScraper",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"name",
|
||||
models.CharField(
|
||||
blank=True,
|
||||
default=None,
|
||||
help_text="Friendly name. If not set, it will be set to the type of the scraper.",
|
||||
max_length=200,
|
||||
null=True,
|
||||
unique=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"type",
|
||||
models.CharField(
|
||||
choices=[
|
||||
("Firecrawl", "Firecrawl"),
|
||||
("Olostep", "Olostep"),
|
||||
("Jina", "Jina"),
|
||||
("Direct", "Direct"),
|
||||
],
|
||||
default="Jina",
|
||||
max_length=20,
|
||||
),
|
||||
),
|
||||
(
|
||||
"api_key",
|
||||
models.CharField(
|
||||
blank=True,
|
||||
default=None,
|
||||
help_text="API key of the web scraper. Only set if scraper service requires an API key. Default is set from env var.",
|
||||
max_length=200,
|
||||
null=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"api_url",
|
||||
models.URLField(
|
||||
blank=True,
|
||||
default=None,
|
||||
help_text="API URL of the web scraper. Only set if scraper service on non-default URL.",
|
||||
null=True,
|
||||
),
|
||||
),
|
||||
(
|
||||
"priority",
|
||||
models.IntegerField(
|
||||
blank=True,
|
||||
default=None,
|
||||
help_text="Priority of the web scraper. Lower numbers run first.",
|
||||
null=True,
|
||||
unique=True,
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="serverchatsettings",
|
||||
name="web_scraper",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
related_name="web_scraper",
|
||||
to="database.webscraper",
|
||||
),
|
||||
),
|
||||
]
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import re
|
||||
import uuid
|
||||
from random import choice
|
||||
|
@ -11,8 +12,6 @@ from django.dispatch import receiver
|
|||
from pgvector.django import VectorField
|
||||
from phonenumber_field.modelfields import PhoneNumberField
|
||||
|
||||
from khoj.utils.helpers import ConversationCommand
|
||||
|
||||
|
||||
class BaseModel(models.Model):
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
@ -244,6 +243,79 @@ class GithubRepoConfig(BaseModel):
|
|||
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
|
||||
|
||||
|
||||
class WebScraper(BaseModel):
|
||||
class WebScraperType(models.TextChoices):
|
||||
FIRECRAWL = "Firecrawl"
|
||||
OLOSTEP = "Olostep"
|
||||
JINA = "Jina"
|
||||
DIRECT = "Direct"
|
||||
|
||||
name = models.CharField(
|
||||
max_length=200,
|
||||
default=None,
|
||||
null=True,
|
||||
blank=True,
|
||||
unique=True,
|
||||
help_text="Friendly name. If not set, it will be set to the type of the scraper.",
|
||||
)
|
||||
type = models.CharField(max_length=20, choices=WebScraperType.choices, default=WebScraperType.JINA)
|
||||
api_key = models.CharField(
|
||||
max_length=200,
|
||||
default=None,
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text="API key of the web scraper. Only set if scraper service requires an API key. Default is set from env var.",
|
||||
)
|
||||
api_url = models.URLField(
|
||||
max_length=200,
|
||||
default=None,
|
||||
null=True,
|
||||
blank=True,
|
||||
help_text="API URL of the web scraper. Only set if scraper service on non-default URL.",
|
||||
)
|
||||
priority = models.IntegerField(
|
||||
default=None,
|
||||
null=True,
|
||||
blank=True,
|
||||
unique=True,
|
||||
help_text="Priority of the web scraper. Lower numbers run first.",
|
||||
)
|
||||
|
||||
def clean(self):
|
||||
error = {}
|
||||
if self.name is None:
|
||||
self.name = self.type.capitalize()
|
||||
if self.api_url is None:
|
||||
if self.type == self.WebScraperType.FIRECRAWL:
|
||||
self.api_url = os.getenv("FIRECRAWL_API_URL", "https://api.firecrawl.dev")
|
||||
elif self.type == self.WebScraperType.OLOSTEP:
|
||||
self.api_url = os.getenv("OLOSTEP_API_URL", "https://agent.olostep.com/olostep-p2p-incomingAPI")
|
||||
elif self.type == self.WebScraperType.JINA:
|
||||
self.api_url = os.getenv("JINA_READER_API_URL", "https://r.jina.ai/")
|
||||
if self.api_key is None:
|
||||
if self.type == self.WebScraperType.FIRECRAWL:
|
||||
self.api_key = os.getenv("FIRECRAWL_API_KEY")
|
||||
if not self.api_key and self.api_url == "https://api.firecrawl.dev":
|
||||
error["api_key"] = "Set API key to use default Firecrawl. Get API key from https://firecrawl.dev."
|
||||
elif self.type == self.WebScraperType.OLOSTEP:
|
||||
self.api_key = os.getenv("OLOSTEP_API_KEY")
|
||||
if self.api_key is None:
|
||||
error["api_key"] = "Set API key to use Olostep. Get API key from https://olostep.com/."
|
||||
elif self.type == self.WebScraperType.JINA:
|
||||
self.api_key = os.getenv("JINA_API_KEY")
|
||||
if error:
|
||||
raise ValidationError(error)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
self.clean()
|
||||
|
||||
if self.priority is None:
|
||||
max_priority = WebScraper.objects.aggregate(models.Max("priority"))["priority__max"]
|
||||
self.priority = max_priority + 1 if max_priority else 1
|
||||
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class ServerChatSettings(BaseModel):
|
||||
chat_default = models.ForeignKey(
|
||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
|
||||
|
@ -251,6 +323,9 @@ class ServerChatSettings(BaseModel):
|
|||
chat_advanced = models.ForeignKey(
|
||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
|
||||
)
|
||||
web_scraper = models.ForeignKey(
|
||||
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
|
||||
)
|
||||
|
||||
|
||||
class LocalOrgConfig(BaseModel):
|
||||
|
|
|
@ -114,6 +114,7 @@ class CrossEncoderModel:
|
|||
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
response = requests.post(target_url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()["scores"]
|
||||
|
||||
cross_inp = [[query, hit.additional[key]] for hit in hits]
|
||||
|
|
|
@ -10,14 +10,22 @@ import aiohttp
|
|||
from bs4 import BeautifulSoup
|
||||
from markdownify import markdownify
|
||||
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import Agent, KhojUser, WebScraper
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
extract_relevant_info,
|
||||
generate_online_subqueries,
|
||||
infer_webpage_urls,
|
||||
)
|
||||
from khoj.utils.helpers import is_internet_connected, is_none_or_empty, timer
|
||||
from khoj.utils.helpers import (
|
||||
is_env_var_true,
|
||||
is_internal_url,
|
||||
is_internet_connected,
|
||||
is_none_or_empty,
|
||||
timer,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -25,12 +33,11 @@ logger = logging.getLogger(__name__)
|
|||
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
|
||||
SERPER_DEV_URL = "https://google.serper.dev/search"
|
||||
|
||||
JINA_READER_API_URL = "https://r.jina.ai/"
|
||||
JINA_SEARCH_API_URL = "https://s.jina.ai/"
|
||||
JINA_API_KEY = os.getenv("JINA_API_KEY")
|
||||
|
||||
OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY")
|
||||
OLOSTEP_API_URL = "https://agent.olostep.com/olostep-p2p-incomingAPI"
|
||||
FIRECRAWL_USE_LLM_EXTRACT = is_env_var_true("FIRECRAWL_USE_LLM_EXTRACT")
|
||||
|
||||
OLOSTEP_QUERY_PARAMS = {
|
||||
"timeout": 35, # seconds
|
||||
"waitBeforeScraping": 1, # seconds
|
||||
|
@ -83,33 +90,36 @@ async def search_online(
|
|||
search_results = await asyncio.gather(*search_tasks)
|
||||
response_dict = {subquery: search_result for subquery, search_result in search_results}
|
||||
|
||||
# Gather distinct web page data from organic results of each subquery without an instant answer.
|
||||
# Gather distinct web pages from organic results for subqueries without an instant answer.
|
||||
# Content of web pages is directly available when Jina is used for search.
|
||||
webpages = {
|
||||
(organic.get("link"), subquery, organic.get("content"))
|
||||
for subquery in response_dict
|
||||
for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]
|
||||
if "answerBox" not in response_dict[subquery]
|
||||
}
|
||||
webpages: Dict[str, Dict] = {}
|
||||
for subquery in response_dict:
|
||||
if "answerBox" in response_dict[subquery]:
|
||||
continue
|
||||
for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]:
|
||||
link = organic.get("link")
|
||||
if link in webpages:
|
||||
webpages[link]["queries"].add(subquery)
|
||||
else:
|
||||
webpages[link] = {"queries": {subquery}, "content": organic.get("content")}
|
||||
|
||||
# Read, extract relevant info from the retrieved web pages
|
||||
if webpages:
|
||||
webpage_links = set([link for link, _, _ in webpages])
|
||||
logger.info(f"Reading web pages at: {list(webpage_links)}")
|
||||
logger.info(f"Reading web pages at: {webpages.keys()}")
|
||||
if send_status_func:
|
||||
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
||||
webpage_links_str = "\n- " + "\n- ".join(webpages.keys())
|
||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [
|
||||
read_webpage_and_extract_content(subquery, link, content, user=user, agent=agent)
|
||||
for link, subquery, content in webpages
|
||||
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
|
||||
for link, data in webpages.items()
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Collect extracted info from the retrieved web pages
|
||||
for subquery, webpage_extract, url in results:
|
||||
for subqueries, url, webpage_extract in results:
|
||||
if webpage_extract is not None:
|
||||
response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract}
|
||||
response_dict[subqueries.pop()]["webpages"] = {"link": url, "snippet": webpage_extract}
|
||||
|
||||
yield response_dict
|
||||
|
||||
|
@ -156,29 +166,66 @@ async def read_webpages(
|
|||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [read_webpage_and_extract_content(query, url, user=user, agent=agent) for url in urls]
|
||||
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
response: Dict[str, Dict] = defaultdict(dict)
|
||||
response[query]["webpages"] = [
|
||||
{"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None
|
||||
{"query": qs.pop(), "link": url, "snippet": extract} for qs, url, extract in results if extract is not None
|
||||
]
|
||||
yield response
|
||||
|
||||
|
||||
async def read_webpage(
|
||||
url, scraper_type=None, api_key=None, api_url=None, subqueries=None, agent=None
|
||||
) -> Tuple[str | None, str | None]:
|
||||
if scraper_type == WebScraper.WebScraperType.FIRECRAWL and FIRECRAWL_USE_LLM_EXTRACT:
|
||||
return None, await query_webpage_with_firecrawl(url, subqueries, api_key, api_url, agent)
|
||||
elif scraper_type == WebScraper.WebScraperType.FIRECRAWL:
|
||||
return await read_webpage_with_firecrawl(url, api_key, api_url), None
|
||||
elif scraper_type == WebScraper.WebScraperType.OLOSTEP:
|
||||
return await read_webpage_with_olostep(url, api_key, api_url), None
|
||||
elif scraper_type == WebScraper.WebScraperType.JINA:
|
||||
return await read_webpage_with_jina(url, api_key, api_url), None
|
||||
else:
|
||||
return await read_webpage_at_url(url), None
|
||||
|
||||
|
||||
async def read_webpage_and_extract_content(
|
||||
subquery: str, url: str, content: str = None, user: KhojUser = None, agent: Agent = None
|
||||
) -> Tuple[str, Union[None, str], str]:
|
||||
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
|
||||
) -> Tuple[set[str], str, Union[None, str]]:
|
||||
# Select the web scrapers to use for reading the web page
|
||||
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
|
||||
# Only use the direct web scraper for internal URLs
|
||||
if is_internal_url(url):
|
||||
web_scrapers = [scraper for scraper in web_scrapers if scraper.type == WebScraper.WebScraperType.DIRECT]
|
||||
|
||||
# Fallback through enabled web scrapers until we successfully read the web page
|
||||
extracted_info = None
|
||||
for scraper in web_scrapers:
|
||||
try:
|
||||
# Read the web page
|
||||
if is_none_or_empty(content):
|
||||
with timer(f"Reading web page at '{url}' took", logger):
|
||||
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
|
||||
with timer(f"Reading web page with {scraper.type} at '{url}' took", logger, log_level=logging.INFO):
|
||||
content, extracted_info = await read_webpage(
|
||||
url, scraper.type, scraper.api_key, scraper.api_url, subqueries, agent
|
||||
)
|
||||
|
||||
# Extract relevant information from the web page
|
||||
if is_none_or_empty(extracted_info):
|
||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||
extracted_info = await extract_relevant_info(subquery, content, user=user, agent=agent)
|
||||
return subquery, extracted_info, url
|
||||
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
|
||||
|
||||
# If we successfully extracted information, break the loop
|
||||
if not is_none_or_empty(extracted_info):
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read web page at '{url}' with {e}")
|
||||
return subquery, None, url
|
||||
logger.warning(f"Failed to read web page with {scraper.type} at '{url}' with {e}")
|
||||
# If this is the last web scraper in the list, log an error
|
||||
if scraper.name == web_scrapers[-1].name:
|
||||
logger.error(f"All web scrapers failed for '{url}'")
|
||||
|
||||
return subqueries, url, extracted_info
|
||||
|
||||
|
||||
async def read_webpage_at_url(web_url: str) -> str:
|
||||
|
@ -195,23 +242,23 @@ async def read_webpage_at_url(web_url: str) -> str:
|
|||
return markdownify(body)
|
||||
|
||||
|
||||
async def read_webpage_with_olostep(web_url: str) -> str:
|
||||
headers = {"Authorization": f"Bearer {OLOSTEP_API_KEY}"}
|
||||
async def read_webpage_with_olostep(web_url: str, api_key: str, api_url: str) -> str:
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
web_scraping_params: Dict[str, Union[str, int, bool]] = OLOSTEP_QUERY_PARAMS.copy() # type: ignore
|
||||
web_scraping_params["url"] = web_url
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(OLOSTEP_API_URL, params=web_scraping_params, headers=headers) as response:
|
||||
async with session.get(api_url, params=web_scraping_params, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
response_json = await response.json()
|
||||
return response_json["markdown_content"]
|
||||
|
||||
|
||||
async def read_webpage_with_jina(web_url: str) -> str:
|
||||
jina_reader_api_url = f"{JINA_READER_API_URL}/{web_url}"
|
||||
async def read_webpage_with_jina(web_url: str, api_key: str, api_url: str) -> str:
|
||||
jina_reader_api_url = f"{api_url}/{web_url}"
|
||||
headers = {"Accept": "application/json", "X-Timeout": "30"}
|
||||
if JINA_API_KEY:
|
||||
headers["Authorization"] = f"Bearer {JINA_API_KEY}"
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(jina_reader_api_url, headers=headers) as response:
|
||||
|
@ -220,6 +267,54 @@ async def read_webpage_with_jina(web_url: str) -> str:
|
|||
return response_json["data"]["content"]
|
||||
|
||||
|
||||
async def read_webpage_with_firecrawl(web_url: str, api_key: str, api_url: str) -> str:
|
||||
firecrawl_api_url = f"{api_url}/v1/scrape"
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
params = {"url": web_url, "formats": ["markdown"], "excludeTags": ["script", ".ad"]}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(firecrawl_api_url, json=params, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
response_json = await response.json()
|
||||
return response_json["data"]["markdown"]
|
||||
|
||||
|
||||
async def query_webpage_with_firecrawl(
|
||||
web_url: str, queries: set[str], api_key: str, api_url: str, agent: Agent = None
|
||||
) -> str:
|
||||
firecrawl_api_url = f"{api_url}/v1/scrape"
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"relevant_extract": {"type": "string"},
|
||||
},
|
||||
"required": [
|
||||
"relevant_extract",
|
||||
],
|
||||
}
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
system_prompt = f"""
|
||||
{prompts.system_prompt_extract_relevant_information}
|
||||
|
||||
{personality_context}
|
||||
User Query: {", ".join(queries)}
|
||||
|
||||
Collate only relevant information from the website to answer the target query and in the provided JSON schema.
|
||||
""".strip()
|
||||
|
||||
params = {"url": web_url, "formats": ["extract"], "extract": {"systemPrompt": system_prompt, "schema": schema}}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(firecrawl_api_url, json=params, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
response_json = await response.json()
|
||||
return response_json["data"]["extract"]["relevant_extract"]
|
||||
|
||||
|
||||
async def search_with_jina(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]:
|
||||
encoded_query = urllib.parse.quote(query)
|
||||
jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}"
|
||||
|
|
|
@ -3,7 +3,6 @@ import base64
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
|
@ -574,7 +573,6 @@ async def chat(
|
|||
chat_metadata: dict = {}
|
||||
connection_alive = True
|
||||
user: KhojUser = request.user.object
|
||||
subscribed: bool = has_required_scope(request, ["premium"])
|
||||
event_delimiter = "␃🔚␗"
|
||||
q = unquote(q)
|
||||
nonlocal conversation_id
|
||||
|
@ -641,7 +639,7 @@ async def chat(
|
|||
request=request,
|
||||
telemetry_type="api",
|
||||
api="chat",
|
||||
client=request.user.client_app,
|
||||
client=common.client,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
host=request.headers.get("host"),
|
||||
metadata=chat_metadata,
|
||||
|
@ -840,6 +838,7 @@ async def chat(
|
|||
# Gather Context
|
||||
## Extract Document References
|
||||
compiled_references, inferred_queries, defiltered_query = [], [], None
|
||||
try:
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
meta_log,
|
||||
|
@ -859,6 +858,13 @@ async def chat(
|
|||
compiled_references.extend(result[0])
|
||||
inferred_queries.extend(result[1])
|
||||
defiltered_query = result[2]
|
||||
except Exception as e:
|
||||
error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
||||
logger.warning(error_message)
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
||||
):
|
||||
yield result
|
||||
|
||||
if not is_none_or_empty(compiled_references):
|
||||
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||
|
@ -894,12 +900,13 @@ async def chat(
|
|||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
online_results = result
|
||||
except ValueError as e:
|
||||
except Exception as e:
|
||||
error_message = f"Error searching online: {e}. Attempting to respond without online results"
|
||||
logger.warning(error_message)
|
||||
async for result in send_llm_response(error_message):
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
|
||||
):
|
||||
yield result
|
||||
return
|
||||
|
||||
## Gather Webpage References
|
||||
if ConversationCommand.Webpage in conversation_commands:
|
||||
|
@ -928,11 +935,15 @@ async def chat(
|
|||
webpages.append(webpage["link"])
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
|
||||
yield result
|
||||
except ValueError as e:
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error directly reading webpages: {e}. Attempting to respond without online results",
|
||||
f"Error reading webpages: {e}. Attempting to respond without webpage results",
|
||||
exc_info=True,
|
||||
)
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
|
||||
):
|
||||
yield result
|
||||
|
||||
## Send Gathered References
|
||||
async for result in send_event(
|
||||
|
|
|
@ -353,13 +353,13 @@ async def aget_relevant_information_sources(
|
|||
final_response = [ConversationCommand.Default]
|
||||
else:
|
||||
final_response = [ConversationCommand.General]
|
||||
return final_response
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}")
|
||||
if len(agent_tools) == 0:
|
||||
final_response = [ConversationCommand.Default]
|
||||
else:
|
||||
final_response = agent_tools
|
||||
return final_response
|
||||
|
||||
|
||||
async def aget_relevant_output_modes(
|
||||
|
@ -551,12 +551,14 @@ async def schedule_query(
|
|||
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
||||
|
||||
|
||||
async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agent: Agent = None) -> Union[str, None]:
|
||||
async def extract_relevant_info(
|
||||
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None
|
||||
) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
"""
|
||||
|
||||
if is_none_or_empty(corpus) or is_none_or_empty(q):
|
||||
if is_none_or_empty(corpus) or is_none_or_empty(qs):
|
||||
return None
|
||||
|
||||
personality_context = (
|
||||
|
@ -564,12 +566,11 @@ async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agen
|
|||
)
|
||||
|
||||
extract_relevant_information = prompts.extract_relevant_information.format(
|
||||
query=q,
|
||||
query=", ".join(qs),
|
||||
corpus=corpus.strip(),
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Extract relevant information from data", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_information,
|
||||
|
|
|
@ -3,6 +3,7 @@ import math
|
|||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from asgiref.sync import sync_to_async
|
||||
from sentence_transformers import util
|
||||
|
@ -231,8 +232,12 @@ def setup(
|
|||
|
||||
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
|
||||
"""Score all retrieved entries using the cross-encoder"""
|
||||
try:
|
||||
with timer("Cross-Encoder Predict Time", logger, state.device):
|
||||
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True)
|
||||
cross_scores = [0.0] * len(hits)
|
||||
|
||||
# Convert cross-encoder scores to distances and pass in hits for reranking
|
||||
for idx in range(len(cross_scores)):
|
||||
|
|
|
@ -2,10 +2,12 @@ from __future__ import annotations # to avoid quoting type hints
|
|||
|
||||
import datetime
|
||||
import io
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from enum import Enum
|
||||
|
@ -164,9 +166,9 @@ def get_class_by_name(name: str) -> object:
|
|||
class timer:
|
||||
"""Context manager to log time taken for a block of code to run"""
|
||||
|
||||
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
|
||||
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None, log_level=logging.DEBUG):
|
||||
self.message = message
|
||||
self.logger = logger
|
||||
self.logger = logger.debug if log_level == logging.DEBUG else logger.info
|
||||
self.device = device
|
||||
|
||||
def __enter__(self):
|
||||
|
@ -176,9 +178,9 @@ class timer:
|
|||
def __exit__(self, *_):
|
||||
elapsed = perf_counter() - self.start
|
||||
if self.device is None:
|
||||
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds")
|
||||
self.logger(f"{self.message}: {elapsed:.3f} seconds")
|
||||
else:
|
||||
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}")
|
||||
self.logger(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}")
|
||||
|
||||
|
||||
class LRU(OrderedDict):
|
||||
|
@ -436,6 +438,46 @@ def is_internet_connected():
|
|||
return False
|
||||
|
||||
|
||||
def is_internal_url(url: str) -> bool:
|
||||
"""
|
||||
Check if a URL is likely to be internal/non-public.
|
||||
|
||||
Args:
|
||||
url (str): The URL to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the URL is likely internal, False otherwise.
|
||||
"""
|
||||
try:
|
||||
parsed_url = urllib.parse.urlparse(url)
|
||||
hostname = parsed_url.hostname
|
||||
|
||||
# Check for localhost
|
||||
if hostname in ["localhost", "127.0.0.1", "::1"]:
|
||||
return True
|
||||
|
||||
# Check for IP addresses in private ranges
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
return ip.is_private
|
||||
except ValueError:
|
||||
pass # Not an IP address, continue with other checks
|
||||
|
||||
# Check for common internal TLDs
|
||||
internal_tlds = [".local", ".internal", ".private", ".corp", ".home", ".lan"]
|
||||
if any(hostname.endswith(tld) for tld in internal_tlds):
|
||||
return True
|
||||
|
||||
# Check for URLs without a TLD
|
||||
if "." not in hostname:
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
# If we can't parse the URL or something else goes wrong, assume it's not internal
|
||||
return False
|
||||
|
||||
|
||||
def convert_image_to_webp(image_bytes):
|
||||
"""Convert image bytes to webp format for faster loading"""
|
||||
image_io = io.BytesIO(image_bytes)
|
||||
|
|
Loading…
Reference in a new issue