From 4af194d74bcec1150f9e531aed2cc4bd293d5b06 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 14 Nov 2023 16:56:26 -0800 Subject: [PATCH] Make search model configurable on server - Expose ability to modify search model via Django admin interface - Previously the bi_encoder and cross_encoder models to use were set in code - Now it's user configurable but with a default config generated by default --- src/database/adapters/__init__.py | 20 ++++++------- src/database/admin.py | 2 ++ src/database/migrations/0017_searchmodel.py | 32 +++++++++++++++++++++ src/database/models/__init__.py | 12 ++++++++ src/khoj/configure.py | 11 ++++--- src/khoj/processor/embeddings.py | 12 ++++---- src/khoj/processor/text_to_entries.py | 5 ++-- src/khoj/utils/state.py | 7 ++--- tests/conftest.py | 7 +++++ tests/helpers.py | 11 +++++++ 10 files changed, 91 insertions(+), 28 deletions(-) create mode 100644 src/database/migrations/0017_searchmodel.py diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 4b9b54ef..9aa7eb5c 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -1,8 +1,8 @@ import math -from typing import Optional, Type, TypeVar, List -from datetime import date, datetime, timedelta +from typing import Optional, Type, List +from datetime import date, datetime import secrets -from typing import Type, TypeVar, List +from typing import Type, List from datetime import date, timezone from django.db import models @@ -31,6 +31,7 @@ from database.models import ( GithubRepoConfig, Conversation, ChatModelOptions, + SearchModel, Subscription, UserConversationConfig, OpenAIProcessorConversationConfig, @@ -41,15 +42,6 @@ from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.date_filter import DateFilter -ModelType = TypeVar("ModelType", bound=models.Model) - - -async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType: - instance = await model_class.objects.filter(id=id).afirst() - if not instance: - raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found") - return instance - async def set_notion_config(token: str, user: KhojUser): notion_config = await NotionConfig.objects.filter(user=user).afirst() @@ -220,6 +212,10 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): return config +def get_or_create_search_model(): + return SearchModel.objects.filter().get_or_create()[0] + + class ConversationAdapters: @staticmethod def get_conversation_by_user(user: KhojUser): diff --git a/src/database/admin.py b/src/database/admin.py index 03c2ca42..a2aa85e2 100644 --- a/src/database/admin.py +++ b/src/database/admin.py @@ -8,6 +8,7 @@ from database.models import ( ChatModelOptions, OpenAIProcessorConversationConfig, OfflineChatProcessorConversationConfig, + SearchModel, Subscription, ) @@ -16,4 +17,5 @@ admin.site.register(KhojUser, UserAdmin) admin.site.register(ChatModelOptions) admin.site.register(OpenAIProcessorConversationConfig) admin.site.register(OfflineChatProcessorConversationConfig) +admin.site.register(SearchModel) admin.site.register(Subscription) diff --git a/src/database/migrations/0017_searchmodel.py b/src/database/migrations/0017_searchmodel.py new file mode 100644 index 00000000..f150e12b --- /dev/null +++ b/src/database/migrations/0017_searchmodel.py @@ -0,0 +1,32 @@ +# Generated by Django 4.2.5 on 2023-11-14 23:25 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0016_alter_subscription_renewal_date"), + ] + + operations = [ + migrations.CreateModel( + name="SearchModel", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("name", models.CharField(default="default", max_length=200)), + ("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)), + ("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)), + ( + "cross_encoder", + models.CharField( + blank=True, default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200, null=True + ), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index 437d86ed..5571c5a7 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -102,6 +102,18 @@ class LocalPlaintextConfig(BaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) +class SearchModel(BaseModel): + class ModelType(models.TextChoices): + TEXT = "text" + + name = models.CharField(max_length=200, default="default") + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT) + bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") + cross_encoder = models.CharField( + max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2", null=True, blank=True + ) + + class OpenAIProcessorConversationConfig(BaseModel): api_key = models.CharField(max_length=200) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 9fb1f019..ab992fb6 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -3,7 +3,6 @@ import logging import json from enum import Enum from typing import Optional -from fastapi import Request import requests import os @@ -21,15 +20,16 @@ from starlette.authentication import ( ) # Internal Packages +from database.models import KhojUser, Subscription +from database.adapters import get_all_users, get_or_create_search_model +from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel +from khoj.routers.indexer import configure_content, load_content, configure_search from khoj.utils import constants, state from khoj.utils.config import ( SearchType, ) from khoj.utils.fs_syncer import collect_files from khoj.utils.rawconfig import FullConfig -from khoj.routers.indexer import configure_content, load_content, configure_search -from database.models import KhojUser, Subscription -from database.adapters import get_all_users logger = logging.getLogger(__name__) @@ -113,6 +113,9 @@ def configure_server( # Initialize Search Models from Config and initialize content try: + state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder) + state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder) + state.config_lock.acquire() state.SearchType = configure_search_types(state.config) state.search_models = configure_search(state.search_models, state.config.search_type) diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index a4daa24f..59a61d05 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -7,10 +7,10 @@ from khoj.utils.rawconfig import SearchResponse class EmbeddingsModel: - def __init__(self): + def __init__(self, model_name: str = "thenlper/gte-small"): self.encode_kwargs = {"normalize_embeddings": True} self.model_kwargs = {"device": get_device()} - self.model_name = "thenlper/gte-small" + self.model_name = model_name self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) def embed_query(self, query): @@ -21,11 +21,11 @@ class EmbeddingsModel: class CrossEncoderModel: - def __init__(self): - self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2" + def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"): + self.model_name = model_name self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device()) - def predict(self, query, hits: List[SearchResponse]): - cross__inp = [[query, hit.additional["compiled"]] for hit in hits] + def predict(self, query, hits: List[SearchResponse], key: str = "compiled"): + cross__inp = [[query, hit.additional[key]] for hit in hits] cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True) return cross_scores diff --git a/src/khoj/processor/text_to_entries.py b/src/khoj/processor/text_to_entries.py index 3d79e02e..66a489eb 100644 --- a/src/khoj/processor/text_to_entries.py +++ b/src/khoj/processor/text_to_entries.py @@ -14,7 +14,7 @@ from khoj.utils.rawconfig import Entry from khoj.processor.embeddings import EmbeddingsModel from khoj.search_filter.date_filter import DateFilter from database.models import KhojUser, Entry as DbEntry, EntryDates -from database.adapters import EntryAdapters +from database.adapters import EntryAdapters, get_or_create_search_model logger = logging.getLogger(__name__) @@ -22,7 +22,8 @@ logger = logging.getLogger(__name__) class TextToEntries(ABC): def __init__(self, config: Any = None): - self.embeddings_model = EmbeddingsModel() + bi_encoder_name = get_or_create_search_model().bi_encoder + self.embeddings_model = EmbeddingsModel(bi_encoder_name) self.config = config self.date_filter = DateFilter() diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 098ae35e..3a03baef 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -5,21 +5,20 @@ from typing import List, Dict from collections import defaultdict # External Packages -import torch from pathlib import Path +from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel # Internal Packages from khoj.utils import config as utils_config from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel from khoj.utils.helpers import LRU, get_device from khoj.utils.rawconfig import FullConfig -from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel # Application Global State config = FullConfig() search_models = SearchModels() -embeddings_model = EmbeddingsModel() -cross_encoder_model = CrossEncoderModel() +embeddings_model: EmbeddingsModel = None +cross_encoder_model: CrossEncoderModel = None content_index = ContentIndex() gpt4all_processor_config: GPT4AllProcessorModel = None config_file: Path = None diff --git a/tests/conftest.py b/tests/conftest.py index 95fa9a99..a169722e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,11 +8,13 @@ from fastapi import FastAPI import os from fastapi import FastAPI + app = FastAPI() # Internal Packages from khoj.configure import configure_routes, configure_search_types, configure_middleware +from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries from khoj.search_type import image_search, text_search from khoj.utils.config import SearchModels @@ -54,6 +56,9 @@ def enable_db_access_for_all_tests(db): @pytest.fixture(scope="session") def search_config() -> SearchConfig: + state.embeddings_model = EmbeddingsModel() + state.cross_encoder_model = CrossEncoderModel() + model_dir = resolve_absolute_path("~/.khoj/search") model_dir.mkdir(parents=True, exist_ok=True) search_config = SearchConfig() @@ -292,6 +297,8 @@ def client( state.config.content_type = content_config state.config.search_type = search_config state.SearchType = configure_search_types(state.config) + state.embeddings_model = EmbeddingsModel() + state.cross_encoder_model = CrossEncoderModel() # These lines help us Mock the Search models for these search types state.search_models.image_search = image_search.initialize_model(search_config.image) diff --git a/tests/helpers.py b/tests/helpers.py index 03f3f9c7..bf30a80d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -7,6 +7,7 @@ from database.models import ( ChatModelOptions, OfflineChatProcessorConversationConfig, OpenAIProcessorConversationConfig, + SearchModel, UserConversationConfig, Conversation, Subscription, @@ -71,6 +72,16 @@ class ConversationFactory(factory.django.DjangoModelFactory): user = factory.SubFactory(UserFactory) +class SearchModelFactory(factory.django.DjangoModelFactory): + class Meta: + model = SearchModel + + name = "default" + model_type = "text" + bi_encoder = "thenlper/gte-small" + cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2" + + class SubscriptionFactory(factory.django.DjangoModelFactory): class Meta: model = Subscription