Make search model configurable on server

- Expose ability to modify search model via Django admin interface
- Previously the bi_encoder and cross_encoder models to use were set
  in code
- Now it's user configurable but with a default config generated by
  default
This commit is contained in:
Debanjum Singh Solanky 2023-11-14 16:56:26 -08:00
parent b734984d6d
commit 4af194d74b
10 changed files with 91 additions and 28 deletions

View file

@ -1,8 +1,8 @@
import math
from typing import Optional, Type, TypeVar, List
from datetime import date, datetime, timedelta
from typing import Optional, Type, List
from datetime import date, datetime
import secrets
from typing import Type, TypeVar, List
from typing import Type, List
from datetime import date, timezone
from django.db import models
@ -31,6 +31,7 @@ from database.models import (
GithubRepoConfig,
Conversation,
ChatModelOptions,
SearchModel,
Subscription,
UserConversationConfig,
OpenAIProcessorConversationConfig,
@ -41,15 +42,6 @@ from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.date_filter import DateFilter
ModelType = TypeVar("ModelType", bound=models.Model)
async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType:
instance = await model_class.objects.filter(id=id).afirst()
if not instance:
raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found")
return instance
async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst()
@ -220,6 +212,10 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config
def get_or_create_search_model():
return SearchModel.objects.filter().get_or_create()[0]
class ConversationAdapters:
@staticmethod
def get_conversation_by_user(user: KhojUser):

View file

@ -8,6 +8,7 @@ from database.models import (
ChatModelOptions,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
SearchModel,
Subscription,
)
@ -16,4 +17,5 @@ admin.site.register(KhojUser, UserAdmin)
admin.site.register(ChatModelOptions)
admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig)
admin.site.register(SearchModel)
admin.site.register(Subscription)

View file

@ -0,0 +1,32 @@
# Generated by Django 4.2.5 on 2023-11-14 23:25
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0016_alter_subscription_renewal_date"),
]
operations = [
migrations.CreateModel(
name="SearchModel",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("name", models.CharField(default="default", max_length=200)),
("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
(
"cross_encoder",
models.CharField(
blank=True, default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200, null=True
),
),
],
options={
"abstract": False,
},
),
]

View file

@ -102,6 +102,18 @@ class LocalPlaintextConfig(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
class SearchModel(BaseModel):
class ModelType(models.TextChoices):
TEXT = "text"
name = models.CharField(max_length=200, default="default")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
cross_encoder = models.CharField(
max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2", null=True, blank=True
)
class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200)

View file

@ -3,7 +3,6 @@ import logging
import json
from enum import Enum
from typing import Optional
from fastapi import Request
import requests
import os
@ -21,15 +20,16 @@ from starlette.authentication import (
)
# Internal Packages
from database.models import KhojUser, Subscription
from database.adapters import get_all_users, get_or_create_search_model
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, load_content, configure_search
from khoj.utils import constants, state
from khoj.utils.config import (
SearchType,
)
from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig
from khoj.routers.indexer import configure_content, load_content, configure_search
from database.models import KhojUser, Subscription
from database.adapters import get_all_users
logger = logging.getLogger(__name__)
@ -113,6 +113,9 @@ def configure_server(
# Initialize Search Models from Config and initialize content
try:
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder)
state.config_lock.acquire()
state.SearchType = configure_search_types(state.config)
state.search_models = configure_search(state.search_models, state.config.search_type)

View file

@ -7,10 +7,10 @@ from khoj.utils.rawconfig import SearchResponse
class EmbeddingsModel:
def __init__(self):
def __init__(self, model_name: str = "thenlper/gte-small"):
self.encode_kwargs = {"normalize_embeddings": True}
self.model_kwargs = {"device": get_device()}
self.model_name = "thenlper/gte-small"
self.model_name = model_name
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
def embed_query(self, query):
@ -21,11 +21,11 @@ class EmbeddingsModel:
class CrossEncoderModel:
def __init__(self):
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
self.model_name = model_name
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
def predict(self, query, hits: List[SearchResponse]):
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
cross__inp = [[query, hit.additional[key]] for hit in hits]
cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
return cross_scores

View file

@ -14,7 +14,7 @@ from khoj.utils.rawconfig import Entry
from khoj.processor.embeddings import EmbeddingsModel
from khoj.search_filter.date_filter import DateFilter
from database.models import KhojUser, Entry as DbEntry, EntryDates
from database.adapters import EntryAdapters
from database.adapters import EntryAdapters, get_or_create_search_model
logger = logging.getLogger(__name__)
@ -22,7 +22,8 @@ logger = logging.getLogger(__name__)
class TextToEntries(ABC):
def __init__(self, config: Any = None):
self.embeddings_model = EmbeddingsModel()
bi_encoder_name = get_or_create_search_model().bi_encoder
self.embeddings_model = EmbeddingsModel(bi_encoder_name)
self.config = config
self.date_filter = DateFilter()

View file

@ -5,21 +5,20 @@ from typing import List, Dict
from collections import defaultdict
# External Packages
import torch
from pathlib import Path
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
# Internal Packages
from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU, get_device
from khoj.utils.rawconfig import FullConfig
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
# Application Global State
config = FullConfig()
search_models = SearchModels()
embeddings_model = EmbeddingsModel()
cross_encoder_model = CrossEncoderModel()
embeddings_model: EmbeddingsModel = None
cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex()
gpt4all_processor_config: GPT4AllProcessorModel = None
config_file: Path = None

View file

@ -8,11 +8,13 @@ from fastapi import FastAPI
import os
from fastapi import FastAPI
app = FastAPI()
# Internal Packages
from khoj.configure import configure_routes, configure_search_types, configure_middleware
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
from khoj.search_type import image_search, text_search
from khoj.utils.config import SearchModels
@ -54,6 +56,9 @@ def enable_db_access_for_all_tests(db):
@pytest.fixture(scope="session")
def search_config() -> SearchConfig:
state.embeddings_model = EmbeddingsModel()
state.cross_encoder_model = CrossEncoderModel()
model_dir = resolve_absolute_path("~/.khoj/search")
model_dir.mkdir(parents=True, exist_ok=True)
search_config = SearchConfig()
@ -292,6 +297,8 @@ def client(
state.config.content_type = content_config
state.config.search_type = search_config
state.SearchType = configure_search_types(state.config)
state.embeddings_model = EmbeddingsModel()
state.cross_encoder_model = CrossEncoderModel()
# These lines help us Mock the Search models for these search types
state.search_models.image_search = image_search.initialize_model(search_config.image)

View file

@ -7,6 +7,7 @@ from database.models import (
ChatModelOptions,
OfflineChatProcessorConversationConfig,
OpenAIProcessorConversationConfig,
SearchModel,
UserConversationConfig,
Conversation,
Subscription,
@ -71,6 +72,16 @@ class ConversationFactory(factory.django.DjangoModelFactory):
user = factory.SubFactory(UserFactory)
class SearchModelFactory(factory.django.DjangoModelFactory):
class Meta:
model = SearchModel
name = "default"
model_type = "text"
bi_encoder = "thenlper/gte-small"
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2"
class SubscriptionFactory(factory.django.DjangoModelFactory):
class Meta:
model = Subscription