Remove user customized search model (#946)

- Use a single standard search model across the server. There's diminishing benefits for having multiple user-customizable search models. 
- We may want to add server-level customization for specific tasks
- Store the search model used to generate a given entry on the `Entry` object
- Remove user-facing APIs and view
- Add a management command for migrating the default search model on the server

In a future PR (after running the migration), we'll also remove the `UserSearchModelConfig`
This commit is contained in:
sabaimran 2024-10-23 17:38:37 -07:00 committed by GitHub
parent f3ce47b445
commit 5120597d4e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 237 additions and 90 deletions

View file

@ -718,7 +718,7 @@ export default function SettingsView() {
}; };
const updateModel = (name: string) => async (id: string) => { const updateModel = (name: string) => async (id: string) => {
if (!userConfig?.is_active && name !== "search") { if (!userConfig?.is_active) {
toast({ toast({
title: `Model Update`, title: `Model Update`,
description: `You need to be subscribed to update ${name} models`, description: `You need to be subscribed to update ${name} models`,
@ -1233,27 +1233,6 @@ export default function SettingsView() {
</CardFooter> </CardFooter>
</Card> </Card>
)} )}
{userConfig.search_model_options.length > 0 && (
<Card className={cardClassName}>
<CardHeader className="text-xl flex flex-row">
<FileMagnifyingGlass className="h-7 w-7 mr-2" />
Search
</CardHeader>
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
<p className="text-gray-400">
Pick the search model to find your documents
</p>
<DropdownComponent
items={userConfig.search_model_options}
selected={
userConfig.selected_search_model_config
}
callbackFunc={updateModel("search")}
/>
</CardContent>
<CardFooter className="flex flex-wrap gap-4"></CardFooter>
</Card>
)}
{userConfig.paint_model_options.length > 0 && ( {userConfig.paint_model_options.length > 0 && (
<Card className={cardClassName}> <Card className={cardClassName}>
<CardHeader className="text-xl flex flex-row"> <CardHeader className="text-xl flex flex-row">

View file

@ -466,18 +466,26 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config return config
def get_user_search_model_or_default(user=None): def get_default_search_model() -> SearchModelConfig:
if user and UserSearchModelConfig.objects.filter(user=user).exists(): default_search_model = SearchModelConfig.objects.filter(name="default").first()
return UserSearchModelConfig.objects.filter(user=user).first().setting
if SearchModelConfig.objects.filter(name="default").exists(): if default_search_model:
return SearchModelConfig.objects.filter(name="default").first() return default_search_model
else: else:
SearchModelConfig.objects.create() SearchModelConfig.objects.create()
return SearchModelConfig.objects.first() return SearchModelConfig.objects.first()
def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:
if user:
user_search_model = UserSearchModelConfig.objects.filter(user=user).first()
if user_search_model:
return user_search_model.setting
return get_default_search_model()
def get_or_create_search_models(): def get_or_create_search_models():
search_models = SearchModelConfig.objects.all() search_models = SearchModelConfig.objects.all()
if search_models.count() == 0: if search_models.count() == 0:
@ -487,21 +495,6 @@ def get_or_create_search_models():
return search_models return search_models
async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
if not config:
return None
new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
return new_config
async def aget_user_search_model(user: KhojUser):
config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
if not config:
return None
return config.setting
class ProcessLockAdapters: class ProcessLockAdapters:
@staticmethod @staticmethod
def get_process_lock(process_name: str): def get_process_lock(process_name: str):

View file

@ -126,6 +126,7 @@ class EntryAdmin(admin.ModelAdmin):
"created_at", "created_at",
"updated_at", "updated_at",
"user", "user",
"agent",
"file_source", "file_source",
"file_type", "file_type",
"file_name", "file_name",
@ -135,6 +136,7 @@ class EntryAdmin(admin.ModelAdmin):
list_filter = ( list_filter = (
"file_type", "file_type",
"user__email", "user__email",
"search_model__name",
) )
ordering = ("-created_at",) ordering = ("-created_at",)

View file

@ -0,0 +1,182 @@
import logging
from typing import List
from django.core.management.base import BaseCommand
from django.db import transaction
from django.db.models import Count, Q
from tqdm import tqdm
from khoj.database.adapters import get_default_search_model
from khoj.database.models import (
Agent,
Entry,
KhojUser,
SearchModelConfig,
UserSearchModelConfig,
)
from khoj.processor.embeddings import EmbeddingsModel
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Command(BaseCommand):
help = "Convert all existing Entry objects to use a new default Search model."
def add_arguments(self, parser):
# Pass default SearchModelConfig ID
parser.add_argument(
"--search_model_id",
action="store",
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects and UserSearchModelConfig objects.",
required=True,
)
# Set the apply flag to apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects.
parser.add_argument(
"--apply",
action="store_true",
help="Apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. Otherwise, only display the number of Entry objects and UserSearchModelConfig objects that will be affected.",
)
def handle(self, *args, **options):
@transaction.atomic
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
entries = Entry.objects.filter(entry_filter).all()
compiled_entries = [entry.compiled for entry in entries]
updated_entries: List[Entry] = []
try:
embeddings = embeddings_model.embed_documents(compiled_entries)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
return
for i, entry in enumerate(tqdm(entries)):
entry.embeddings = embeddings[i]
entry.search_model_id = search_model.id
updated_entries.append(entry)
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
search_model_config_id = options.get("search_model_id")
apply = options.get("apply")
logger.info(f"SearchModelConfig ID: {search_model_config_id}")
logger.info(f"Apply: {apply}")
embeddings_model = dict()
search_models = SearchModelConfig.objects.all()
for model in search_models:
embeddings_model.update(
{
model.name: EmbeddingsModel(
model.bi_encoder,
model.embeddings_inference_endpoint,
model.embeddings_inference_endpoint_api_key,
query_encode_kwargs=model.bi_encoder_query_encode_config,
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
model_kwargs=model.bi_encoder_model_config,
)
}
)
new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
logger.info(f"New default Search model: {new_default_search_model_config}")
user_search_model_configs_to_update = UserSearchModelConfig.objects.exclude(
setting_id=search_model_config_id
).all()
logger.info(f"Number of UserSearchModelConfig objects to update: {user_search_model_configs_to_update.count()}")
for user_config in user_search_model_configs_to_update:
affected_user = user_config.user
entry_filter = Q(user=affected_user)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for user {affected_user}: {relevant_entries.count()}")
if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)
user_config.setting = new_default_search_model_config
user_config.save()
logger.info(
f"Updated UserSearchModelConfig object for user {affected_user} to use the new default Search model."
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for user {affected_user} to use the new default Search model."
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
logger.info("----")
# There are also plenty of users who have indexed documents without explicitly creating a UserSearchModelConfig object. You would have to migrate these users as well, if the default is different from search_model_config_id.
current_default = get_default_search_model()
if current_default.id != new_default_search_model_config.id:
users_without_user_search_model_config = KhojUser.objects.annotate(
user_search_model_config_count=Count("usersearchmodelconfig")
).filter(user_search_model_config_count=0)
logger.info(f"Number of User objects to update: {users_without_user_search_model_config.count()}")
for user in users_without_user_search_model_config:
entry_filter = Q(user=user)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for user {user}: {relevant_entries.count()}")
if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)
UserSearchModelConfig.objects.create(user=user, setting=new_default_search_model_config)
logger.info(
f"Created UserSearchModelConfig object for user {user} to use the new default Search model."
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for user {user} to use the new default Search model."
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
else:
logger.info("Default is the same as search_model_config_id.")
all_agents = Agent.objects.all()
logger.info(f"Number of Agent objects to update: {all_agents.count()}")
for agent in all_agents:
entry_filter = Q(agent=agent)
relevant_entries = Entry.objects.filter(entry_filter).all()
logger.info(f"Number of Entry objects to update for agent {agent}: {relevant_entries.count()}")
if apply:
try:
regenerate_entries(
entry_filter,
embeddings_model[new_default_search_model_config.name],
new_default_search_model_config,
)
logger.info(
f"Updated {relevant_entries.count()} Entry objects for agent {agent} to use the new default Search model."
)
except Exception as e:
logger.error(f"Error embedding documents: {e}")
if apply and current_default.id != new_default_search_model_config.id:
# Get the existing default SearchModelConfig object and update its name
current_default.name = f"prev_default_{current_default.id}"
current_default.save()
# Update the new default SearchModelConfig object's name
new_default_search_model_config.name = "default"
new_default_search_model_config.save()
if not apply:
logger.info("Run the command with the --apply flag to apply the new default Search model.")

View file

@ -0,0 +1,24 @@
# Generated by Django 5.0.8 on 2024-10-21 21:09
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0071_subscription_enabled_trial_at_and_more"),
]
operations = [
migrations.AddField(
model_name="entry",
name="search_model",
field=models.ForeignKey(
blank=True,
default=None,
null=True,
on_delete=django.db.models.deletion.SET_NULL,
to="database.searchmodelconfig",
),
),
]

View file

@ -449,6 +449,7 @@ class UserVoiceModelConfig(BaseModel):
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True) setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
# TODO Delete this model once all users have been migrated to the server's default settings
class UserSearchModelConfig(BaseModel): class UserSearchModelConfig(BaseModel):
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE) setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
@ -535,6 +536,7 @@ class Entry(BaseModel):
url = models.URLField(max_length=400, default=None, null=True, blank=True) url = models.URLField(max_length=400, default=None, null=True, blank=True)
hashed_value = models.CharField(max_length=100) hashed_value = models.CharField(max_length=100)
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False) corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True)
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
if self.user and self.agent: if self.user and self.agent:

View file

@ -12,7 +12,8 @@ from tqdm import tqdm
from khoj.database.adapters import ( from khoj.database.adapters import (
EntryAdapters, EntryAdapters,
FileObjectAdapters, FileObjectAdapters,
get_user_search_model_or_default, get_default_search_model,
get_user_default_search_model,
) )
from khoj.database.models import Entry as DbEntry from khoj.database.models import Entry as DbEntry
from khoj.database.models import EntryDates, KhojUser from khoj.database.models import EntryDates, KhojUser
@ -148,10 +149,10 @@ class TextToEntries(ABC):
hashes_to_process |= hashes_for_file - existing_entry_hashes hashes_to_process |= hashes_for_file - existing_entry_hashes
embeddings = [] embeddings = []
model = get_user_default_search_model(user=user)
with timer("Generated embeddings for entries to add to database in", logger): with timer("Generated embeddings for entries to add to database in", logger):
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process] entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
data_to_embed = [getattr(entry, key) for entry in entries_to_process] data_to_embed = [getattr(entry, key) for entry in entries_to_process]
model = get_user_search_model_or_default(user)
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed) embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
added_entries: list[DbEntry] = [] added_entries: list[DbEntry] = []
@ -177,6 +178,7 @@ class TextToEntries(ABC):
file_type=file_type, file_type=file_type,
hashed_value=entry_hash, hashed_value=entry_hash,
corpus_id=entry.corpus_id, corpus_id=entry.corpus_id,
search_model=model,
) )
) )
try: try:

