mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Make search model configurable on server
- Expose ability to modify search model via Django admin interface - Previously the bi_encoder and cross_encoder models to use were set in code - Now it's user configurable but with a default config generated by default
This commit is contained in:
parent
b734984d6d
commit
4af194d74b
10 changed files with 91 additions and 28 deletions
|
@ -1,8 +1,8 @@
|
|||
import math
|
||||
from typing import Optional, Type, TypeVar, List
|
||||
from datetime import date, datetime, timedelta
|
||||
from typing import Optional, Type, List
|
||||
from datetime import date, datetime
|
||||
import secrets
|
||||
from typing import Type, TypeVar, List
|
||||
from typing import Type, List
|
||||
from datetime import date, timezone
|
||||
|
||||
from django.db import models
|
||||
|
@ -31,6 +31,7 @@ from database.models import (
|
|||
GithubRepoConfig,
|
||||
Conversation,
|
||||
ChatModelOptions,
|
||||
SearchModel,
|
||||
Subscription,
|
||||
UserConversationConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
|
@ -41,15 +42,6 @@ from khoj.search_filter.word_filter import WordFilter
|
|||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=models.Model)
|
||||
|
||||
|
||||
async def retrieve_object(model_class: Type[ModelType], id: int) -> ModelType:
|
||||
instance = await model_class.objects.filter(id=id).afirst()
|
||||
if not instance:
|
||||
raise HTTPException(status_code=404, detail=f"{model_class.__name__} not found")
|
||||
return instance
|
||||
|
||||
|
||||
async def set_notion_config(token: str, user: KhojUser):
|
||||
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
||||
|
@ -220,6 +212,10 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
|||
return config
|
||||
|
||||
|
||||
def get_or_create_search_model():
|
||||
return SearchModel.objects.filter().get_or_create()[0]
|
||||
|
||||
|
||||
class ConversationAdapters:
|
||||
@staticmethod
|
||||
def get_conversation_by_user(user: KhojUser):
|
||||
|
|
|
@ -8,6 +8,7 @@ from database.models import (
|
|||
ChatModelOptions,
|
||||
OpenAIProcessorConversationConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
SearchModel,
|
||||
Subscription,
|
||||
)
|
||||
|
||||
|
@ -16,4 +17,5 @@ admin.site.register(KhojUser, UserAdmin)
|
|||
admin.site.register(ChatModelOptions)
|
||||
admin.site.register(OpenAIProcessorConversationConfig)
|
||||
admin.site.register(OfflineChatProcessorConversationConfig)
|
||||
admin.site.register(SearchModel)
|
||||
admin.site.register(Subscription)
|
||||
|
|
32
src/database/migrations/0017_searchmodel.py
Normal file
32
src/database/migrations/0017_searchmodel.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-14 23:25
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0016_alter_subscription_renewal_date"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="SearchModel",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("name", models.CharField(default="default", max_length=200)),
|
||||
("model_type", models.CharField(choices=[("text", "Text")], default="text", max_length=200)),
|
||||
("bi_encoder", models.CharField(default="thenlper/gte-small", max_length=200)),
|
||||
(
|
||||
"cross_encoder",
|
||||
models.CharField(
|
||||
blank=True, default="cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=200, null=True
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -102,6 +102,18 @@ class LocalPlaintextConfig(BaseModel):
|
|||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class SearchModel(BaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
TEXT = "text"
|
||||
|
||||
name = models.CharField(max_length=200, default="default")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
|
||||
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
|
||||
cross_encoder = models.CharField(
|
||||
max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2", null=True, blank=True
|
||||
)
|
||||
|
||||
|
||||
class OpenAIProcessorConversationConfig(BaseModel):
|
||||
api_key = models.CharField(max_length=200)
|
||||
|
||||
|
|
|
@ -3,7 +3,6 @@ import logging
|
|||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from fastapi import Request
|
||||
import requests
|
||||
import os
|
||||
|
||||
|
@ -21,15 +20,16 @@ from starlette.authentication import (
|
|||
)
|
||||
|
||||
# Internal Packages
|
||||
from database.models import KhojUser, Subscription
|
||||
from database.adapters import get_all_users, get_or_create_search_model
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import (
|
||||
SearchType,
|
||||
)
|
||||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||
from database.models import KhojUser, Subscription
|
||||
from database.adapters import get_all_users
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -113,6 +113,9 @@ def configure_server(
|
|||
|
||||
# Initialize Search Models from Config and initialize content
|
||||
try:
|
||||
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
|
||||
state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder)
|
||||
|
||||
state.config_lock.acquire()
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
|
|
|
@ -7,10 +7,10 @@ from khoj.utils.rawconfig import SearchResponse
|
|||
|
||||
|
||||
class EmbeddingsModel:
|
||||
def __init__(self):
|
||||
def __init__(self, model_name: str = "thenlper/gte-small"):
|
||||
self.encode_kwargs = {"normalize_embeddings": True}
|
||||
self.model_kwargs = {"device": get_device()}
|
||||
self.model_name = "thenlper/gte-small"
|
||||
self.model_name = model_name
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
|
||||
def embed_query(self, query):
|
||||
|
@ -21,11 +21,11 @@ class EmbeddingsModel:
|
|||
|
||||
|
||||
class CrossEncoderModel:
|
||||
def __init__(self):
|
||||
self.model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
||||
self.model_name = model_name
|
||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
||||
|
||||
def predict(self, query, hits: List[SearchResponse]):
|
||||
cross__inp = [[query, hit.additional["compiled"]] for hit in hits]
|
||||
def predict(self, query, hits: List[SearchResponse], key: str = "compiled"):
|
||||
cross__inp = [[query, hit.additional[key]] for hit in hits]
|
||||
cross_scores = self.cross_encoder_model.predict(cross__inp, apply_softmax=True)
|
||||
return cross_scores
|
||||
|
|
|
@ -14,7 +14,7 @@ from khoj.utils.rawconfig import Entry
|
|||
from khoj.processor.embeddings import EmbeddingsModel
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from database.models import KhojUser, Entry as DbEntry, EntryDates
|
||||
from database.adapters import EntryAdapters
|
||||
from database.adapters import EntryAdapters, get_or_create_search_model
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -22,7 +22,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class TextToEntries(ABC):
|
||||
def __init__(self, config: Any = None):
|
||||
self.embeddings_model = EmbeddingsModel()
|
||||
bi_encoder_name = get_or_create_search_model().bi_encoder
|
||||
self.embeddings_model = EmbeddingsModel(bi_encoder_name)
|
||||
self.config = config
|
||||
self.date_filter = DateFilter()
|
||||
|
||||
|
|
|
@ -5,21 +5,20 @@ from typing import List, Dict
|
|||
from collections import defaultdict
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import config as utils_config
|
||||
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import LRU, get_device
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
from khoj.processor.embeddings import EmbeddingsModel, CrossEncoderModel
|
||||
|
||||
# Application Global State
|
||||
config = FullConfig()
|
||||
search_models = SearchModels()
|
||||
embeddings_model = EmbeddingsModel()
|
||||
cross_encoder_model = CrossEncoderModel()
|
||||
embeddings_model: EmbeddingsModel = None
|
||||
cross_encoder_model: CrossEncoderModel = None
|
||||
content_index = ContentIndex()
|
||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||
config_file: Path = None
|
||||
|
|
|
@ -8,11 +8,13 @@ from fastapi import FastAPI
|
|||
import os
|
||||
from fastapi import FastAPI
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_routes, configure_search_types, configure_middleware
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.processor.plaintext.plaintext_to_entries import PlaintextToEntries
|
||||
from khoj.search_type import image_search, text_search
|
||||
from khoj.utils.config import SearchModels
|
||||
|
@ -54,6 +56,9 @@ def enable_db_access_for_all_tests(db):
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_config() -> SearchConfig:
|
||||
state.embeddings_model = EmbeddingsModel()
|
||||
state.cross_encoder_model = CrossEncoderModel()
|
||||
|
||||
model_dir = resolve_absolute_path("~/.khoj/search")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
search_config = SearchConfig()
|
||||
|
@ -292,6 +297,8 @@ def client(
|
|||
state.config.content_type = content_config
|
||||
state.config.search_type = search_config
|
||||
state.SearchType = configure_search_types(state.config)
|
||||
state.embeddings_model = EmbeddingsModel()
|
||||
state.cross_encoder_model = CrossEncoderModel()
|
||||
|
||||
# These lines help us Mock the Search models for these search types
|
||||
state.search_models.image_search = image_search.initialize_model(search_config.image)
|
||||
|
|
|
@ -7,6 +7,7 @@ from database.models import (
|
|||
ChatModelOptions,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
SearchModel,
|
||||
UserConversationConfig,
|
||||
Conversation,
|
||||
Subscription,
|
||||
|
@ -71,6 +72,16 @@ class ConversationFactory(factory.django.DjangoModelFactory):
|
|||
user = factory.SubFactory(UserFactory)
|
||||
|
||||
|
||||
class SearchModelFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = SearchModel
|
||||
|
||||
name = "default"
|
||||
model_type = "text"
|
||||
bi_encoder = "thenlper/gte-small"
|
||||
cross_encoder = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
|
||||
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = Subscription
|
||||
|
|
Loading…
Reference in a new issue