From c841abe13f3cde06690e5c818ccd05dc40b4e74f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 15 Oct 2024 17:17:36 -0700 Subject: [PATCH] Change webpage scraper to use via server admin panel --- src/khoj/database/adapters/__init__.py | 13 ++++++++++++ src/khoj/database/admin.py | 1 + .../0068_serverchatsettings_web_scraper.py | 21 +++++++++++++++++++ src/khoj/database/models/__init__.py | 7 +++++++ src/khoj/processor/tools/online_search.py | 13 +++++++----- 5 files changed, 50 insertions(+), 5 deletions(-) create mode 100644 src/khoj/database/migrations/0068_serverchatsettings_web_scraper.py diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 182ce701..51b8afe6 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1031,6 +1031,19 @@ class ConversationAdapters: return server_chat_settings.chat_advanced return await ConversationAdapters.aget_default_conversation_config(user) + @staticmethod + async def aget_webscraper(FIRECRAWL_API_KEY: str = None, OLOSTEP_API_KEY: str = None): + server_chat_settings: ServerChatSettings = await ServerChatSettings.objects.filter().afirst() + if server_chat_settings is not None and server_chat_settings.web_scraper is not None: + web_scraper = ServerChatSettings.WebScraper(server_chat_settings.web_scraper) + if (web_scraper == ServerChatSettings.WebScraper.FIRECRAWL and FIRECRAWL_API_KEY) or ( + web_scraper == ServerChatSettings.WebScraper.OLOSTEP and OLOSTEP_API_KEY + ): + return web_scraper + # Fallback to JinaAI if the API keys for the other providers are not set + # JinaAI is the default web scraper as it does not require an API key + return ServerChatSettings.WebScraper.JINAAI + @staticmethod def create_conversation_from_public_conversation( user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 3e192952..51988752 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -198,6 +198,7 @@ class ServerChatSettingsAdmin(admin.ModelAdmin): list_display = ( "chat_default", "chat_advanced", + "web_scraper", ) diff --git a/src/khoj/database/migrations/0068_serverchatsettings_web_scraper.py b/src/khoj/database/migrations/0068_serverchatsettings_web_scraper.py new file mode 100644 index 00000000..89482dbd --- /dev/null +++ b/src/khoj/database/migrations/0068_serverchatsettings_web_scraper.py @@ -0,0 +1,21 @@ +# Generated by Django 5.0.8 on 2024-10-16 00:06 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0067_alter_agent_style_icon"), + ] + + operations = [ + migrations.AddField( + model_name="serverchatsettings", + name="web_scraper", + field=models.CharField( + choices=[("firecrawl", "Firecrawl"), ("olostep", "Olostep"), ("jinaai", "JinaAI")], + default="jinaai", + max_length=20, + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index ec4b61d1..7c4a16fa 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -8,6 +8,7 @@ from django.core.exceptions import ValidationError from django.db import models from django.db.models.signals import pre_save from django.dispatch import receiver +from django.utils.translation import gettext_lazy from pgvector.django import VectorField from phonenumber_field.modelfields import PhoneNumberField @@ -245,12 +246,18 @@ class GithubRepoConfig(BaseModel): class ServerChatSettings(BaseModel): + class WebScraper(models.TextChoices): + FIRECRAWL = "firecrawl", gettext_lazy("Firecrawl") + OLOSTEP = "olostep", gettext_lazy("Olostep") + JINAAI = "jinaai", gettext_lazy("JinaAI") + chat_default = models.ForeignKey( ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default" ) chat_advanced = models.ForeignKey( ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced" ) + web_scraper = models.CharField(max_length=20, choices=WebScraper.choices, default=WebScraper.JINAAI) class LocalOrgConfig(BaseModel): diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py index ea45846b..df9b180f 100644 --- a/src/khoj/processor/tools/online_search.py +++ b/src/khoj/processor/tools/online_search.py @@ -10,7 +10,8 @@ import aiohttp from bs4 import BeautifulSoup from markdownify import markdownify -from khoj.database.models import Agent, KhojUser +from khoj.database.adapters import ConversationAdapters +from khoj.database.models import Agent, KhojUser, ServerChatSettings from khoj.processor.conversation import prompts from khoj.routers.helpers import ( ChatEvent, @@ -177,16 +178,18 @@ async def read_webpages( async def read_webpage_and_extract_content( subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None ) -> Tuple[set[str], str, Union[None, str]]: + # Select the web scraper to use for reading the web page + web_scraper = await ConversationAdapters.aget_webscraper(FIRECRAWL_API_KEY, OLOSTEP_API_KEY) extracted_info = None try: if is_none_or_empty(content): - with timer(f"Reading web page at '{url}' took", logger): - if FIRECRAWL_API_KEY: + with timer(f"Reading web page with {web_scraper.value} at '{url}' took", logger): + if web_scraper == ServerChatSettings.WebScraper.FIRECRAWL: if FIRECRAWL_TO_EXTRACT: extracted_info = await read_webpage_and_extract_content_with_firecrawl(url, subqueries, agent) else: content = await read_webpage_with_firecrawl(url) - elif OLOSTEP_API_KEY: + elif web_scraper == ServerChatSettings.WebScraper.OLOSTEP: content = await read_webpage_with_olostep(url) else: content = await read_webpage_with_jina(url) @@ -194,7 +197,7 @@ async def read_webpage_and_extract_content( with timer(f"Extracting relevant information from web page at '{url}' took", logger): extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent) except Exception as e: - logger.error(f"Failed to read web page at '{url}' with {e}") + logger.error(f"Failed to read web page with {web_scraper.value} at '{url}' with {e}") return subqueries, url, extracted_info