Deprecate the UserSearchModelConfig and remove all references

- The server has moved to a model of standardization for the embeddings generation workflow. Remove references to the support for differentiated models.
- The migration script fo ra new model needs to be updated to accommodate full regeneration.
This commit is contained in:
sabaimran 2024-11-04 12:24:41 -08:00
parent 99c1d2831a
commit 1e89baca7b
8 changed files with 25 additions and 99 deletions

View file

@ -48,7 +48,6 @@ from khoj.database.models import (
TextToImageModelConfig,
UserConversationConfig,
UserRequests,
UserSearchModelConfig,
UserTextToImageModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
@ -481,15 +480,6 @@ def get_default_search_model() -> SearchModelConfig:
return SearchModelConfig.objects.first()
def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:
if user:
user_search_model = UserSearchModelConfig.objects.filter(user=user).first()
if user_search_model:
return user_search_model.setting
return get_default_search_model()
def get_or_create_search_models():
search_models = SearchModelConfig.objects.all()
if search_models.count() == 0:

View file

@ -28,7 +28,6 @@ from khoj.database.models import (
TextToImageModelConfig,
UserConversationConfig,
UserRequests,
UserSearchModelConfig,
UserVoiceModelConfig,
VoiceModelOption,
WebScraper,
@ -99,7 +98,6 @@ admin.site.register(KhojUser, KhojUserAdmin)
admin.site.register(ProcessLock)
admin.site.register(SpeechToTextModelOptions)
admin.site.register(ReflectiveQuestion)
admin.site.register(UserSearchModelConfig)
admin.site.register(ClientApplication)
admin.site.register(GithubConfig)
admin.site.register(NotionConfig)

View file

@ -7,13 +7,7 @@ from django.db.models import Count, Q
from tqdm import tqdm
from khoj.database.adapters import get_default_search_model
from khoj.database.models import (
Agent,
Entry,
KhojUser,
SearchModelConfig,
UserSearchModelConfig,
)
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig
from khoj.processor.embeddings import EmbeddingsModel
logging.basicConfig(level=logging.INFO)
@ -30,15 +24,15 @@ class Command(BaseCommand):
parser.add_argument(
"--search_model_id",
action="store",
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects and UserSearchModelConfig objects.",
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects.",
required=True,
)
# Set the apply flag to apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects.
# Set the apply flag to apply the new default Search model to all existing Entry objects.
parser.add_argument(
"--apply",
action="store_true",
help="Apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. Otherwise, only display the number of Entry objects and UserSearchModelConfig objects that will be affected.",
help="Apply the new default Search model to all existing Entry objects. Otherwise, only display the number of Entry objects that will be affected.",
)
def handle(self, *args, **options):
@ -88,72 +82,12 @@ class Command(BaseCommand):
new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
logger.info(f"New default Search model: {new_default_search_model_config}")
user_search_model_configs_to_update = UserSearchModelConfig.objects.exclude(
setting_id=search_model_config_id
).all()
logger.info(f"Number of UserSearchModelConfig objects to update: {user_search_model_configs_to_update.count()}")
for user_config in user_search_model_configs_to_update:
affected_user = user_config.user
entry_filter = Q(user=affected_user)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for user {affected_user}: {relevant_entries.count()}")
if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)
user_config.setting = new_default_search_model_config
user_config.save()
logger.info(
f"Updated UserSearchModelConfig object for user {affected_user} to use the new default Search model."
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for user {affected_user} to use the new default Search model."
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
logger.info("----")
# There are also plenty of users who have indexed documents without explicitly creating a UserSearchModelConfig object. You would have to migrate these users as well, if the default is different from search_model_config_id.
current_default = get_default_search_model()
if current_default.id != new_default_search_model_config.id:
users_without_user_search_model_config = KhojUser.objects.annotate(
user_search_model_config_count=Count("usersearchmodelconfig")
).filter(user_search_model_config_count=0)
logger.info(f"Number of User objects to update: {users_without_user_search_model_config.count()}")
for user in users_without_user_search_model_config:
entry_filter = Q(user=user)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for user {user}: {relevant_entries.count()}")
if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)
UserSearchModelConfig.objects.create(user=user, setting=new_default_search_model_config)
logger.info(
f"Created UserSearchModelConfig object for user {user} to use the new default Search model."
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for user {user} to use the new default Search model."
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
else:
logger.info("Default is the same as search_model_config_id.")
# TODO: Migrate all Entry objects to use the new default Search model
all_agents = Agent.objects.all()
logger.info(f"Number of Agent objects to update: {all_agents.count()}")
@ -174,6 +108,7 @@ class Command(BaseCommand):
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
if apply and current_default.id != new_default_search_model_config.id:
# Get the existing default SearchModelConfig object and update its name
current_default.name = f"prev_default_{current_default.id}"

View file

@ -0,0 +1,15 @@
# Generated by Django 5.0.9 on 2024-11-04 19:56
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
("database", "0072_entry_search_model"),
]
operations = [
migrations.DeleteModel(
name="UserSearchModelConfig",
),
]

View file

@ -449,12 +449,6 @@ class UserVoiceModelConfig(BaseModel):
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
# TODO Delete this model once all users have been migrated to the server's default settings
class UserSearchModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
class UserTextToImageModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)

View file

@ -13,7 +13,6 @@ from khoj.database.adapters import (
EntryAdapters,
FileObjectAdapters,
get_default_search_model,
get_user_default_search_model,
)
from khoj.database.models import Entry as DbEntry
from khoj.database.models import EntryDates, KhojUser
@ -149,7 +148,7 @@ class TextToEntries(ABC):
hashes_to_process |= hashes_for_file - existing_entry_hashes
embeddings = []
model = get_user_default_search_model(user=user)
model = get_default_search_model()
with timer("Generated embeddings for entries to add to database in", logger):
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process]

View file

@ -26,7 +26,6 @@ from khoj.database.adapters import (
ConversationAdapters,
EntryAdapters,
get_default_search_model,
get_user_default_search_model,
get_user_photo,
)
from khoj.database.models import (
@ -151,7 +150,7 @@ async def execute_search(
encoded_asymmetric_query = None
if t != SearchType.Image:
with timer("Encoding query took", logger=logger):
search_model = await sync_to_async(get_user_default_search_model)(user)
search_model = await sync_to_async(get_default_search_model)()
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor:

View file

@ -8,11 +8,7 @@ import torch
from asgiref.sync import sync_to_async
from sentence_transformers import util
from khoj.database.adapters import (
EntryAdapters,
get_default_search_model,
get_user_default_search_model,
)
from khoj.database.adapters import EntryAdapters, get_default_search_model
from khoj.database.models import Agent
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
@ -114,7 +110,7 @@ async def query(
file_type = search_type_to_embeddings_type[type.value]
query = raw_query
search_model = await sync_to_async(get_user_default_search_model)(user)
search_model = await sync_to_async(get_default_search_model)()
if not max_distance:
if search_model.bi_encoder_confidence_threshold:
max_distance = search_model.bi_encoder_confidence_threshold