Make Chat and Online Search Resilient and Faster (#936)

## Overview
### New
- Support using Firecrawl(https://firecrawl.dev) to read web pages
- Add, switch and re-prioritize web page reader(s) to use via the admin panel

### Speed
- Improve response speed by aggregating web page read, extract queries to run only once for each web page

### Response Resilience
- Fallback through enabled web page readers until web page read
- Enable reading web pages on the internal network for self-hosted Khoj running in anonymous mode
- Try respond even if web search, web page read fails during chat
- Try respond even if document search via inference endpoint fails

### Fix
- Return data sources to use if exception in data source chat actor

## Details
### Configure web page readers to use
 - Only the web scraper set in Server Chat Settings via the Django admin panel, if set
 - Otherwise use the web scrapers added via the Django admin panel (in order of priority), if set
 - Otherwise, use all the web scrapers enabled by settings API keys via environment variables (e.g `FIRECRAWL_API_KEY', `JINA_API_KEY' env vars set), if set
 - Otherwise, use Jina to web scrape if no scrapers explicitly defined
 
For self-hosted setups running in anonymous-mode, the ability to directly read webpages is also enabled by default. This is especially useful for reading webpages in your internal network that the other web page readers will not be able to access.

### Aggregate webpage extract queries to run once for each distinct web page

Previously, we'd run separate webpage read and extract relevant
content pipes for each distinct (query, url) pair.

Now we aggregate all queries for each url to extract information from
and run the webpage read and extract relevant content pipes once for
each distinct URL.

Even though the webpage content extraction pipes were previously being
run in parallel. They increased the response time by
1. adding more ~duplicate context for the response generation step to read
2. being more susceptible to variability in web page read latency of the parallel jobs

The aggregated retrieval of context for all queries for a given
webpage could result in some hit to context quality. But it should
improve and reduce variability in response time, quality and costs. 

This should especially help with speed and quality of online search 
for offline or low context chat models.
This commit is contained in:
Debanjum 2024-10-17 17:57:44 -07:00 committed by GitHub
commit 7fb4c2939d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 493 additions and 88 deletions

View file

@ -1,6 +1,7 @@
import json
import logging
import math
import os
import random
import re
import secrets
@ -10,7 +11,6 @@ from enum import Enum
from typing import Callable, Iterable, List, Optional, Type
import cron_descriptor
import django
from apscheduler.job import Job
from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore
@ -52,6 +52,7 @@ from khoj.database.models import (
UserTextToImageModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
WebScraper,
)
from khoj.processor.conversation import prompts
from khoj.search_filter.date_filter import DateFilter
@ -59,7 +60,12 @@ from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import generate_random_name, is_none_or_empty, timer
from khoj.utils.helpers import (
generate_random_name,
in_debug_mode,
is_none_or_empty,
timer,
)
logger = logging.getLogger(__name__)
@ -1031,6 +1037,70 @@ class ConversationAdapters:
return server_chat_settings.chat_advanced
return await ConversationAdapters.aget_default_conversation_config(user)
@staticmethod
async def aget_server_webscraper():
server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst()
if server_chat_settings is not None and server_chat_settings.web_scraper is not None:
return server_chat_settings.web_scraper
return None
@staticmethod
async def aget_enabled_webscrapers() -> list[WebScraper]:
enabled_scrapers: list[WebScraper] = []
server_webscraper = await ConversationAdapters.aget_server_webscraper()
if server_webscraper:
# Only use the webscraper set in the server chat settings
enabled_scrapers = [server_webscraper]
if not enabled_scrapers:
# Use the enabled web scrapers, ordered by priority, until get web page content
enabled_scrapers = [scraper async for scraper in WebScraper.objects.all().order_by("priority").aiterator()]
if not enabled_scrapers:
# Use scrapers enabled via environment variables
if os.getenv("FIRECRAWL_API_KEY"):
api_url = os.getenv("FIRECRAWL_API_URL", "https://api.firecrawl.dev")
enabled_scrapers.append(
WebScraper(
type=WebScraper.WebScraperType.FIRECRAWL,
name=WebScraper.WebScraperType.FIRECRAWL.capitalize(),
api_key=os.getenv("FIRECRAWL_API_KEY"),
api_url=api_url,
)
)
if os.getenv("OLOSTEP_API_KEY"):
api_url = os.getenv("OLOSTEP_API_URL", "https://agent.olostep.com/olostep-p2p-incomingAPI")
enabled_scrapers.append(
WebScraper(
type=WebScraper.WebScraperType.OLOSTEP,
name=WebScraper.WebScraperType.OLOSTEP.capitalize(),
api_key=os.getenv("OLOSTEP_API_KEY"),
api_url=api_url,
)
)
# Jina is the default fallback scrapers to use as it does not require an API key
api_url = os.getenv("JINA_READER_API_URL", "https://r.jina.ai/")
enabled_scrapers.append(
WebScraper(
type=WebScraper.WebScraperType.JINA,
name=WebScraper.WebScraperType.JINA.capitalize(),
api_key=os.getenv("JINA_API_KEY"),
api_url=api_url,
)
)
# Only enable the direct web page scraper by default in self-hosted single user setups.
# Useful for reading webpages on your intranet.
if state.anonymous_mode or in_debug_mode():
enabled_scrapers.append(
WebScraper(
type=WebScraper.WebScraperType.DIRECT,
name=WebScraper.WebScraperType.DIRECT.capitalize(),
api_key=None,
api_url=None,
)
)
return enabled_scrapers
@staticmethod
def create_conversation_from_public_conversation(
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication

View file

@ -31,6 +31,7 @@ from khoj.database.models import (
UserSearchModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
WebScraper,
)
from khoj.utils.helpers import ImageIntentType
@ -198,9 +199,24 @@ class ServerChatSettingsAdmin(admin.ModelAdmin):
list_display = (
"chat_default",
"chat_advanced",
"web_scraper",
)
@admin.register(WebScraper)
class WebScraperAdmin(admin.ModelAdmin):
list_display = (
"priority",
"name",
"type",
"api_key",
"api_url",
"created_at",
)
search_fields = ("name", "api_key", "api_url", "type")
ordering = ("priority",)
@admin.register(Conversation)
class ConversationAdmin(admin.ModelAdmin):
list_display = (

View file

@ -0,0 +1,89 @@
# Generated by Django 5.0.8 on 2024-10-18 00:41
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0068_alter_agent_output_modes"),
]
operations = [
migrations.CreateModel(
name="WebScraper",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"name",
models.CharField(
blank=True,
default=None,
help_text="Friendly name. If not set, it will be set to the type of the scraper.",
max_length=200,
null=True,
unique=True,
),
),
(
"type",
models.CharField(
choices=[
("Firecrawl", "Firecrawl"),
("Olostep", "Olostep"),
("Jina", "Jina"),
("Direct", "Direct"),
],
default="Jina",
max_length=20,
),
),
(
"api_key",
models.CharField(
blank=True,
default=None,
help_text="API key of the web scraper. Only set if scraper service requires an API key. Default is set from env var.",
max_length=200,
null=True,
),
),
(
"api_url",
models.URLField(
blank=True,
default=None,
help_text="API URL of the web scraper. Only set if scraper service on non-default URL.",
null=True,
),
),
(
"priority",
models.IntegerField(
blank=True,
default=None,
help_text="Priority of the web scraper. Lower numbers run first.",
null=True,
unique=True,
),
),
],
options={
"abstract": False,
},
),
migrations.AddField(
model_name="serverchatsettings",
name="web_scraper",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.CASCADE,
related_name="web_scraper",
to="database.webscraper",
),
),
]

View file

@ -1,3 +1,4 @@
import os
import re
import uuid
from random import choice
@ -11,8 +12,6 @@ from django.dispatch import receiver
from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField
from khoj.utils.helpers import ConversationCommand
class BaseModel(models.Model):
created_at = models.DateTimeField(auto_now_add=True)
@ -244,6 +243,79 @@ class GithubRepoConfig(BaseModel):
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
class WebScraper(BaseModel):
class WebScraperType(models.TextChoices):
FIRECRAWL = "Firecrawl"
OLOSTEP = "Olostep"
JINA = "Jina"
DIRECT = "Direct"
name = models.CharField(
max_length=200,
default=None,
null=True,
blank=True,
unique=True,
help_text="Friendly name. If not set, it will be set to the type of the scraper.",
)
type = models.CharField(max_length=20, choices=WebScraperType.choices, default=WebScraperType.JINA)
api_key = models.CharField(
max_length=200,
default=None,
null=True,
blank=True,
help_text="API key of the web scraper. Only set if scraper service requires an API key. Default is set from env var.",
)
api_url = models.URLField(
max_length=200,
default=None,
null=True,
blank=True,
help_text="API URL of the web scraper. Only set if scraper service on non-default URL.",
)
priority = models.IntegerField(
default=None,
null=True,
blank=True,
unique=True,
help_text="Priority of the web scraper. Lower numbers run first.",
)
def clean(self):
error = {}
if self.name is None:
self.name = self.type.capitalize()
if self.api_url is None:
if self.type == self.WebScraperType.FIRECRAWL:
self.api_url = os.getenv("FIRECRAWL_API_URL", "https://api.firecrawl.dev")
elif self.type == self.WebScraperType.OLOSTEP:
self.api_url = os.getenv("OLOSTEP_API_URL", "https://agent.olostep.com/olostep-p2p-incomingAPI")
elif self.type == self.WebScraperType.JINA:
self.api_url = os.getenv("JINA_READER_API_URL", "https://r.jina.ai/")
if self.api_key is None:
if self.type == self.WebScraperType.FIRECRAWL:
self.api_key = os.getenv("FIRECRAWL_API_KEY")
if not self.api_key and self.api_url == "https://api.firecrawl.dev":
error["api_key"] = "Set API key to use default Firecrawl. Get API key from https://firecrawl.dev."
elif self.type == self.WebScraperType.OLOSTEP:
self.api_key = os.getenv("OLOSTEP_API_KEY")
if self.api_key is None:
error["api_key"] = "Set API key to use Olostep. Get API key from https://olostep.com/."
elif self.type == self.WebScraperType.JINA:
self.api_key = os.getenv("JINA_API_KEY")
if error:
raise ValidationError(error)
def save(self, *args, **kwargs):
self.clean()
if self.priority is None:
max_priority = WebScraper.objects.aggregate(models.Max("priority"))["priority__max"]
self.priority = max_priority + 1 if max_priority else 1
super().save(*args, **kwargs)
class ServerChatSettings(BaseModel):
chat_default = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
@ -251,6 +323,9 @@ class ServerChatSettings(BaseModel):
chat_advanced = models.ForeignKey(
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
)
web_scraper = models.ForeignKey(
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
)
class LocalOrgConfig(BaseModel):

View file

@ -114,6 +114,7 @@ class CrossEncoderModel:
payload = {"inputs": {"query": query, "passages": [hit.additional[key] for hit in hits]}}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.post(target_url, json=payload, headers=headers)
response.raise_for_status()
return response.json()["scores"]
cross_inp = [[query, hit.additional[key]] for hit in hits]

View file

@ -10,14 +10,22 @@ import aiohttp
from bs4 import BeautifulSoup
from markdownify import markdownify
from khoj.database.models import Agent, KhojUser
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import Agent, KhojUser, WebScraper
from khoj.processor.conversation import prompts
from khoj.routers.helpers import (
ChatEvent,
extract_relevant_info,
generate_online_subqueries,
infer_webpage_urls,
)
from khoj.utils.helpers import is_internet_connected, is_none_or_empty, timer
from khoj.utils.helpers import (
is_env_var_true,
is_internal_url,
is_internet_connected,
is_none_or_empty,
timer,
)
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
@ -25,12 +33,11 @@ logger = logging.getLogger(__name__)
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
SERPER_DEV_URL = "https://google.serper.dev/search"
JINA_READER_API_URL = "https://r.jina.ai/"
JINA_SEARCH_API_URL = "https://s.jina.ai/"
JINA_API_KEY = os.getenv("JINA_API_KEY")
OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY")
OLOSTEP_API_URL = "https://agent.olostep.com/olostep-p2p-incomingAPI"
FIRECRAWL_USE_LLM_EXTRACT = is_env_var_true("FIRECRAWL_USE_LLM_EXTRACT")
OLOSTEP_QUERY_PARAMS = {
"timeout": 35, # seconds
"waitBeforeScraping": 1, # seconds
@ -83,33 +90,36 @@ async def search_online(
search_results = await asyncio.gather(*search_tasks)
response_dict = {subquery: search_result for subquery, search_result in search_results}
# Gather distinct web page data from organic results of each subquery without an instant answer.
# Gather distinct web pages from organic results for subqueries without an instant answer.
# Content of web pages is directly available when Jina is used for search.
webpages = {
(organic.get("link"), subquery, organic.get("content"))
for subquery in response_dict
for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]
if "answerBox" not in response_dict[subquery]
}
webpages: Dict[str, Dict] = {}
for subquery in response_dict:
if "answerBox" in response_dict[subquery]:
continue
for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]:
link = organic.get("link")
if link in webpages:
webpages[link]["queries"].add(subquery)
else:
webpages[link] = {"queries": {subquery}, "content": organic.get("content")}
# Read, extract relevant info from the retrieved web pages
if webpages:
webpage_links = set([link for link, _, _ in webpages])
logger.info(f"Reading web pages at: {list(webpage_links)}")
logger.info(f"Reading web pages at: {webpages.keys()}")
if send_status_func:
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
webpage_links_str = "\n- " + "\n- ".join(webpages.keys())
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(subquery, link, content, user=user, agent=agent)
for link, subquery, content in webpages
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
for link, data in webpages.items()
]
results = await asyncio.gather(*tasks)
# Collect extracted info from the retrieved web pages
for subquery, webpage_extract, url in results:
for subqueries, url, webpage_extract in results:
if webpage_extract is not None:
response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract}
response_dict[subqueries.pop()]["webpages"] = {"link": url, "snippet": webpage_extract}
yield response_dict
@ -156,29 +166,66 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(query, url, user=user, agent=agent) for url in urls]
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
response[query]["webpages"] = [
{"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None
{"query": qs.pop(), "link": url, "snippet": extract} for qs, url, extract in results if extract is not None
]
yield response
async def read_webpage(
url, scraper_type=None, api_key=None, api_url=None, subqueries=None, agent=None
) -> Tuple[str | None, str | None]:
if scraper_type == WebScraper.WebScraperType.FIRECRAWL and FIRECRAWL_USE_LLM_EXTRACT:
return None, await query_webpage_with_firecrawl(url, subqueries, api_key, api_url, agent)
elif scraper_type == WebScraper.WebScraperType.FIRECRAWL:
return await read_webpage_with_firecrawl(url, api_key, api_url), None
elif scraper_type == WebScraper.WebScraperType.OLOSTEP:
return await read_webpage_with_olostep(url, api_key, api_url), None
elif scraper_type == WebScraper.WebScraperType.JINA:
return await read_webpage_with_jina(url, api_key, api_url), None
else:
return await read_webpage_at_url(url), None
async def read_webpage_and_extract_content(
subquery: str, url: str, content: str = None, user: KhojUser = None, agent: Agent = None
) -> Tuple[str, Union[None, str], str]:
try:
if is_none_or_empty(content):
with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, content, user=user, agent=agent)
return subquery, extracted_info, url
except Exception as e:
logger.error(f"Failed to read web page at '{url}' with {e}")
return subquery, None, url
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
) -> Tuple[set[str], str, Union[None, str]]:
# Select the web scrapers to use for reading the web page
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
# Only use the direct web scraper for internal URLs
if is_internal_url(url):
web_scrapers = [scraper for scraper in web_scrapers if scraper.type == WebScraper.WebScraperType.DIRECT]
# Fallback through enabled web scrapers until we successfully read the web page
extracted_info = None
for scraper in web_scrapers:
try:
# Read the web page
if is_none_or_empty(content):
with timer(f"Reading web page with {scraper.type} at '{url}' took", logger, log_level=logging.INFO):
content, extracted_info = await read_webpage(
url, scraper.type, scraper.api_key, scraper.api_url, subqueries, agent
)
# Extract relevant information from the web page
if is_none_or_empty(extracted_info):
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
# If we successfully extracted information, break the loop
if not is_none_or_empty(extracted_info):
break
except Exception as e:
logger.warning(f"Failed to read web page with {scraper.type} at '{url}' with {e}")
# If this is the last web scraper in the list, log an error
if scraper.name == web_scrapers[-1].name:
logger.error(f"All web scrapers failed for '{url}'")
return subqueries, url, extracted_info
async def read_webpage_at_url(web_url: str) -> str:
@ -195,23 +242,23 @@ async def read_webpage_at_url(web_url: str) -> str:
return markdownify(body)
async def read_webpage_with_olostep(web_url: str) -> str:
headers = {"Authorization": f"Bearer {OLOSTEP_API_KEY}"}
async def read_webpage_with_olostep(web_url: str, api_key: str, api_url: str) -> str:
headers = {"Authorization": f"Bearer {api_key}"}
web_scraping_params: Dict[str, Union[str, int, bool]] = OLOSTEP_QUERY_PARAMS.copy() # type: ignore
web_scraping_params["url"] = web_url
async with aiohttp.ClientSession() as session:
async with session.get(OLOSTEP_API_URL, params=web_scraping_params, headers=headers) as response:
async with session.get(api_url, params=web_scraping_params, headers=headers) as response:
response.raise_for_status()
response_json = await response.json()
return response_json["markdown_content"]
async def read_webpage_with_jina(web_url: str) -> str:
jina_reader_api_url = f"{JINA_READER_API_URL}/{web_url}"
async def read_webpage_with_jina(web_url: str, api_key: str, api_url: str) -> str:
jina_reader_api_url = f"{api_url}/{web_url}"
headers = {"Accept": "application/json", "X-Timeout": "30"}
if JINA_API_KEY:
headers["Authorization"] = f"Bearer {JINA_API_KEY}"
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
async with aiohttp.ClientSession() as session:
async with session.get(jina_reader_api_url, headers=headers) as response:
@ -220,6 +267,54 @@ async def read_webpage_with_jina(web_url: str) -> str:
return response_json["data"]["content"]
async def read_webpage_with_firecrawl(web_url: str, api_key: str, api_url: str) -> str:
firecrawl_api_url = f"{api_url}/v1/scrape"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
params = {"url": web_url, "formats": ["markdown"], "excludeTags": ["script", ".ad"]}
async with aiohttp.ClientSession() as session:
async with session.post(firecrawl_api_url, json=params, headers=headers) as response:
response.raise_for_status()
response_json = await response.json()
return response_json["data"]["markdown"]
async def query_webpage_with_firecrawl(
web_url: str, queries: set[str], api_key: str, api_url: str, agent: Agent = None
) -> str:
firecrawl_api_url = f"{api_url}/v1/scrape"
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
schema = {
"type": "object",
"properties": {
"relevant_extract": {"type": "string"},
},
"required": [
"relevant_extract",
],
}
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
system_prompt = f"""
{prompts.system_prompt_extract_relevant_information}
{personality_context}
User Query: {", ".join(queries)}
Collate only relevant information from the website to answer the target query and in the provided JSON schema.
""".strip()
params = {"url": web_url, "formats": ["extract"], "extract": {"systemPrompt": system_prompt, "schema": schema}}
async with aiohttp.ClientSession() as session:
async with session.post(firecrawl_api_url, json=params, headers=headers) as response:
response.raise_for_status()
response_json = await response.json()
return response_json["data"]["extract"]["relevant_extract"]
async def search_with_jina(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]:
encoded_query = urllib.parse.quote(query)
jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}"

