Add a model that specifies the user's search model configuration

- Update all endpoints that generate embeddings to use the new model. Incl. generating text embeddings, creating embeddings for a search query
This commit is contained in:
sabaimran 2023-12-20 09:22:26 +05:30
parent 6dd2b05bf5
commit 0f6e4ff683
5 changed files with 44 additions and 7 deletions

View file

@ -32,6 +32,7 @@ from khoj.database.models import (
SpeechToTextModelOptions,
Subscription,
UserConversationConfig,
UserSearchModelConfig,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
ReflectiveQuestion,
@ -250,7 +251,10 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config
def get_default_search_model():
def get_user_search_model_or_default(user=None):
if user and UserSearchModelConfig.objects.filter(user=user).exists():
return UserSearchModelConfig.objects.filter(user=user).first().setting
if SearchModelConfig.objects.filter(name="default").exists():
return SearchModelConfig.objects.filter(name="default").first()
return SearchModelConfig.objects.first()

View file

@ -0,0 +1,33 @@
# Generated by Django 4.2.7 on 2023-12-19 15:44
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0022_texttoimagemodelconfig"),
]
operations = [
migrations.CreateModel(
name="UserSearchModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"setting",
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.searchmodelconfig"),
),
(
"user",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
],
options={
"abstract": False,
},
),
]

View file

@ -14,7 +14,7 @@ from khoj.utils.helpers import is_none_or_empty, timer, batcher
from khoj.utils.rawconfig import Entry
from khoj.search_filter.date_filter import DateFilter
from khoj.database.models import KhojUser, Entry as DbEntry, EntryDates
from khoj.database.adapters import EntryAdapters, get_default_search_model
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
logger = logging.getLogger(__name__)
@ -112,7 +112,7 @@ class TextToEntries(ABC):
with timer("Generated embeddings for entries to add to database in", logger):
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
model = get_default_search_model()
model = get_user_search_model_or_default(user)
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
added_entries: list[DbEntry] = []

View file

@ -18,7 +18,7 @@ from starlette.authentication import requires
# Internal Packages
from khoj.configure import configure_server
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_search_model_or_default
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
from khoj.database.models import Entry as DbEntry
from khoj.database.models import (
@ -416,7 +416,7 @@ async def search(
]
if text_search_models:
with timer("Encoding query took", logger=logger):
search_model = await sync_to_async(get_default_search_model)()
search_model = await sync_to_async(get_user_search_model_or_default)(user)
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor:

View file

@ -19,7 +19,7 @@ from khoj.utils.state import SearchType
from khoj.utils.rawconfig import SearchResponse, Entry
from khoj.utils.jsonl import load_jsonl
from khoj.processor.content.text_to_entries import TextToEntries
from khoj.database.adapters import EntryAdapters, get_default_search_model
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
from khoj.database.models import KhojUser, Entry as DbEntry
logger = logging.getLogger(__name__)
@ -115,7 +115,7 @@ async def query(
# Encode the query using the bi-encoder
if question_embedding is None:
with timer("Query Encode Time", logger, state.device):
search_model = await sync_to_async(get_default_search_model)()
search_model = await sync_to_async(get_user_search_model_or_default)(user)
question_embedding = state.embeddings_model[search_model.name].embed_query(query)
# Find relevant entries for the query