mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Deprecate the UserSearchModelConfig and remove all references
- The server has moved to a model of standardization for the embeddings generation workflow. Remove references to the support for differentiated models. - The migration script fo ra new model needs to be updated to accommodate full regeneration.
This commit is contained in:
parent
99c1d2831a
commit
1e89baca7b
8 changed files with 25 additions and 99 deletions
|
@ -48,7 +48,6 @@ from khoj.database.models import (
|
||||||
TextToImageModelConfig,
|
TextToImageModelConfig,
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
UserRequests,
|
UserRequests,
|
||||||
UserSearchModelConfig,
|
|
||||||
UserTextToImageModelConfig,
|
UserTextToImageModelConfig,
|
||||||
UserVoiceModelConfig,
|
UserVoiceModelConfig,
|
||||||
VoiceModelOption,
|
VoiceModelOption,
|
||||||
|
@ -481,15 +480,6 @@ def get_default_search_model() -> SearchModelConfig:
|
||||||
return SearchModelConfig.objects.first()
|
return SearchModelConfig.objects.first()
|
||||||
|
|
||||||
|
|
||||||
def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:
|
|
||||||
if user:
|
|
||||||
user_search_model = UserSearchModelConfig.objects.filter(user=user).first()
|
|
||||||
if user_search_model:
|
|
||||||
return user_search_model.setting
|
|
||||||
|
|
||||||
return get_default_search_model()
|
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_search_models():
|
def get_or_create_search_models():
|
||||||
search_models = SearchModelConfig.objects.all()
|
search_models = SearchModelConfig.objects.all()
|
||||||
if search_models.count() == 0:
|
if search_models.count() == 0:
|
||||||
|
|
|
@ -28,7 +28,6 @@ from khoj.database.models import (
|
||||||
TextToImageModelConfig,
|
TextToImageModelConfig,
|
||||||
UserConversationConfig,
|
UserConversationConfig,
|
||||||
UserRequests,
|
UserRequests,
|
||||||
UserSearchModelConfig,
|
|
||||||
UserVoiceModelConfig,
|
UserVoiceModelConfig,
|
||||||
VoiceModelOption,
|
VoiceModelOption,
|
||||||
WebScraper,
|
WebScraper,
|
||||||
|
@ -99,7 +98,6 @@ admin.site.register(KhojUser, KhojUserAdmin)
|
||||||
admin.site.register(ProcessLock)
|
admin.site.register(ProcessLock)
|
||||||
admin.site.register(SpeechToTextModelOptions)
|
admin.site.register(SpeechToTextModelOptions)
|
||||||
admin.site.register(ReflectiveQuestion)
|
admin.site.register(ReflectiveQuestion)
|
||||||
admin.site.register(UserSearchModelConfig)
|
|
||||||
admin.site.register(ClientApplication)
|
admin.site.register(ClientApplication)
|
||||||
admin.site.register(GithubConfig)
|
admin.site.register(GithubConfig)
|
||||||
admin.site.register(NotionConfig)
|
admin.site.register(NotionConfig)
|
||||||
|
|
|
@ -7,13 +7,7 @@ from django.db.models import Count, Q
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from khoj.database.adapters import get_default_search_model
|
from khoj.database.adapters import get_default_search_model
|
||||||
from khoj.database.models import (
|
from khoj.database.models import Agent, Entry, KhojUser, SearchModelConfig
|
||||||
Agent,
|
|
||||||
Entry,
|
|
||||||
KhojUser,
|
|
||||||
SearchModelConfig,
|
|
||||||
UserSearchModelConfig,
|
|
||||||
)
|
|
||||||
from khoj.processor.embeddings import EmbeddingsModel
|
from khoj.processor.embeddings import EmbeddingsModel
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
@ -30,15 +24,15 @@ class Command(BaseCommand):
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--search_model_id",
|
"--search_model_id",
|
||||||
action="store",
|
action="store",
|
||||||
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects and UserSearchModelConfig objects.",
|
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects.",
|
||||||
required=True,
|
required=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set the apply flag to apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects.
|
# Set the apply flag to apply the new default Search model to all existing Entry objects.
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--apply",
|
"--apply",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. Otherwise, only display the number of Entry objects and UserSearchModelConfig objects that will be affected.",
|
help="Apply the new default Search model to all existing Entry objects. Otherwise, only display the number of Entry objects that will be affected.",
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle(self, *args, **options):
|
def handle(self, *args, **options):
|
||||||
|
@ -88,72 +82,12 @@ class Command(BaseCommand):
|
||||||
|
|
||||||
new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
|
new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
|
||||||
logger.info(f"New default Search model: {new_default_search_model_config}")
|
logger.info(f"New default Search model: {new_default_search_model_config}")
|
||||||
user_search_model_configs_to_update = UserSearchModelConfig.objects.exclude(
|
|
||||||
setting_id=search_model_config_id
|
|
||||||
).all()
|
|
||||||
logger.info(f"Number of UserSearchModelConfig objects to update: {user_search_model_configs_to_update.count()}")
|
|
||||||
|
|
||||||
for user_config in user_search_model_configs_to_update:
|
|
||||||
affected_user = user_config.user
|
|
||||||
entry_filter = Q(user=affected_user)
|
|
||||||
relevant_entries = Entry.objects.filter(entry_filter).all()
|
|
||||||
logger.info(f"Number of Entry objects to update for user {affected_user}: {relevant_entries.count()}")
|
|
||||||
|
|
||||||
if apply:
|
|
||||||
try:
|
|
||||||
regenerate_entries(
|
|
||||||
entry_filter,
|
|
||||||
embeddings_model[new_default_search_model_config.name],
|
|
||||||
new_default_search_model_config,
|
|
||||||
)
|
|
||||||
user_config.setting = new_default_search_model_config
|
|
||||||
user_config.save()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Updated UserSearchModelConfig object for user {affected_user} to use the new default Search model."
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Updated {relevant_entries.count()} Entry objects for user {affected_user} to use the new default Search model."
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error embedding documents: {e}")
|
|
||||||
|
|
||||||
logger.info("----")
|
logger.info("----")
|
||||||
|
|
||||||
# There are also plenty of users who have indexed documents without explicitly creating a UserSearchModelConfig object. You would have to migrate these users as well, if the default is different from search_model_config_id.
|
|
||||||
current_default = get_default_search_model()
|
current_default = get_default_search_model()
|
||||||
if current_default.id != new_default_search_model_config.id:
|
|
||||||
users_without_user_search_model_config = KhojUser.objects.annotate(
|
|
||||||
user_search_model_config_count=Count("usersearchmodelconfig")
|
|
||||||
).filter(user_search_model_config_count=0)
|
|
||||||
|
|
||||||
logger.info(f"Number of User objects to update: {users_without_user_search_model_config.count()}")
|
# TODO: Migrate all Entry objects to use the new default Search model
|
||||||
for user in users_without_user_search_model_config:
|
|
||||||
entry_filter = Q(user=user)
|
|
||||||
relevant_entries = Entry.objects.filter(entry_filter).all()
|
|
||||||
logger.info(f"Number of Entry objects to update for user {user}: {relevant_entries.count()}")
|
|
||||||
|
|
||||||
if apply:
|
|
||||||
try:
|
|
||||||
regenerate_entries(
|
|
||||||
entry_filter,
|
|
||||||
embeddings_model[new_default_search_model_config.name],
|
|
||||||
new_default_search_model_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
UserSearchModelConfig.objects.create(user=user, setting=new_default_search_model_config)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Created UserSearchModelConfig object for user {user} to use the new default Search model."
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Updated {relevant_entries.count()} Entry objects for user {user} to use the new default Search model."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error embedding documents: {e}")
|
|
||||||
else:
|
|
||||||
logger.info("Default is the same as search_model_config_id.")
|
|
||||||
|
|
||||||
all_agents = Agent.objects.all()
|
all_agents = Agent.objects.all()
|
||||||
logger.info(f"Number of Agent objects to update: {all_agents.count()}")
|
logger.info(f"Number of Agent objects to update: {all_agents.count()}")
|
||||||
|
@ -174,6 +108,7 @@ class Command(BaseCommand):
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error embedding documents: {e}")
|
logger.error(f"Error embedding documents: {e}")
|
||||||
|
|
||||||
if apply and current_default.id != new_default_search_model_config.id:
|
if apply and current_default.id != new_default_search_model_config.id:
|
||||||
# Get the existing default SearchModelConfig object and update its name
|
# Get the existing default SearchModelConfig object and update its name
|
||||||
current_default.name = f"prev_default_{current_default.id}"
|
current_default.name = f"prev_default_{current_default.id}"
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Generated by Django 5.0.9 on 2024-11-04 19:56
|
||||||
|
|
||||||
|
from django.db import migrations
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0072_entry_search_model"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.DeleteModel(
|
||||||
|
name="UserSearchModelConfig",
|
||||||
|
),
|
||||||
|
]
|
|
@ -449,12 +449,6 @@ class UserVoiceModelConfig(BaseModel):
|
||||||
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
# TODO Delete this model once all users have been migrated to the server's default settings
|
|
||||||
class UserSearchModelConfig(BaseModel):
|
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
|
||||||
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
|
||||||
|
|
||||||
|
|
||||||
class UserTextToImageModelConfig(BaseModel):
|
class UserTextToImageModelConfig(BaseModel):
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
|
setting = models.ForeignKey(TextToImageModelConfig, on_delete=models.CASCADE)
|
||||||
|
|
|
@ -13,7 +13,6 @@ from khoj.database.adapters import (
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
FileObjectAdapters,
|
FileObjectAdapters,
|
||||||
get_default_search_model,
|
get_default_search_model,
|
||||||
get_user_default_search_model,
|
|
||||||
)
|
)
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import EntryDates, KhojUser
|
from khoj.database.models import EntryDates, KhojUser
|
||||||
|
@ -149,7 +148,7 @@ class TextToEntries(ABC):
|
||||||
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
||||||
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
model = get_user_default_search_model(user=user)
|
model = get_default_search_model()
|
||||||
with timer("Generated embeddings for entries to add to database in", logger):
|
with timer("Generated embeddings for entries to add to database in", logger):
|
||||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||||
|
|
|
@ -26,7 +26,6 @@ from khoj.database.adapters import (
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
get_default_search_model,
|
get_default_search_model,
|
||||||
get_user_default_search_model,
|
|
||||||
get_user_photo,
|
get_user_photo,
|
||||||
)
|
)
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
|
@ -151,7 +150,7 @@ async def execute_search(
|
||||||
encoded_asymmetric_query = None
|
encoded_asymmetric_query = None
|
||||||
if t != SearchType.Image:
|
if t != SearchType.Image:
|
||||||
with timer("Encoding query took", logger=logger):
|
with timer("Encoding query took", logger=logger):
|
||||||
search_model = await sync_to_async(get_user_default_search_model)(user)
|
search_model = await sync_to_async(get_default_search_model)()
|
||||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
|
|
@ -8,11 +8,7 @@ import torch
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from sentence_transformers import util
|
from sentence_transformers import util
|
||||||
|
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import EntryAdapters, get_default_search_model
|
||||||
EntryAdapters,
|
|
||||||
get_default_search_model,
|
|
||||||
get_user_default_search_model,
|
|
||||||
)
|
|
||||||
from khoj.database.models import Agent
|
from khoj.database.models import Agent
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import KhojUser
|
from khoj.database.models import KhojUser
|
||||||
|
@ -114,7 +110,7 @@ async def query(
|
||||||
file_type = search_type_to_embeddings_type[type.value]
|
file_type = search_type_to_embeddings_type[type.value]
|
||||||
|
|
||||||
query = raw_query
|
query = raw_query
|
||||||
search_model = await sync_to_async(get_user_default_search_model)(user)
|
search_model = await sync_to_async(get_default_search_model)()
|
||||||
if not max_distance:
|
if not max_distance:
|
||||||
if search_model.bi_encoder_confidence_threshold:
|
if search_model.bi_encoder_confidence_threshold:
|
||||||
max_distance = search_model.bi_encoder_confidence_threshold
|
max_distance = search_model.bi_encoder_confidence_threshold
|
||||||
|
|
Loading…
Reference in a new issue