mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Remove user customized search model (#946)
- Use a single standard search model across the server. There's diminishing benefits for having multiple user-customizable search models. - We may want to add server-level customization for specific tasks - Store the search model used to generate a given entry on the `Entry` object - Remove user-facing APIs and view - Add a management command for migrating the default search model on the server In a future PR (after running the migration), we'll also remove the `UserSearchModelConfig`
This commit is contained in:
parent
f3ce47b445
commit
5120597d4e
11 changed files with 237 additions and 90 deletions
|
@ -718,7 +718,7 @@ export default function SettingsView() {
|
|||
};
|
||||
|
||||
const updateModel = (name: string) => async (id: string) => {
|
||||
if (!userConfig?.is_active && name !== "search") {
|
||||
if (!userConfig?.is_active) {
|
||||
toast({
|
||||
title: `Model Update`,
|
||||
description: `You need to be subscribed to update ${name} models`,
|
||||
|
@ -1233,27 +1233,6 @@ export default function SettingsView() {
|
|||
</CardFooter>
|
||||
</Card>
|
||||
)}
|
||||
{userConfig.search_model_options.length > 0 && (
|
||||
<Card className={cardClassName}>
|
||||
<CardHeader className="text-xl flex flex-row">
|
||||
<FileMagnifyingGlass className="h-7 w-7 mr-2" />
|
||||
Search
|
||||
</CardHeader>
|
||||
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
|
||||
<p className="text-gray-400">
|
||||
Pick the search model to find your documents
|
||||
</p>
|
||||
<DropdownComponent
|
||||
items={userConfig.search_model_options}
|
||||
selected={
|
||||
userConfig.selected_search_model_config
|
||||
}
|
||||
callbackFunc={updateModel("search")}
|
||||
/>
|
||||
</CardContent>
|
||||
<CardFooter className="flex flex-wrap gap-4"></CardFooter>
|
||||
</Card>
|
||||
)}
|
||||
{userConfig.paint_model_options.length > 0 && (
|
||||
<Card className={cardClassName}>
|
||||
<CardHeader className="text-xl flex flex-row">
|
||||
|
|
|
@ -466,18 +466,26 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
|||
return config
|
||||
|
||||
|
||||
def get_user_search_model_or_default(user=None):
|
||||
if user and UserSearchModelConfig.objects.filter(user=user).exists():
|
||||
return UserSearchModelConfig.objects.filter(user=user).first().setting
|
||||
def get_default_search_model() -> SearchModelConfig:
|
||||
default_search_model = SearchModelConfig.objects.filter(name="default").first()
|
||||
|
||||
if SearchModelConfig.objects.filter(name="default").exists():
|
||||
return SearchModelConfig.objects.filter(name="default").first()
|
||||
if default_search_model:
|
||||
return default_search_model
|
||||
else:
|
||||
SearchModelConfig.objects.create()
|
||||
|
||||
return SearchModelConfig.objects.first()
|
||||
|
||||
|
||||
def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:
|
||||
if user:
|
||||
user_search_model = UserSearchModelConfig.objects.filter(user=user).first()
|
||||
if user_search_model:
|
||||
return user_search_model.setting
|
||||
|
||||
return get_default_search_model()
|
||||
|
||||
|
||||
def get_or_create_search_models():
|
||||
search_models = SearchModelConfig.objects.all()
|
||||
if search_models.count() == 0:
|
||||
|
@ -487,21 +495,6 @@ def get_or_create_search_models():
|
|||
return search_models
|
||||
|
||||
|
||||
async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
|
||||
config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
|
||||
if not config:
|
||||
return None
|
||||
new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
||||
return new_config
|
||||
|
||||
|
||||
async def aget_user_search_model(user: KhojUser):
|
||||
config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
||||
if not config:
|
||||
return None
|
||||
return config.setting
|
||||
|
||||
|
||||
class ProcessLockAdapters:
|
||||
@staticmethod
|
||||
def get_process_lock(process_name: str):
|
||||
|
|
|
@ -126,6 +126,7 @@ class EntryAdmin(admin.ModelAdmin):
|
|||
"created_at",
|
||||
"updated_at",
|
||||
"user",
|
||||
"agent",
|
||||
"file_source",
|
||||
"file_type",
|
||||
"file_name",
|
||||
|
@ -135,6 +136,7 @@ class EntryAdmin(admin.ModelAdmin):
|
|||
list_filter = (
|
||||
"file_type",
|
||||
"user__email",
|
||||
"search_model__name",
|
||||
)
|
||||
ordering = ("-created_at",)
|
||||
|
||||
|
|
182
src/khoj/database/management/commands/change_default_model.py
Normal file
182
src/khoj/database/management/commands/change_default_model.py
Normal file
|
@ -0,0 +1,182 @@
|
|||
import logging
|
||||
from typing import List
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db import transaction
|
||||
from django.db.models import Count, Q
|
||||
from tqdm import tqdm
|
||||
|
||||
from khoj.database.adapters import get_default_search_model
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
Entry,
|
||||
KhojUser,
|
||||
SearchModelConfig,
|
||||
UserSearchModelConfig,
|
||||
)
|
||||
from khoj.processor.embeddings import EmbeddingsModel
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = "Convert all existing Entry objects to use a new default Search model."
|
||||
|
||||
def add_arguments(self, parser):
|
||||
# Pass default SearchModelConfig ID
|
||||
parser.add_argument(
|
||||
"--search_model_id",
|
||||
action="store",
|
||||
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects and UserSearchModelConfig objects.",
|
||||
required=True,
|
||||
)
|
||||
|
||||
# Set the apply flag to apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects.
|
||||
parser.add_argument(
|
||||
"--apply",
|
||||
action="store_true",
|
||||
help="Apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. Otherwise, only display the number of Entry objects and UserSearchModelConfig objects that will be affected.",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
@transaction.atomic
|
||||
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
|
||||
entries = Entry.objects.filter(entry_filter).all()
|
||||
compiled_entries = [entry.compiled for entry in entries]
|
||||
updated_entries: List[Entry] = []
|
||||
try:
|
||||
embeddings = embeddings_model.embed_documents(compiled_entries)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error embedding documents: {e}")
|
||||
return
|
||||
|
||||
for i, entry in enumerate(tqdm(entries)):
|
||||
entry.embeddings = embeddings[i]
|
||||
entry.search_model_id = search_model.id
|
||||
updated_entries.append(entry)
|
||||
|
||||
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
|
||||
|
||||
search_model_config_id = options.get("search_model_id")
|
||||
apply = options.get("apply")
|
||||
|
||||
logger.info(f"SearchModelConfig ID: {search_model_config_id}")
|
||||
logger.info(f"Apply: {apply}")
|
||||
|
||||
embeddings_model = dict()
|
||||
|
||||
search_models = SearchModelConfig.objects.all()
|
||||
for model in search_models:
|
||||
embeddings_model.update(
|
||||
{
|
||||
model.name: EmbeddingsModel(
|
||||
model.bi_encoder,
|
||||
model.embeddings_inference_endpoint,
|
||||
model.embeddings_inference_endpoint_api_key,
|
||||
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||
model_kwargs=model.bi_encoder_model_config,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
|
||||
logger.info(f"New default Search model: {new_default_search_model_config}")
|
||||
user_search_model_configs_to_update = UserSearchModelConfig.objects.exclude(
|
||||
setting_id=search_model_config_id
|
||||
).all()
|
||||
logger.info(f"Number of UserSearchModelConfig objects to update: {user_search_model_configs_to_update.count()}")
|
||||
|
||||
for user_config in user_search_model_configs_to_update:
|
||||
affected_user = user_config.user
|
||||
entry_filter = Q(user=affected_user)
|
||||
relevant_entries = Entry.objects.filter(entry_filter).all()
|
||||
logger.info(f"Number of Entry objects to update for user {affected_user}: {relevant_entries.count()}")
|
||||
|
||||
if apply:
|
||||
try:
|
||||
regenerate_entries(
|
||||
entry_filter,
|
||||
embeddings_model[new_default_search_model_config.name],
|
||||
new_default_search_model_config,
|
||||
)
|
||||
user_config.setting = new_default_search_model_config
|
||||
user_config.save()
|
||||
|
||||
logger.info(
|
||||
f"Updated UserSearchModelConfig object for user {affected_user} to use the new default Search model."
|
||||
)
|
||||
logger.info(
|
||||
f"Updated {relevant_entries.count()} Entry objects for user {affected_user} to use the new default Search model."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error embedding documents: {e}")
|
||||
|
||||
logger.info("----")
|
||||
|
||||
# There are also plenty of users who have indexed documents without explicitly creating a UserSearchModelConfig object. You would have to migrate these users as well, if the default is different from search_model_config_id.
|
||||
current_default = get_default_search_model()
|
||||
if current_default.id != new_default_search_model_config.id:
|
||||
users_without_user_search_model_config = KhojUser.objects.annotate(
|
||||
user_search_model_config_count=Count("usersearchmodelconfig")
|
||||
).filter(user_search_model_config_count=0)
|
||||
|
||||
logger.info(f"Number of User objects to update: {users_without_user_search_model_config.count()}")
|
||||
for user in users_without_user_search_model_config:
|
||||
entry_filter = Q(user=user)
|
||||
relevant_entries = Entry.objects.filter(entry_filter).all()
|
||||
logger.info(f"Number of Entry objects to update for user {user}: {relevant_entries.count()}")
|
||||
|
||||
if apply:
|
||||
try:
|
||||
regenerate_entries(
|
||||
entry_filter,
|
||||
embeddings_model[new_default_search_model_config.name],
|
||||
new_default_search_model_config,
|
||||
)
|
||||
|
||||
UserSearchModelConfig.objects.create(user=user, setting=new_default_search_model_config)
|
||||
|
||||
logger.info(
|
||||
f"Created UserSearchModelConfig object for user {user} to use the new default Search model."
|
||||
)
|
||||
logger.info(
|
||||
f"Updated {relevant_entries.count()} Entry objects for user {user} to use the new default Search model."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error embedding documents: {e}")
|
||||
else:
|
||||
logger.info("Default is the same as search_model_config_id.")
|
||||
|
||||
all_agents = Agent.objects.all()
|
||||
logger.info(f"Number of Agent objects to update: {all_agents.count()}")
|
||||
for agent in all_agents:
|
||||
entry_filter = Q(agent=agent)
|
||||
relevant_entries = Entry.objects.filter(entry_filter).all()
|
||||
logger.info(f"Number of Entry objects to update for agent {agent}: {relevant_entries.count()}")
|
||||
|
||||
if apply:
|
||||
try:
|
||||
regenerate_entries(
|
||||
entry_filter,
|
||||
embeddings_model[new_default_search_model_config.name],
|
||||
new_default_search_model_config,
|
||||
)
|
||||
logger.info(
|
||||
f"Updated {relevant_entries.count()} Entry objects for agent {agent} to use the new default Search model."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error embedding documents: {e}")
|
||||
if apply and current_default.id != new_default_search_model_config.id:
|
||||
# Get the existing default SearchModelConfig object and update its name
|
||||
current_default.name = f"prev_default_{current_default.id}"
|
||||
current_default.save()
|
||||
|
||||
# Update the new default SearchModelConfig object's name
|
||||
new_default_search_model_config.name = "default"
|
||||
new_default_search_model_config.save()
|
||||
if not apply:
|
||||
logger.info("Run the command with the --apply flag to apply the new default Search model.")
|
24
src/khoj/database/migrations/0072_entry_search_model.py
Normal file
24
src/khoj/database/migrations/0072_entry_search_model.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
# Generated by Django 5.0.8 on 2024-10-21 21:09
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0071_subscription_enabled_trial_at_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="entry",
|
||||
name="search_model",
|
||||
field=models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.SET_NULL,
|
||||
to="database.searchmodelconfig",
|
||||
),
|
||||
),
|
||||
]
|
|
@ -449,6 +449,7 @@ class UserVoiceModelConfig(BaseModel):
|
|||
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
# TODO Delete this model once all users have been migrated to the server's default settings
|
||||
class UserSearchModelConfig(BaseModel):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
||||
|
@ -535,6 +536,7 @@ class Entry(BaseModel):
|
|||
url = models.URLField(max_length=400, default=None, null=True, blank=True)
|
||||
hashed_value = models.CharField(max_length=100)
|
||||
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
||||
search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if self.user and self.agent:
|
||||
|
|
|
@ -12,7 +12,8 @@ from tqdm import tqdm
|
|||
from khoj.database.adapters import (
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
get_user_search_model_or_default,
|
||||
get_default_search_model,
|
||||
get_user_default_search_model,
|
||||
)
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import EntryDates, KhojUser
|
||||
|
@ -148,10 +149,10 @@ class TextToEntries(ABC):
|
|||
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
||||
|
||||
embeddings = []
|
||||
model = get_user_default_search_model(user=user)
|
||||
with timer("Generated embeddings for entries to add to database in", logger):
|
||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||
model = get_user_search_model_or_default(user)
|
||||
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
|
||||
|
||||
added_entries: list[DbEntry] = []
|
||||
|
@ -177,6 +178,7 @@ class TextToEntries(ABC):
|
|||
file_type=file_type,
|
||||
hashed_value=entry_hash,
|
||||
corpus_id=entry.corpus_id,
|
||||
search_model=model,
|
||||
)
|
||||
)
|
||||
try:
|
||||
|
|
|
@ -25,8 +25,9 @@ from khoj.database.adapters import (
|
|||
AutomationAdapters,
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
get_default_search_model,
|
||||
get_user_default_search_model,
|
||||
get_user_photo,
|
||||
get_user_search_model_or_default,
|
||||
)
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
|
@ -149,7 +150,7 @@ async def execute_search(
|
|||
encoded_asymmetric_query = None
|
||||
if t != SearchType.Image:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||
search_model = await sync_to_async(get_user_default_search_model)(user)
|
||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
|
|
|
@ -94,39 +94,6 @@ async def update_voice_model(
|
|||
return Response(status_code=202, content=json.dumps({"status": "ok"}))
|
||||
|
||||
|
||||
@api_model.post("/search", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def update_search_model(
|
||||
request: Request,
|
||||
id: str,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
prev_config = await adapters.aget_user_search_model(user)
|
||||
new_config = await adapters.aset_user_search_model(user, int(id))
|
||||
|
||||
if prev_config and int(id) != prev_config.id and new_config:
|
||||
await EntryAdapters.adelete_all_entries(user)
|
||||
|
||||
if not prev_config:
|
||||
# If the use was just using the default config, delete all the entries and set the new config.
|
||||
await EntryAdapters.adelete_all_entries(user)
|
||||
|
||||
if new_config is None:
|
||||
return {"status": "error", "message": "Model not found"}
|
||||
else:
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_search_model",
|
||||
client=client,
|
||||
metadata={"search_model": new_config.setting.name},
|
||||
)
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@api_model.post("/paint", status_code=200)
|
||||
@requires(["authenticated"])
|
||||
async def update_paint_model(
|
||||
|
|
|
@ -1706,13 +1706,6 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
|||
for chat_model in chat_models:
|
||||
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id})
|
||||
|
||||
search_model_options = adapters.get_or_create_search_models().all()
|
||||
all_search_model_options = list()
|
||||
for search_model_option in search_model_options:
|
||||
all_search_model_options.append({"name": search_model_option.name, "id": search_model_option.id})
|
||||
|
||||
current_search_model_option = adapters.get_user_search_model_or_default(user)
|
||||
|
||||
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
|
||||
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
||||
all_paint_model_options = list()
|
||||
|
@ -1745,8 +1738,6 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
|||
"has_documents": has_documents,
|
||||
"notion_token": notion_token,
|
||||
# user model settings
|
||||
"search_model_options": all_search_model_options,
|
||||
"selected_search_model_config": current_search_model_option.id,
|
||||
"chat_model_options": chat_model_options,
|
||||
"selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None,
|
||||
"paint_model_options": all_paint_model_options,
|
||||
|
|
|
@ -8,7 +8,11 @@ import torch
|
|||
from asgiref.sync import sync_to_async
|
||||
from sentence_transformers import util
|
||||
|
||||
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||
from khoj.database.adapters import (
|
||||
EntryAdapters,
|
||||
get_default_search_model,
|
||||
get_user_default_search_model,
|
||||
)
|
||||
from khoj.database.models import Agent
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import KhojUser
|
||||
|
@ -110,7 +114,7 @@ async def query(
|
|||
file_type = search_type_to_embeddings_type[type.value]
|
||||
|
||||
query = raw_query
|
||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||
search_model = await sync_to_async(get_user_default_search_model)(user)
|
||||
if not max_distance:
|
||||
if search_model.bi_encoder_confidence_threshold:
|
||||
max_distance = search_model.bi_encoder_confidence_threshold
|
||||
|
|
Loading…
Reference in a new issue