mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Remove user customized search model (#946)
- Use a single standard search model across the server. There's diminishing benefits for having multiple user-customizable search models. - We may want to add server-level customization for specific tasks - Store the search model used to generate a given entry on the `Entry` object - Remove user-facing APIs and view - Add a management command for migrating the default search model on the server In a future PR (after running the migration), we'll also remove the `UserSearchModelConfig`
This commit is contained in:
parent
f3ce47b445
commit
5120597d4e
11 changed files with 237 additions and 90 deletions
|
@ -718,7 +718,7 @@ export default function SettingsView() {
|
||||||
};
|
};
|
||||||
|
|
||||||
const updateModel = (name: string) => async (id: string) => {
|
const updateModel = (name: string) => async (id: string) => {
|
||||||
if (!userConfig?.is_active && name !== "search") {
|
if (!userConfig?.is_active) {
|
||||||
toast({
|
toast({
|
||||||
title: `Model Update`,
|
title: `Model Update`,
|
||||||
description: `You need to be subscribed to update ${name} models`,
|
description: `You need to be subscribed to update ${name} models`,
|
||||||
|
@ -1233,27 +1233,6 @@ export default function SettingsView() {
|
||||||
</CardFooter>
|
</CardFooter>
|
||||||
</Card>
|
</Card>
|
||||||
)}
|
)}
|
||||||
{userConfig.search_model_options.length > 0 && (
|
|
||||||
<Card className={cardClassName}>
|
|
||||||
<CardHeader className="text-xl flex flex-row">
|
|
||||||
<FileMagnifyingGlass className="h-7 w-7 mr-2" />
|
|
||||||
Search
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
|
|
||||||
<p className="text-gray-400">
|
|
||||||
Pick the search model to find your documents
|
|
||||||
</p>
|
|
||||||
<DropdownComponent
|
|
||||||
items={userConfig.search_model_options}
|
|
||||||
selected={
|
|
||||||
userConfig.selected_search_model_config
|
|
||||||
}
|
|
||||||
callbackFunc={updateModel("search")}
|
|
||||||
/>
|
|
||||||
</CardContent>
|
|
||||||
<CardFooter className="flex flex-wrap gap-4"></CardFooter>
|
|
||||||
</Card>
|
|
||||||
)}
|
|
||||||
{userConfig.paint_model_options.length > 0 && (
|
{userConfig.paint_model_options.length > 0 && (
|
||||||
<Card className={cardClassName}>
|
<Card className={cardClassName}>
|
||||||
<CardHeader className="text-xl flex flex-row">
|
<CardHeader className="text-xl flex flex-row">
|
||||||
|
|
|
@ -466,18 +466,26 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_user_search_model_or_default(user=None):
|
def get_default_search_model() -> SearchModelConfig:
|
||||||
if user and UserSearchModelConfig.objects.filter(user=user).exists():
|
default_search_model = SearchModelConfig.objects.filter(name="default").first()
|
||||||
return UserSearchModelConfig.objects.filter(user=user).first().setting
|
|
||||||
|
|
||||||
if SearchModelConfig.objects.filter(name="default").exists():
|
if default_search_model:
|
||||||
return SearchModelConfig.objects.filter(name="default").first()
|
return default_search_model
|
||||||
else:
|
else:
|
||||||
SearchModelConfig.objects.create()
|
SearchModelConfig.objects.create()
|
||||||
|
|
||||||
return SearchModelConfig.objects.first()
|
return SearchModelConfig.objects.first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:
|
||||||
|
if user:
|
||||||
|
user_search_model = UserSearchModelConfig.objects.filter(user=user).first()
|
||||||
|
if user_search_model:
|
||||||
|
return user_search_model.setting
|
||||||
|
|
||||||
|
return get_default_search_model()
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_search_models():
|
def get_or_create_search_models():
|
||||||
search_models = SearchModelConfig.objects.all()
|
search_models = SearchModelConfig.objects.all()
|
||||||
if search_models.count() == 0:
|
if search_models.count() == 0:
|
||||||
|
@ -487,21 +495,6 @@ def get_or_create_search_models():
|
||||||
return search_models
|
return search_models
|
||||||
|
|
||||||
|
|
||||||
async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
|
|
||||||
config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
|
|
||||||
if not config:
|
|
||||||
return None
|
|
||||||
new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
|
||||||
return new_config
|
|
||||||
|
|
||||||
|
|
||||||
async def aget_user_search_model(user: KhojUser):
|
|
||||||
config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
|
||||||
if not config:
|
|
||||||
return None
|
|
||||||
return config.setting
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessLockAdapters:
|
class ProcessLockAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_process_lock(process_name: str):
|
def get_process_lock(process_name: str):
|
||||||
|
|
|
@ -126,6 +126,7 @@ class EntryAdmin(admin.ModelAdmin):
|
||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
"user",
|
"user",
|
||||||
|
"agent",
|
||||||
"file_source",
|
"file_source",
|
||||||
"file_type",
|
"file_type",
|
||||||
"file_name",
|
"file_name",
|
||||||
|
@ -135,6 +136,7 @@ class EntryAdmin(admin.ModelAdmin):
|
||||||
list_filter = (
|
list_filter = (
|
||||||
"file_type",
|
"file_type",
|
||||||
"user__email",
|
"user__email",
|
||||||
|
"search_model__name",
|
||||||
)
|
)
|
||||||
ordering = ("-created_at",)
|
ordering = ("-created_at",)
|
||||||
|
|
||||||
|
|
182
src/khoj/database/management/commands/change_default_model.py
Normal file
182
src/khoj/database/management/commands/change_default_model.py
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from django.core.management.base import BaseCommand
|
||||||
|
from django.db import transaction
|
||||||
|
from django.db.models import Count, Q
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from khoj.database.adapters import get_default_search_model
|
||||||
|
from khoj.database.models import (
|
||||||
|
Agent,
|
||||||
|
Entry,
|
||||||
|
KhojUser,
|
||||||
|
SearchModelConfig,
|
||||||
|
UserSearchModelConfig,
|
||||||
|
)
|
||||||
|
from khoj.processor.embeddings import EmbeddingsModel
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Command(BaseCommand):
|
||||||
|
help = "Convert all existing Entry objects to use a new default Search model."
|
||||||
|
|
||||||
|
def add_arguments(self, parser):
|
||||||
|
# Pass default SearchModelConfig ID
|
||||||
|
parser.add_argument(
|
||||||
|
"--search_model_id",
|
||||||
|
action="store",
|
||||||
|
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects and UserSearchModelConfig objects.",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the apply flag to apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects.
|
||||||
|
parser.add_argument(
|
||||||
|
"--apply",
|
||||||
|
action="store_true",
|
||||||
|
help="Apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. Otherwise, only display the number of Entry objects and UserSearchModelConfig objects that will be affected.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
@transaction.atomic
|
||||||
|
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
|
||||||
|
entries = Entry.objects.filter(entry_filter).all()
|
||||||
|
compiled_entries = [entry.compiled for entry in entries]
|
||||||
|
updated_entries: List[Entry] = []
|
||||||
|
try:
|
||||||
|
embeddings = embeddings_model.embed_documents(compiled_entries)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error embedding documents: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
for i, entry in enumerate(tqdm(entries)):
|
||||||
|
entry.embeddings = embeddings[i]
|
||||||
|
entry.search_model_id = search_model.id
|
||||||
|
updated_entries.append(entry)
|
||||||
|
|
||||||
|
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
|
||||||
|
|
||||||
|
search_model_config_id = options.get("search_model_id")
|
||||||
|
apply = options.get("apply")
|
||||||
|
|
||||||
|
logger.info(f"SearchModelConfig ID: {search_model_config_id}")
|
||||||
|
logger.info(f"Apply: {apply}")
|
||||||
|
|
||||||
|
embeddings_model = dict()
|
||||||
|
|
||||||
|
search_models = SearchModelConfig.objects.all()
|
||||||
|
for model in search_models:
|
||||||
|
embeddings_model.update(
|
||||||
|
{
|
||||||
|
model.name: EmbeddingsModel(
|
||||||
|
model.bi_encoder,
|
||||||
|
model.embeddings_inference_endpoint,
|
||||||
|
model.embeddings_inference_endpoint_api_key,
|
||||||
|
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||||
|
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||||
|
model_kwargs=model.bi_encoder_model_config,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
|
||||||
|
logger.info(f"New default Search model: {new_default_search_model_config}")
|
||||||
|
user_search_model_configs_to_update = UserSearchModelConfig.objects.exclude(
|
||||||
|
setting_id=search_model_config_id
|
||||||
|
).all()
|
||||||
|
logger.info(f"Number of UserSearchModelConfig objects to update: {user_search_model_configs_to_update.count()}")
|
||||||
|
|
||||||
|
for user_config in user_search_model_configs_to_update:
|
||||||
|
affected_user = user_config.user
|
||||||
|
entry_filter = Q(user=affected_user)
|
||||||
|
relevant_entries = Entry.objects.filter(entry_filter).all()
|
||||||
|
logger.info(f"Number of Entry objects to update for user {affected_user}: {relevant_entries.count()}")
|
||||||
|
|
||||||
|
if apply:
|
||||||
|
try:
|
||||||
|
regenerate_entries(
|
||||||
|
entry_filter,
|
||||||
|
embeddings_model[new_default_search_model_config.name],
|
||||||
|
new_default_search_model_config,
|
||||||
|
)
|
||||||
|
user_config.setting = new_default_search_model_config
|
||||||
|
user_config.save()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Updated UserSearchModelConfig object for user {affected_user} to use the new default Search model."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Updated {relevant_entries.count()} Entry objects for user {affected_user} to use the new default Search model."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error embedding documents: {e}")
|
||||||
|
|
||||||
|
logger.info("----")
|
||||||
|
|
||||||
|
# There are also plenty of users who have indexed documents without explicitly creating a UserSearchModelConfig object. You would have to migrate these users as well, if the default is different from search_model_config_id.
|
||||||
|
current_default = get_default_search_model()
|
||||||
|
if current_default.id != new_default_search_model_config.id:
|
||||||
|
users_without_user_search_model_config = KhojUser.objects.annotate(
|
||||||
|
user_search_model_config_count=Count("usersearchmodelconfig")
|
||||||
|
).filter(user_search_model_config_count=0)
|
||||||
|
|
||||||
|
logger.info(f"Number of User objects to update: {users_without_user_search_model_config.count()}")
|
||||||
|
for user in users_without_user_search_model_config:
|
||||||
|
entry_filter = Q(user=user)
|
||||||
|
relevant_entries = Entry.objects.filter(entry_filter).all()
|
||||||
|
logger.info(f"Number of Entry objects to update for user {user}: {relevant_entries.count()}")
|
||||||
|
|
||||||
|
if apply:
|
||||||
|
try:
|
||||||
|
regenerate_entries(
|
||||||
|
entry_filter,
|
||||||
|
embeddings_model[new_default_search_model_config.name],
|
||||||
|
new_default_search_model_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
UserSearchModelConfig.objects.create(user=user, setting=new_default_search_model_config)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created UserSearchModelConfig object for user {user} to use the new default Search model."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Updated {relevant_entries.count()} Entry objects for user {user} to use the new default Search model."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error embedding documents: {e}")
|
||||||
|
else:
|
||||||
|
logger.info("Default is the same as search_model_config_id.")
|
||||||
|
|
||||||
|
all_agents = Agent.objects.all()
|
||||||
|
logger.info(f"Number of Agent objects to update: {all_agents.count()}")
|
||||||
|
for agent in all_agents:
|
||||||
|
entry_filter = Q(agent=agent)
|
||||||
|
relevant_entries = Entry.objects.filter(entry_filter).all()
|
||||||
|
logger.info(f"Number of Entry objects to update for agent {agent}: {relevant_entries.count()}")
|
||||||
|
|
||||||
|
if apply:
|
||||||
|
try:
|
||||||
|
regenerate_entries(
|
||||||
|
entry_filter,
|
||||||
|
embeddings_model[new_default_search_model_config.name],
|
||||||
|
new_default_search_model_config,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Updated {relevant_entries.count()} Entry objects for agent {agent} to use the new default Search model."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error embedding documents: {e}")
|
||||||
|
if apply and current_default.id != new_default_search_model_config.id:
|
||||||
|
# Get the existing default SearchModelConfig object and update its name
|
||||||
|
current_default.name = f"prev_default_{current_default.id}"
|
||||||
|
current_default.save()
|
||||||
|
|
||||||
|
# Update the new default SearchModelConfig object's name
|
||||||
|
new_default_search_model_config.name = "default"
|
||||||
|
new_default_search_model_config.save()
|
||||||
|
if not apply:
|
||||||
|
logger.info("Run the command with the --apply flag to apply the new default Search model.")
|
24
src/khoj/database/migrations/0072_entry_search_model.py
Normal file
24
src/khoj/database/migrations/0072_entry_search_model.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
# Generated by Django 5.0.8 on 2024-10-21 21:09
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0071_subscription_enabled_trial_at_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="entry",
|
||||||
|
name="search_model",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.SET_NULL,
|
||||||
|
to="database.searchmodelconfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -449,6 +449,7 @@ class UserVoiceModelConfig(BaseModel):
|
||||||
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO Delete this model once all users have been migrated to the server's default settings
|
||||||
class UserSearchModelConfig(BaseModel):
|
class UserSearchModelConfig(BaseModel):
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
||||||
|
@ -535,6 +536,7 @@ class Entry(BaseModel):
|
||||||
url = models.URLField(max_length=400, default=None, null=True, blank=True)
|
url = models.URLField(max_length=400, default=None, null=True, blank=True)
|
||||||
hashed_value = models.CharField(max_length=100)
|
hashed_value = models.CharField(max_length=100)
|
||||||
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
||||||
|
search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True)
|
||||||
|
|
||||||
def save(self, *args, **kwargs):
|
def save(self, *args, **kwargs):
|
||||||
if self.user and self.agent:
|
if self.user and self.agent:
|
||||||
|
|
|
@ -12,7 +12,8 @@ from tqdm import tqdm
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
FileObjectAdapters,
|
FileObjectAdapters,
|
||||||
get_user_search_model_or_default,
|
get_default_search_model,
|
||||||
|
get_user_default_search_model,
|
||||||
)
|
)
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import EntryDates, KhojUser
|
from khoj.database.models import EntryDates, KhojUser
|
||||||
|
@ -148,10 +149,10 @@ class TextToEntries(ABC):
|
||||||
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
||||||
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
model = get_user_default_search_model(user=user)
|
||||||
with timer("Generated embeddings for entries to add to database in", logger):
|
with timer("Generated embeddings for entries to add to database in", logger):
|
||||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||||
model = get_user_search_model_or_default(user)
|
|
||||||
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
|
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
|
||||||
|
|
||||||
added_entries: list[DbEntry] = []
|
added_entries: list[DbEntry] = []
|
||||||
|
@ -177,6 +178,7 @@ class TextToEntries(ABC):
|
||||||
file_type=file_type,
|
file_type=file_type,
|
||||||
hashed_value=entry_hash,
|
hashed_value=entry_hash,
|
||||||
corpus_id=entry.corpus_id,
|
corpus_id=entry.corpus_id,
|
||||||
|
search_model=model,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -25,8 +25,9 @@ from khoj.database.adapters import (
|
||||||
AutomationAdapters,
|
AutomationAdapters,
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
|
get_default_search_model,
|
||||||
|
get_user_default_search_model,
|
||||||
get_user_photo,
|
get_user_photo,
|
||||||
get_user_search_model_or_default,
|
|
||||||
)
|
)
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
Agent,
|
Agent,
|
||||||
|
@ -149,7 +150,7 @@ async def execute_search(
|
||||||
encoded_asymmetric_query = None
|
encoded_asymmetric_query = None
|
||||||
if t != SearchType.Image:
|
if t != SearchType.Image:
|
||||||
with timer("Encoding query took", logger=logger):
|
with timer("Encoding query took", logger=logger):
|
||||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
search_model = await sync_to_async(get_user_default_search_model)(user)
|
||||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
|
|
@ -94,39 +94,6 @@ async def update_voice_model(
|
||||||
return Response(status_code=202, content=json.dumps({"status": "ok"}))
|
return Response(status_code=202, content=json.dumps({"status": "ok"}))
|
||||||
|
|
||||||
|
|
||||||
@api_model.post("/search", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def update_search_model(
|
|
||||||
request: Request,
|
|
||||||
id: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
prev_config = await adapters.aget_user_search_model(user)
|
|
||||||
new_config = await adapters.aset_user_search_model(user, int(id))
|
|
||||||
|
|
||||||
if prev_config and int(id) != prev_config.id and new_config:
|
|
||||||
await EntryAdapters.adelete_all_entries(user)
|
|
||||||
|
|
||||||
if not prev_config:
|
|
||||||
# If the use was just using the default config, delete all the entries and set the new config.
|
|
||||||
await EntryAdapters.adelete_all_entries(user)
|
|
||||||
|
|
||||||
if new_config is None:
|
|
||||||
return {"status": "error", "message": "Model not found"}
|
|
||||||
else:
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_search_model",
|
|
||||||
client=client,
|
|
||||||
metadata={"search_model": new_config.setting.name},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api_model.post("/paint", status_code=200)
|
@api_model.post("/paint", status_code=200)
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def update_paint_model(
|
async def update_paint_model(
|
||||||
|
|
|
@ -1706,13 +1706,6 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
||||||
for chat_model in chat_models:
|
for chat_model in chat_models:
|
||||||
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id})
|
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id})
|
||||||
|
|
||||||
search_model_options = adapters.get_or_create_search_models().all()
|
|
||||||
all_search_model_options = list()
|
|
||||||
for search_model_option in search_model_options:
|
|
||||||
all_search_model_options.append({"name": search_model_option.name, "id": search_model_option.id})
|
|
||||||
|
|
||||||
current_search_model_option = adapters.get_user_search_model_or_default(user)
|
|
||||||
|
|
||||||
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
|
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
|
||||||
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
||||||
all_paint_model_options = list()
|
all_paint_model_options = list()
|
||||||
|
@ -1745,8 +1738,6 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
||||||
"has_documents": has_documents,
|
"has_documents": has_documents,
|
||||||
"notion_token": notion_token,
|
"notion_token": notion_token,
|
||||||
# user model settings
|
# user model settings
|
||||||
"search_model_options": all_search_model_options,
|
|
||||||
"selected_search_model_config": current_search_model_option.id,
|
|
||||||
"chat_model_options": chat_model_options,
|
"chat_model_options": chat_model_options,
|
||||||
"selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None,
|
"selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None,
|
||||||
"paint_model_options": all_paint_model_options,
|
"paint_model_options": all_paint_model_options,
|
||||||
|
|
|
@ -8,7 +8,11 @@ import torch
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from sentence_transformers import util
|
from sentence_transformers import util
|
||||||
|
|
||||||
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
from khoj.database.adapters import (
|
||||||
|
EntryAdapters,
|
||||||
|
get_default_search_model,
|
||||||
|
get_user_default_search_model,
|
||||||
|
)
|
||||||
from khoj.database.models import Agent
|
from khoj.database.models import Agent
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import KhojUser
|
from khoj.database.models import KhojUser
|
||||||
|
@ -110,7 +114,7 @@ async def query(
|
||||||
file_type = search_type_to_embeddings_type[type.value]
|
file_type = search_type_to_embeddings_type[type.value]
|
||||||
|
|
||||||
query = raw_query
|
query = raw_query
|
||||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
search_model = await sync_to_async(get_user_default_search_model)(user)
|
||||||
if not max_distance:
|
if not max_distance:
|
||||||
if search_model.bi_encoder_confidence_threshold:
|
if search_model.bi_encoder_confidence_threshold:
|
||||||
max_distance = search_model.bi_encoder_confidence_threshold
|
max_distance = search_model.bi_encoder_confidence_threshold
|
||||||
|
|
Loading…
Reference in a new issue