diff --git a/src/interface/web/app/settings/page.tsx b/src/interface/web/app/settings/page.tsx index d79eeff4..fe3e11e7 100644 --- a/src/interface/web/app/settings/page.tsx +++ b/src/interface/web/app/settings/page.tsx @@ -718,7 +718,7 @@ export default function SettingsView() { }; const updateModel = (name: string) => async (id: string) => { - if (!userConfig?.is_active && name !== "search") { + if (!userConfig?.is_active) { toast({ title: `Model Update`, description: `You need to be subscribed to update ${name} models`, @@ -1233,27 +1233,6 @@ export default function SettingsView() { )} - {userConfig.search_model_options.length > 0 && ( - - - - Search - - - - Pick the search model to find your documents - - - - - - )} {userConfig.paint_model_options.length > 0 && ( diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index d5312b29..b77eccdd 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -466,18 +466,26 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): return config -def get_user_search_model_or_default(user=None): - if user and UserSearchModelConfig.objects.filter(user=user).exists(): - return UserSearchModelConfig.objects.filter(user=user).first().setting +def get_default_search_model() -> SearchModelConfig: + default_search_model = SearchModelConfig.objects.filter(name="default").first() - if SearchModelConfig.objects.filter(name="default").exists(): - return SearchModelConfig.objects.filter(name="default").first() + if default_search_model: + return default_search_model else: SearchModelConfig.objects.create() return SearchModelConfig.objects.first() +def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig: + if user: + user_search_model = UserSearchModelConfig.objects.filter(user=user).first() + if user_search_model: + return user_search_model.setting + + return get_default_search_model() + + def get_or_create_search_models(): search_models = SearchModelConfig.objects.all() if search_models.count() == 0: @@ -487,21 +495,6 @@ def get_or_create_search_models(): return search_models -async def aset_user_search_model(user: KhojUser, search_model_config_id: int): - config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst() - if not config: - return None - new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) - return new_config - - -async def aget_user_search_model(user: KhojUser): - config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst() - if not config: - return None - return config.setting - - class ProcessLockAdapters: @staticmethod def get_process_lock(process_name: str): diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 5aa9204b..3087f7b1 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -126,6 +126,7 @@ class EntryAdmin(admin.ModelAdmin): "created_at", "updated_at", "user", + "agent", "file_source", "file_type", "file_name", @@ -135,6 +136,7 @@ class EntryAdmin(admin.ModelAdmin): list_filter = ( "file_type", "user__email", + "search_model__name", ) ordering = ("-created_at",) diff --git a/src/khoj/database/management/commands/change_default_model.py b/src/khoj/database/management/commands/change_default_model.py new file mode 100644 index 00000000..cfa78581 --- /dev/null +++ b/src/khoj/database/management/commands/change_default_model.py @@ -0,0 +1,182 @@ +import logging +from typing import List + +from django.core.management.base import BaseCommand +from django.db import transaction +from django.db.models import Count, Q +from tqdm import tqdm + +from khoj.database.adapters import get_default_search_model +from khoj.database.models import ( + Agent, + Entry, + KhojUser, + SearchModelConfig, + UserSearchModelConfig, +) +from khoj.processor.embeddings import EmbeddingsModel + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class Command(BaseCommand): + help = "Convert all existing Entry objects to use a new default Search model." + + def add_arguments(self, parser): + # Pass default SearchModelConfig ID + parser.add_argument( + "--search_model_id", + action="store", + help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects and UserSearchModelConfig objects.", + required=True, + ) + + # Set the apply flag to apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. + parser.add_argument( + "--apply", + action="store_true", + help="Apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. Otherwise, only display the number of Entry objects and UserSearchModelConfig objects that will be affected.", + ) + + def handle(self, *args, **options): + @transaction.atomic + def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig): + entries = Entry.objects.filter(entry_filter).all() + compiled_entries = [entry.compiled for entry in entries] + updated_entries: List[Entry] = [] + try: + embeddings = embeddings_model.embed_documents(compiled_entries) + + except Exception as e: + logger.error(f"Error embedding documents: {e}") + return + + for i, entry in enumerate(tqdm(entries)): + entry.embeddings = embeddings[i] + entry.search_model_id = search_model.id + updated_entries.append(entry) + + Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"]) + + search_model_config_id = options.get("search_model_id") + apply = options.get("apply") + + logger.info(f"SearchModelConfig ID: {search_model_config_id}") + logger.info(f"Apply: {apply}") + + embeddings_model = dict() + + search_models = SearchModelConfig.objects.all() + for model in search_models: + embeddings_model.update( + { + model.name: EmbeddingsModel( + model.bi_encoder, + model.embeddings_inference_endpoint, + model.embeddings_inference_endpoint_api_key, + query_encode_kwargs=model.bi_encoder_query_encode_config, + docs_encode_kwargs=model.bi_encoder_docs_encode_config, + model_kwargs=model.bi_encoder_model_config, + ) + } + ) + + new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id) + logger.info(f"New default Search model: {new_default_search_model_config}") + user_search_model_configs_to_update = UserSearchModelConfig.objects.exclude( + setting_id=search_model_config_id + ).all() + logger.info(f"Number of UserSearchModelConfig objects to update: {user_search_model_configs_to_update.count()}") + + for user_config in user_search_model_configs_to_update: + affected_user = user_config.user + entry_filter = Q(user=affected_user) + relevant_entries = Entry.objects.filter(entry_filter).all() + logger.info(f"Number of Entry objects to update for user {affected_user}: {relevant_entries.count()}") + + if apply: + try: + regenerate_entries( + entry_filter, + embeddings_model[new_default_search_model_config.name], + new_default_search_model_config, + ) + user_config.setting = new_default_search_model_config + user_config.save() + + logger.info( + f"Updated UserSearchModelConfig object for user {affected_user} to use the new default Search model." + ) + logger.info( + f"Updated {relevant_entries.count()} Entry objects for user {affected_user} to use the new default Search model." + ) + + except Exception as e: + logger.error(f"Error embedding documents: {e}") + + logger.info("----") + + # There are also plenty of users who have indexed documents without explicitly creating a UserSearchModelConfig object. You would have to migrate these users as well, if the default is different from search_model_config_id. + current_default = get_default_search_model() + if current_default.id != new_default_search_model_config.id: + users_without_user_search_model_config = KhojUser.objects.annotate( + user_search_model_config_count=Count("usersearchmodelconfig") + ).filter(user_search_model_config_count=0) + + logger.info(f"Number of User objects to update: {users_without_user_search_model_config.count()}") + for user in users_without_user_search_model_config: + entry_filter = Q(user=user) + relevant_entries = Entry.objects.filter(entry_filter).all() + logger.info(f"Number of Entry objects to update for user {user}: {relevant_entries.count()}") + + if apply: + try: + regenerate_entries( + entry_filter, + embeddings_model[new_default_search_model_config.name], + new_default_search_model_config, + ) + + UserSearchModelConfig.objects.create(user=user, setting=new_default_search_model_config) + + logger.info( + f"Created UserSearchModelConfig object for user {user} to use the new default Search model." + ) + logger.info( + f"Updated {relevant_entries.count()} Entry objects for user {user} to use the new default Search model." + ) + except Exception as e: + logger.error(f"Error embedding documents: {e}") + else: + logger.info("Default is the same as search_model_config_id.") + + all_agents = Agent.objects.all() + logger.info(f"Number of Agent objects to update: {all_agents.count()}") + for agent in all_agents: + entry_filter = Q(agent=agent) + relevant_entries = Entry.objects.filter(entry_filter).all() + logger.info(f"Number of Entry objects to update for agent {agent}: {relevant_entries.count()}") + + if apply: + try: + regenerate_entries( + entry_filter, + embeddings_model[new_default_search_model_config.name], + new_default_search_model_config, + ) + logger.info( + f"Updated {relevant_entries.count()} Entry objects for agent {agent} to use the new default Search model." + ) + except Exception as e: + logger.error(f"Error embedding documents: {e}") + if apply and current_default.id != new_default_search_model_config.id: + # Get the existing default SearchModelConfig object and update its name + current_default.name = f"prev_default_{current_default.id}" + current_default.save() + + # Update the new default SearchModelConfig object's name + new_default_search_model_config.name = "default" + new_default_search_model_config.save() + if not apply: + logger.info("Run the command with the --apply flag to apply the new default Search model.") diff --git a/src/khoj/database/migrations/0072_entry_search_model.py b/src/khoj/database/migrations/0072_entry_search_model.py new file mode 100644 index 00000000..545f1f62 --- /dev/null +++ b/src/khoj/database/migrations/0072_entry_search_model.py @@ -0,0 +1,24 @@ +# Generated by Django 5.0.8 on 2024-10-21 21:09 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0071_subscription_enabled_trial_at_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="entry", + name="search_model", + field=models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + to="database.searchmodelconfig", + ), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index c89c409a..7b5bbd12 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -449,6 +449,7 @@ class UserVoiceModelConfig(BaseModel): setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True) +# TODO Delete this model once all users have been migrated to the server's default settings class UserSearchModelConfig(BaseModel): user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE) @@ -535,6 +536,7 @@ class Entry(BaseModel): url = models.URLField(max_length=400, default=None, null=True, blank=True) hashed_value = models.CharField(max_length=100) corpus_id = models.UUIDField(default=uuid.uuid4, editable=False) + search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True) def save(self, *args, **kwargs): if self.user and self.agent: diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index 6fee9c0c..b0b74996 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -12,7 +12,8 @@ from tqdm import tqdm from khoj.database.adapters import ( EntryAdapters, FileObjectAdapters, - get_user_search_model_or_default, + get_default_search_model, + get_user_default_search_model, ) from khoj.database.models import Entry as DbEntry from khoj.database.models import EntryDates, KhojUser @@ -148,10 +149,10 @@ class TextToEntries(ABC): hashes_to_process |= hashes_for_file - existing_entry_hashes embeddings = [] + model = get_user_default_search_model(user=user) with timer("Generated embeddings for entries to add to database in", logger): entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process] data_to_embed = [getattr(entry, key) for entry in entries_to_process] - model = get_user_search_model_or_default(user) embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed) added_entries: list[DbEntry] = [] @@ -177,6 +178,7 @@ class TextToEntries(ABC): file_type=file_type, hashed_value=entry_hash, corpus_id=entry.corpus_id, + search_model=model, ) ) try: diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index c542b1f3..8254da4d 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -25,8 +25,9 @@ from khoj.database.adapters import ( AutomationAdapters, ConversationAdapters, EntryAdapters, + get_default_search_model, + get_user_default_search_model, get_user_photo, - get_user_search_model_or_default, ) from khoj.database.models import ( Agent, @@ -149,7 +150,7 @@ async def execute_search( encoded_asymmetric_query = None if t != SearchType.Image: with timer("Encoding query took", logger=logger): - search_model = await sync_to_async(get_user_search_model_or_default)(user) + search_model = await sync_to_async(get_user_default_search_model)(user) encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) with concurrent.futures.ThreadPoolExecutor() as executor: diff --git a/src/khoj/routers/api_model.py b/src/khoj/routers/api_model.py index fc6be626..6d6b9e21 100644 --- a/src/khoj/routers/api_model.py +++ b/src/khoj/routers/api_model.py @@ -94,39 +94,6 @@ async def update_voice_model( return Response(status_code=202, content=json.dumps({"status": "ok"})) -@api_model.post("/search", status_code=200) -@requires(["authenticated"]) -async def update_search_model( - request: Request, - id: str, - client: Optional[str] = None, -): - user = request.user.object - - prev_config = await adapters.aget_user_search_model(user) - new_config = await adapters.aset_user_search_model(user, int(id)) - - if prev_config and int(id) != prev_config.id and new_config: - await EntryAdapters.adelete_all_entries(user) - - if not prev_config: - # If the use was just using the default config, delete all the entries and set the new config. - await EntryAdapters.adelete_all_entries(user) - - if new_config is None: - return {"status": "error", "message": "Model not found"} - else: - update_telemetry_state( - request=request, - telemetry_type="api", - api="set_search_model", - client=client, - metadata={"search_model": new_config.setting.name}, - ) - - return {"status": "ok"} - - @api_model.post("/paint", status_code=200) @requires(["authenticated"]) async def update_paint_model( diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e9c752fb..b28bbe95 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1706,13 +1706,6 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) for chat_model in chat_models: chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id}) - search_model_options = adapters.get_or_create_search_models().all() - all_search_model_options = list() - for search_model_option in search_model_options: - all_search_model_options.append({"name": search_model_option.name, "id": search_model_option.id}) - - current_search_model_option = adapters.get_user_search_model_or_default(user) - selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user) paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() all_paint_model_options = list() @@ -1745,8 +1738,6 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False) "has_documents": has_documents, "notion_token": notion_token, # user model settings - "search_model_options": all_search_model_options, - "selected_search_model_config": current_search_model_option.id, "chat_model_options": chat_model_options, "selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None, "paint_model_options": all_paint_model_options, diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index b67132e4..fbc972a7 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -8,7 +8,11 @@ import torch from asgiref.sync import sync_to_async from sentence_transformers import util -from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default +from khoj.database.adapters import ( + EntryAdapters, + get_default_search_model, + get_user_default_search_model, +) from khoj.database.models import Agent from khoj.database.models import Entry as DbEntry from khoj.database.models import KhojUser @@ -110,7 +114,7 @@ async def query( file_type = search_type_to_embeddings_type[type.value] query = raw_query - search_model = await sync_to_async(get_user_search_model_or_default)(user) + search_model = await sync_to_async(get_user_default_search_model)(user) if not max_distance: if search_model.bi_encoder_confidence_threshold: max_distance = search_model.bi_encoder_confidence_threshold
- Pick the search model to find your documents -