View file

@ -3,7 +3,6 @@ import base64
import json
import logging
import time
import warnings
from datetime import datetime
from functools import partial
from typing import Dict, Optional
@ -574,7 +573,6 @@ async def chat(
chat_metadata: dict = {}
connection_alive = True
user: KhojUser = request.user.object
subscribed: bool = has_required_scope(request, ["premium"])
event_delimiter = "␃🔚␗"
q = unquote(q)
nonlocal conversation_id
@ -641,7 +639,7 @@ async def chat(
request=request,
telemetry_type="api",
api="chat",
client=request.user.client_app,
client=common.client,
user_agent=request.headers.get("user-agent"),
host=request.headers.get("host"),
metadata=chat_metadata,
@ -840,25 +838,33 @@ async def chat(
# Gather Context
## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], None
async for result in extract_references_and_questions(
request,
meta_log,
q,
(n or 7),
d,
conversation_id,
conversation_commands,
location,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
try:
async for result in extract_references_and_questions(
request,
meta_log,
q,
(n or 7),
d,
conversation_id,
conversation_commands,
location,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
compiled_references.extend(result[0])
inferred_queries.extend(result[1])
defiltered_query = result[2]
except Exception as e:
error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
logger.warning(error_message)
async for result in send_event(
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
):
yield result
if not is_none_or_empty(compiled_references):
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
@ -894,12 +900,13 @@ async def chat(
yield result[ChatEvent.STATUS]
else:
online_results = result
except ValueError as e:
except Exception as e:
error_message = f"Error searching online: {e}. Attempting to respond without online results"
logger.warning(error_message)
async for result in send_llm_response(error_message):
async for result in send_event(
ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
):
yield result
return
## Gather Webpage References
if ConversationCommand.Webpage in conversation_commands:
@ -928,11 +935,15 @@ async def chat(
webpages.append(webpage["link"])
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
yield result
except ValueError as e:
except Exception as e:
logger.warning(
f"Error directly reading webpages: {e}. Attempting to respond without online results",
f"Error reading webpages: {e}. Attempting to respond without webpage results",
exc_info=True,
)
async for result in send_event(
ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
):
yield result
## Send Gathered References
async for result in send_event(

View file

@ -353,13 +353,13 @@ async def aget_relevant_information_sources(
final_response = [ConversationCommand.Default]
else:
final_response = [ConversationCommand.General]
return final_response
except Exception as e:
except Exception:
logger.error(f"Invalid response for determining relevant tools: {response}")
if len(agent_tools) == 0:
final_response = [ConversationCommand.Default]
else:
final_response = agent_tools
return final_response
async def aget_relevant_output_modes(
@ -551,12 +551,14 @@ async def schedule_query(
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agent: Agent = None) -> Union[str, None]:
async def extract_relevant_info(
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
"""
if is_none_or_empty(corpus) or is_none_or_empty(q):
if is_none_or_empty(corpus) or is_none_or_empty(qs):
return None
personality_context = (
@ -564,17 +566,16 @@ async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agen
)
extract_relevant_information = prompts.extract_relevant_information.format(
query=q,
query=", ".join(qs),
corpus=corpus.strip(),
personality_context=personality_context,
)
with timer("Chat actor: Extract relevant information from data", logger):
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
user=user,
)
response = await send_message_to_model_wrapper(
extract_relevant_information,
prompts.system_prompt_extract_relevant_information,
user=user,
)
return response.strip()

View file

@ -3,6 +3,7 @@ import math
from pathlib import Path
from typing import List, Optional, Tuple, Type, Union
import requests
import torch
from asgiref.sync import sync_to_async
from sentence_transformers import util
@ -231,8 +232,12 @@ def setup(
def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
"""Score all retrieved entries using the cross-encoder"""
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
try:
with timer("Cross-Encoder Predict Time", logger, state.device):
cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
except requests.exceptions.HTTPError as e:
logger.error(f"Failed to rerank documents using the inference endpoint. Error: {e}.", exc_info=True)
cross_scores = [0.0] * len(hits)
# Convert cross-encoder scores to distances and pass in hits for reranking
for idx in range(len(cross_scores)):

View file

@ -2,10 +2,12 @@ from __future__ import annotations # to avoid quoting type hints
import datetime
import io
import ipaddress
import logging
import os
import platform
import random
import urllib.parse
import uuid
from collections import OrderedDict
from enum import Enum
@ -164,9 +166,9 @@ def get_class_by_name(name: str) -> object:
class timer:
"""Context manager to log time taken for a block of code to run"""
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None, log_level=logging.DEBUG):
self.message = message
self.logger = logger
self.logger = logger.debug if log_level == logging.DEBUG else logger.info
self.device = device
def __enter__(self):
@ -176,9 +178,9 @@ class timer:
def __exit__(self, *_):
elapsed = perf_counter() - self.start
if self.device is None:
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds")
self.logger(f"{self.message}: {elapsed:.3f} seconds")
else:
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}")
self.logger(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}")
class LRU(OrderedDict):
@ -436,6 +438,46 @@ def is_internet_connected():
return False
def is_internal_url(url: str) -> bool:
"""
Check if a URL is likely to be internal/non-public.
Args:
url (str): The URL to check.
Returns:
bool: True if the URL is likely internal, False otherwise.
"""
try:
parsed_url = urllib.parse.urlparse(url)
hostname = parsed_url.hostname
# Check for localhost
if hostname in ["localhost", "127.0.0.1", "::1"]:
return True
# Check for IP addresses in private ranges
try:
ip = ipaddress.ip_address(hostname)
return ip.is_private
except ValueError:
pass # Not an IP address, continue with other checks
# Check for common internal TLDs
internal_tlds = [".local", ".internal", ".private", ".corp", ".home", ".lan"]
if any(hostname.endswith(tld) for tld in internal_tlds):
return True
# Check for URLs without a TLD
if "." not in hostname:
return True
return False
except Exception:
# If we can't parse the URL or something else goes wrong, assume it's not internal
return False
def convert_image_to_webp(image_bytes):
"""Convert image bytes to webp format for faster loading"""
image_io = io.BytesIO(image_bytes)