From 5120597d4ee9a06a50b4a672abaa48f6afd512fa Mon Sep 17 00:00:00 2001
From: sabaimran <65192171+sabaimran@users.noreply.github.com>
Date: Wed, 23 Oct 2024 17:38:37 -0700
Subject: [PATCH] Remove user customized search model (#946)
- Use a single standard search model across the server. There's diminishing benefits for having multiple user-customizable search models.
- We may want to add server-level customization for specific tasks
- Store the search model used to generate a given entry on the `Entry` object
- Remove user-facing APIs and view
- Add a management command for migrating the default search model on the server
In a future PR (after running the migration), we'll also remove the `UserSearchModelConfig`
---
src/interface/web/app/settings/page.tsx | 23 +--
src/khoj/database/adapters/__init__.py | 33 ++--
src/khoj/database/admin.py | 2 +
.../commands/change_default_model.py | 182 ++++++++++++++++++
.../migrations/0072_entry_search_model.py | 24 +++
src/khoj/database/models/__init__.py | 2 +
src/khoj/processor/content/text_to_entries.py | 6 +-
src/khoj/routers/api.py | 5 +-
src/khoj/routers/api_model.py | 33 ----
src/khoj/routers/helpers.py | 9 -
src/khoj/search_type/text_search.py | 8 +-
11 files changed, 237 insertions(+), 90 deletions(-)
create mode 100644 src/khoj/database/management/commands/change_default_model.py
create mode 100644 src/khoj/database/migrations/0072_entry_search_model.py
diff --git a/src/interface/web/app/settings/page.tsx b/src/interface/web/app/settings/page.tsx
index d79eeff4..fe3e11e7 100644
--- a/src/interface/web/app/settings/page.tsx
+++ b/src/interface/web/app/settings/page.tsx
@@ -718,7 +718,7 @@ export default function SettingsView() {
};
const updateModel = (name: string) => async (id: string) => {
- if (!userConfig?.is_active && name !== "search") {
+ if (!userConfig?.is_active) {
toast({
title: `Model Update`,
description: `You need to be subscribed to update ${name} models`,
@@ -1233,27 +1233,6 @@ export default function SettingsView() {
)}
- {userConfig.search_model_options.length > 0 && (
-
-
-
- Search
-
-
-
- Pick the search model to find your documents
-
-
-
-
-
- )}
{userConfig.paint_model_options.length > 0 && (
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index d5312b29..b77eccdd 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -466,18 +466,26 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config
-def get_user_search_model_or_default(user=None):
- if user and UserSearchModelConfig.objects.filter(user=user).exists():
- return UserSearchModelConfig.objects.filter(user=user).first().setting
+def get_default_search_model() -> SearchModelConfig:
+ default_search_model = SearchModelConfig.objects.filter(name="default").first()
- if SearchModelConfig.objects.filter(name="default").exists():
- return SearchModelConfig.objects.filter(name="default").first()
+ if default_search_model:
+ return default_search_model
else:
SearchModelConfig.objects.create()
return SearchModelConfig.objects.first()
+def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:
+ if user:
+ user_search_model = UserSearchModelConfig.objects.filter(user=user).first()
+ if user_search_model:
+ return user_search_model.setting
+
+ return get_default_search_model()
+
+
def get_or_create_search_models():
search_models = SearchModelConfig.objects.all()
if search_models.count() == 0:
@@ -487,21 +495,6 @@ def get_or_create_search_models():
return search_models
-async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
- config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
- if not config:
- return None
- new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
- return new_config
-
-
-async def aget_user_search_model(user: KhojUser):
- config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
- if not config:
- return None
- return config.setting
-
-
class ProcessLockAdapters:
@staticmethod
def get_process_lock(process_name: str):
diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py
index 5aa9204b..3087f7b1 100644
--- a/src/khoj/database/admin.py
+++ b/src/khoj/database/admin.py
@@ -126,6 +126,7 @@ class EntryAdmin(admin.ModelAdmin):
"created_at",
"updated_at",
"user",
+ "agent",
"file_source",
"file_type",
"file_name",
@@ -135,6 +136,7 @@ class EntryAdmin(admin.ModelAdmin):
list_filter = (
"file_type",
"user__email",
+ "search_model__name",
)
ordering = ("-created_at",)
diff --git a/src/khoj/database/management/commands/change_default_model.py b/src/khoj/database/management/commands/change_default_model.py
new file mode 100644
index 00000000..cfa78581
--- /dev/null
+++ b/src/khoj/database/management/commands/change_default_model.py
@@ -0,0 +1,182 @@
+import logging
+from typing import List
+
+from django.core.management.base import BaseCommand
+from django.db import transaction
+from django.db.models import Count, Q
+from tqdm import tqdm
+
+from khoj.database.adapters import get_default_search_model
+from khoj.database.models import (
+ Agent,
+ Entry,
+ KhojUser,
+ SearchModelConfig,
+ UserSearchModelConfig,
+)
+from khoj.processor.embeddings import EmbeddingsModel
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+class Command(BaseCommand):
+ help = "Convert all existing Entry objects to use a new default Search model."
+
+ def add_arguments(self, parser):
+ # Pass default SearchModelConfig ID
+ parser.add_argument(
+ "--search_model_id",
+ action="store",
+ help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects and UserSearchModelConfig objects.",
+ required=True,
+ )
+
+ # Set the apply flag to apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects.
+ parser.add_argument(
+ "--apply",
+ action="store_true",
+ help="Apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. Otherwise, only display the number of Entry objects and UserSearchModelConfig objects that will be affected.",
+ )
+
+ def handle(self, *args, **options):
+ @transaction.atomic
+ def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
+ entries = Entry.objects.filter(entry_filter).all()
+ compiled_entries = [entry.compiled for entry in entries]
+ updated_entries: List[Entry] = []
+ try:
+ embeddings = embeddings_model.embed_documents(compiled_entries)
+
+ except Exception as e:
+ logger.error(f"Error embedding documents: {e}")
+ return
+
+ for i, entry in enumerate(tqdm(entries)):
+ entry.embeddings = embeddings[i]
+ entry.search_model_id = search_model.id
+ updated_entries.append(entry)
+
+ Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
+
+ search_model_config_id = options.get("search_model_id")
+ apply = options.get("apply")
+
+ logger.info(f"SearchModelConfig ID: {search_model_config_id}")
+ logger.info(f"Apply: {apply}")
+
+ embeddings_model = dict()
+
+ search_models = SearchModelConfig.objects.all()
+ for model in search_models:
+ embeddings_model.update(
+ {
+ model.name: EmbeddingsModel(
+ model.bi_encoder,
+ model.embeddings_inference_endpoint,
+ model.embeddings_inference_endpoint_api_key,
+ query_encode_kwargs=model.bi_encoder_query_encode_config,
+ docs_encode_kwargs=model.bi_encoder_docs_encode_config,
+ model_kwargs=model.bi_encoder_model_config,
+ )
+ }
+ )
+
+ new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
+ logger.info(f"New default Search model: {new_default_search_model_config}")
+ user_search_model_configs_to_update = UserSearchModelConfig.objects.exclude(
+ setting_id=search_model_config_id
+ ).all()
+ logger.info(f"Number of UserSearchModelConfig objects to update: {user_search_model_configs_to_update.count()}")
+
+ for user_config in user_search_model_configs_to_update:
+ affected_user = user_config.user
+ entry_filter = Q(user=affected_user)
+ relevant_entries = Entry.objects.filter(entry_filter).all()
+ logger.info(f"Number of Entry objects to update for user {affected_user}: {relevant_entries.count()}")
+
+ if apply:
+ try:
+ regenerate_entries(
+ entry_filter,
+ embeddings_model[new_default_search_model_config.name],
+ new_default_search_model_config,
+ )
+ user_config.setting = new_default_search_model_config
+ user_config.save()
+
+ logger.info(
+ f"Updated UserSearchModelConfig object for user {affected_user} to use the new default Search model."
+ )
+ logger.info(
+ f"Updated {relevant_entries.count()} Entry objects for user {affected_user} to use the new default Search model."
+ )
+
+ except Exception as e:
+ logger.error(f"Error embedding documents: {e}")
+
+ logger.info("----")
+
+ # There are also plenty of users who have indexed documents without explicitly creating a UserSearchModelConfig object. You would have to migrate these users as well, if the default is different from search_model_config_id.
+ current_default = get_default_search_model()
+ if current_default.id != new_default_search_model_config.id:
+ users_without_user_search_model_config = KhojUser.objects.annotate(
+ user_search_model_config_count=Count("usersearchmodelconfig")
+ ).filter(user_search_model_config_count=0)
+
+ logger.info(f"Number of User objects to update: {users_without_user_search_model_config.count()}")
+ for user in users_without_user_search_model_config:
+ entry_filter = Q(user=user)
+ relevant_entries = Entry.objects.filter(entry_filter).all()
+ logger.info(f"Number of Entry objects to update for user {user}: {relevant_entries.count()}")
+
+ if apply:
+ try:
+ regenerate_entries(
+ entry_filter,
+ embeddings_model[new_default_search_model_config.name],
+ new_default_search_model_config,
+ )
+
+ UserSearchModelConfig.objects.create(user=user, setting=new_default_search_model_config)
+
+ logger.info(
+ f"Created UserSearchModelConfig object for user {user} to use the new default Search model."
+ )
+ logger.info(
+ f"Updated {relevant_entries.count()} Entry objects for user {user} to use the new default Search model."
+ )
+ except Exception as e:
+ logger.error(f"Error embedding documents: {e}")
+ else:
+ logger.info("Default is the same as search_model_config_id.")
+
+ all_agents = Agent.objects.all()
+ logger.info(f"Number of Agent objects to update: {all_agents.count()}")
+ for agent in all_agents:
+ entry_filter = Q(agent=agent)
+ relevant_entries = Entry.objects.filter(entry_filter).all()
+ logger.info(f"Number of Entry objects to update for agent {agent}: {relevant_entries.count()}")
+
+ if apply:
+ try:
+ regenerate_entries(
+ entry_filter,
+ embeddings_model[new_default_search_model_config.name],
+ new_default_search_model_config,
+ )
+ logger.info(
+ f"Updated {relevant_entries.count()} Entry objects for agent {agent} to use the new default Search model."
+ )
+ except Exception as e:
+ logger.error(f"Error embedding documents: {e}")
+ if apply and current_default.id != new_default_search_model_config.id:
+ # Get the existing default SearchModelConfig object and update its name
+ current_default.name = f"prev_default_{current_default.id}"
+ current_default.save()
+
+ # Update the new default SearchModelConfig object's name
+ new_default_search_model_config.name = "default"
+ new_default_search_model_config.save()
+ if not apply:
+ logger.info("Run the command with the --apply flag to apply the new default Search model.")
diff --git a/src/khoj/database/migrations/0072_entry_search_model.py b/src/khoj/database/migrations/0072_entry_search_model.py
new file mode 100644
index 00000000..545f1f62
--- /dev/null
+++ b/src/khoj/database/migrations/0072_entry_search_model.py
@@ -0,0 +1,24 @@
+# Generated by Django 5.0.8 on 2024-10-21 21:09
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0071_subscription_enabled_trial_at_and_more"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="entry",
+ name="search_model",
+ field=models.ForeignKey(
+ blank=True,
+ default=None,
+ null=True,
+ on_delete=django.db.models.deletion.SET_NULL,
+ to="database.searchmodelconfig",
+ ),
+ ),
+ ]
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index c89c409a..7b5bbd12 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -449,6 +449,7 @@ class UserVoiceModelConfig(BaseModel):
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
+# TODO Delete this model once all users have been migrated to the server's default settings
class UserSearchModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
@@ -535,6 +536,7 @@ class Entry(BaseModel):
url = models.URLField(max_length=400, default=None, null=True, blank=True)
hashed_value = models.CharField(max_length=100)
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
+ search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True)
def save(self, *args, **kwargs):
if self.user and self.agent:
diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py
index 6fee9c0c..b0b74996 100644
--- a/src/khoj/processor/content/text_to_entries.py
+++ b/src/khoj/processor/content/text_to_entries.py
@@ -12,7 +12,8 @@ from tqdm import tqdm
from khoj.database.adapters import (
EntryAdapters,
FileObjectAdapters,
- get_user_search_model_or_default,
+ get_default_search_model,
+ get_user_default_search_model,
)
from khoj.database.models import Entry as DbEntry
from khoj.database.models import EntryDates, KhojUser
@@ -148,10 +149,10 @@ class TextToEntries(ABC):
hashes_to_process |= hashes_for_file - existing_entry_hashes
embeddings = []
+ model = get_user_default_search_model(user=user)
with timer("Generated embeddings for entries to add to database in", logger):
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
- model = get_user_search_model_or_default(user)
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
added_entries: list[DbEntry] = []
@@ -177,6 +178,7 @@ class TextToEntries(ABC):
file_type=file_type,
hashed_value=entry_hash,
corpus_id=entry.corpus_id,
+ search_model=model,
)
)
try:
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index c542b1f3..8254da4d 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -25,8 +25,9 @@ from khoj.database.adapters import (
AutomationAdapters,
ConversationAdapters,
EntryAdapters,
+ get_default_search_model,
+ get_user_default_search_model,
get_user_photo,
- get_user_search_model_or_default,
)
from khoj.database.models import (
Agent,
@@ -149,7 +150,7 @@ async def execute_search(
encoded_asymmetric_query = None
if t != SearchType.Image:
with timer("Encoding query took", logger=logger):
- search_model = await sync_to_async(get_user_search_model_or_default)(user)
+ search_model = await sync_to_async(get_user_default_search_model)(user)
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor:
diff --git a/src/khoj/routers/api_model.py b/src/khoj/routers/api_model.py
index fc6be626..6d6b9e21 100644
--- a/src/khoj/routers/api_model.py
+++ b/src/khoj/routers/api_model.py
@@ -94,39 +94,6 @@ async def update_voice_model(
return Response(status_code=202, content=json.dumps({"status": "ok"}))
-@api_model.post("/search", status_code=200)
-@requires(["authenticated"])
-async def update_search_model(
- request: Request,
- id: str,
- client: Optional[str] = None,
-):
- user = request.user.object
-
- prev_config = await adapters.aget_user_search_model(user)
- new_config = await adapters.aset_user_search_model(user, int(id))
-
- if prev_config and int(id) != prev_config.id and new_config:
- await EntryAdapters.adelete_all_entries(user)
-
- if not prev_config:
- # If the use was just using the default config, delete all the entries and set the new config.
- await EntryAdapters.adelete_all_entries(user)
-
- if new_config is None:
- return {"status": "error", "message": "Model not found"}
- else:
- update_telemetry_state(
- request=request,
- telemetry_type="api",
- api="set_search_model",
- client=client,
- metadata={"search_model": new_config.setting.name},
- )
-
- return {"status": "ok"}
-
-
@api_model.post("/paint", status_code=200)
@requires(["authenticated"])
async def update_paint_model(
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index e9c752fb..b28bbe95 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -1706,13 +1706,6 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
for chat_model in chat_models:
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id})
- search_model_options = adapters.get_or_create_search_models().all()
- all_search_model_options = list()
- for search_model_option in search_model_options:
- all_search_model_options.append({"name": search_model_option.name, "id": search_model_option.id})
-
- current_search_model_option = adapters.get_user_search_model_or_default(user)
-
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list()
@@ -1745,8 +1738,6 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
"has_documents": has_documents,
"notion_token": notion_token,
# user model settings
- "search_model_options": all_search_model_options,
- "selected_search_model_config": current_search_model_option.id,
"chat_model_options": chat_model_options,
"selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None,
"paint_model_options": all_paint_model_options,
diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py
index b67132e4..fbc972a7 100644
--- a/src/khoj/search_type/text_search.py
+++ b/src/khoj/search_type/text_search.py
@@ -8,7 +8,11 @@ import torch
from asgiref.sync import sync_to_async
from sentence_transformers import util
-from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
+from khoj.database.adapters import (
+ EntryAdapters,
+ get_default_search_model,
+ get_user_default_search_model,
+)
from khoj.database.models import Agent
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
@@ -110,7 +114,7 @@ async def query(
file_type = search_type_to_embeddings_type[type.value]
query = raw_query
- search_model = await sync_to_async(get_user_search_model_or_default)(user)
+ search_model = await sync_to_async(get_user_default_search_model)(user)
if not max_distance:
if search_model.bi_encoder_confidence_threshold:
max_distance = search_model.bi_encoder_confidence_threshold