View file

@ -25,8 +25,9 @@ from khoj.database.adapters import (
AutomationAdapters, AutomationAdapters,
ConversationAdapters, ConversationAdapters,
EntryAdapters, EntryAdapters,
get_default_search_model,
get_user_default_search_model,
get_user_photo, get_user_photo,
get_user_search_model_or_default,
) )
from khoj.database.models import ( from khoj.database.models import (
Agent, Agent,
@ -149,7 +150,7 @@ async def execute_search(
encoded_asymmetric_query = None encoded_asymmetric_query = None
if t != SearchType.Image: if t != SearchType.Image:
with timer("Encoding query took", logger=logger): with timer("Encoding query took", logger=logger):
search_model = await sync_to_async(get_user_search_model_or_default)(user) search_model = await sync_to_async(get_user_default_search_model)(user)
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:

View file

@ -94,39 +94,6 @@ async def update_voice_model(
return Response(status_code=202, content=json.dumps({"status": "ok"})) return Response(status_code=202, content=json.dumps({"status": "ok"}))
@api_model.post("/search", status_code=200)
@requires(["authenticated"])
async def update_search_model(
request: Request,
id: str,
client: Optional[str] = None,
):
user = request.user.object
prev_config = await adapters.aget_user_search_model(user)
new_config = await adapters.aset_user_search_model(user, int(id))
if prev_config and int(id) != prev_config.id and new_config:
await EntryAdapters.adelete_all_entries(user)
if not prev_config:
# If the use was just using the default config, delete all the entries and set the new config.
await EntryAdapters.adelete_all_entries(user)
if new_config is None:
return {"status": "error", "message": "Model not found"}
else:
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_search_model",
client=client,
metadata={"search_model": new_config.setting.name},
)
return {"status": "ok"}
@api_model.post("/paint", status_code=200) @api_model.post("/paint", status_code=200)
@requires(["authenticated"]) @requires(["authenticated"])
async def update_paint_model( async def update_paint_model(

View file

@ -1706,13 +1706,6 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
for chat_model in chat_models: for chat_model in chat_models:
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id}) chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id})
search_model_options = adapters.get_or_create_search_models().all()
all_search_model_options = list()
for search_model_option in search_model_options:
all_search_model_options.append({"name": search_model_option.name, "id": search_model_option.id})
current_search_model_option = adapters.get_user_search_model_or_default(user)
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user) selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all() paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
all_paint_model_options = list() all_paint_model_options = list()
@ -1745,8 +1738,6 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
"has_documents": has_documents, "has_documents": has_documents,
"notion_token": notion_token, "notion_token": notion_token,
# user model settings # user model settings
"search_model_options": all_search_model_options,
"selected_search_model_config": current_search_model_option.id,
"chat_model_options": chat_model_options, "chat_model_options": chat_model_options,
"selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None, "selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None,
"paint_model_options": all_paint_model_options, "paint_model_options": all_paint_model_options,

View file

@ -8,7 +8,11 @@ import torch
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from sentence_transformers import util from sentence_transformers import util
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default from khoj.database.adapters import (
EntryAdapters,
get_default_search_model,
get_user_default_search_model,
)
from khoj.database.models import Agent from khoj.database.models import Agent
from khoj.database.models import Entry as DbEntry from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
@ -110,7 +114,7 @@ async def query(
file_type = search_type_to_embeddings_type[type.value] file_type = search_type_to_embeddings_type[type.value]
query = raw_query query = raw_query
search_model = await sync_to_async(get_user_search_model_or_default)(user) search_model = await sync_to_async(get_user_default_search_model)(user)
if not max_distance: if not max_distance:
if search_model.bi_encoder_confidence_threshold: if search_model.bi_encoder_confidence_threshold:
max_distance = search_model.bi_encoder_confidence_threshold max_distance = search_model.bi_encoder_confidence_threshold