From 0f6e4ff683d5ae5baa5e6ca0960d96e309fa39cf Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 20 Dec 2023 09:22:26 +0530 Subject: [PATCH] Add a model that specifies the user's search model configuration - Update all endpoints that generate embeddings to use the new model. Incl. generating text embeddings, creating embeddings for a search query --- src/khoj/database/adapters/__init__.py | 6 +++- .../migrations/0023_usersearchmodelconfig.py | 33 +++++++++++++++++++ src/khoj/processor/content/text_to_entries.py | 4 +-- src/khoj/routers/api.py | 4 +-- src/khoj/search_type/text_search.py | 4 +-- 5 files changed, 44 insertions(+), 7 deletions(-) create mode 100644 src/khoj/database/migrations/0023_usersearchmodelconfig.py diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index b04e4828..0610d763 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -32,6 +32,7 @@ from khoj.database.models import ( SpeechToTextModelOptions, Subscription, UserConversationConfig, + UserSearchModelConfig, OpenAIProcessorConversationConfig, OfflineChatProcessorConversationConfig, ReflectiveQuestion, @@ -250,7 +251,10 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): return config -def get_default_search_model(): +def get_user_search_model_or_default(user=None): + if user and UserSearchModelConfig.objects.filter(user=user).exists(): + return UserSearchModelConfig.objects.filter(user=user).first().setting + if SearchModelConfig.objects.filter(name="default").exists(): return SearchModelConfig.objects.filter(name="default").first() return SearchModelConfig.objects.first() diff --git a/src/khoj/database/migrations/0023_usersearchmodelconfig.py b/src/khoj/database/migrations/0023_usersearchmodelconfig.py new file mode 100644 index 00000000..0aec99f0 --- /dev/null +++ b/src/khoj/database/migrations/0023_usersearchmodelconfig.py @@ -0,0 +1,33 @@ +# Generated by Django 4.2.7 on 2023-12-19 15:44 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0022_texttoimagemodelconfig"), + ] + + operations = [ + migrations.CreateModel( + name="UserSearchModelConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "setting", + models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.searchmodelconfig"), + ), + ( + "user", + models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index bfcf37f7..0c092ed4 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -14,7 +14,7 @@ from khoj.utils.helpers import is_none_or_empty, timer, batcher from khoj.utils.rawconfig import Entry from khoj.search_filter.date_filter import DateFilter from khoj.database.models import KhojUser, Entry as DbEntry, EntryDates -from khoj.database.adapters import EntryAdapters, get_default_search_model +from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default logger = logging.getLogger(__name__) @@ -112,7 +112,7 @@ class TextToEntries(ABC): with timer("Generated embeddings for entries to add to database in", logger): entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process] data_to_embed = [getattr(entry, key) for entry in entries_to_process] - model = get_default_search_model() + model = get_user_search_model_or_default(user) embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed) added_entries: list[DbEntry] = [] diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index dfdfceda..ad360dd3 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -18,7 +18,7 @@ from starlette.authentication import requires # Internal Packages from khoj.configure import configure_server from khoj.database import adapters -from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model +from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_search_model_or_default from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions from khoj.database.models import Entry as DbEntry from khoj.database.models import ( @@ -416,7 +416,7 @@ async def search( ] if text_search_models: with timer("Encoding query took", logger=logger): - search_model = await sync_to_async(get_default_search_model)() + search_model = await sync_to_async(get_user_search_model_or_default)(user) encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) with concurrent.futures.ThreadPoolExecutor() as executor: diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 1523473c..a4b8c9b1 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -19,7 +19,7 @@ from khoj.utils.state import SearchType from khoj.utils.rawconfig import SearchResponse, Entry from khoj.utils.jsonl import load_jsonl from khoj.processor.content.text_to_entries import TextToEntries -from khoj.database.adapters import EntryAdapters, get_default_search_model +from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default from khoj.database.models import KhojUser, Entry as DbEntry logger = logging.getLogger(__name__) @@ -115,7 +115,7 @@ async def query( # Encode the query using the bi-encoder if question_embedding is None: with timer("Query Encode Time", logger, state.device): - search_model = await sync_to_async(get_default_search_model)() + search_model = await sync_to_async(get_user_search_model_or_default)(user) question_embedding = state.embeddings_model[search_model.name].embed_query(query) # Find relevant entries for the query