mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Add a model that specifies the user's search model configuration
- Update all endpoints that generate embeddings to use the new model. Incl. generating text embeddings, creating embeddings for a search query
This commit is contained in:
parent
6dd2b05bf5
commit
0f6e4ff683
5 changed files with 44 additions and 7 deletions
|
@ -32,6 +32,7 @@ from khoj.database.models import (
|
|||
SpeechToTextModelOptions,
|
||||
Subscription,
|
||||
UserConversationConfig,
|
||||
UserSearchModelConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
ReflectiveQuestion,
|
||||
|
@ -250,7 +251,10 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
|||
return config
|
||||
|
||||
|
||||
def get_default_search_model():
|
||||
def get_user_search_model_or_default(user=None):
|
||||
if user and UserSearchModelConfig.objects.filter(user=user).exists():
|
||||
return UserSearchModelConfig.objects.filter(user=user).first().setting
|
||||
|
||||
if SearchModelConfig.objects.filter(name="default").exists():
|
||||
return SearchModelConfig.objects.filter(name="default").first()
|
||||
return SearchModelConfig.objects.first()
|
||||
|
|
33
src/khoj/database/migrations/0023_usersearchmodelconfig.py
Normal file
33
src/khoj/database/migrations/0023_usersearchmodelconfig.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
# Generated by Django 4.2.7 on 2023-12-19 15:44
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0022_texttoimagemodelconfig"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="UserSearchModelConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"setting",
|
||||
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.searchmodelconfig"),
|
||||
),
|
||||
(
|
||||
"user",
|
||||
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -14,7 +14,7 @@ from khoj.utils.helpers import is_none_or_empty, timer, batcher
|
|||
from khoj.utils.rawconfig import Entry
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.database.models import KhojUser, Entry as DbEntry, EntryDates
|
||||
from khoj.database.adapters import EntryAdapters, get_default_search_model
|
||||
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -112,7 +112,7 @@ class TextToEntries(ABC):
|
|||
with timer("Generated embeddings for entries to add to database in", logger):
|
||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||
model = get_default_search_model()
|
||||
model = get_user_search_model_or_default(user)
|
||||
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
|
||||
|
||||
added_entries: list[DbEntry] = []
|
||||
|
|
|
@ -18,7 +18,7 @@ from starlette.authentication import requires
|
|||
# Internal Packages
|
||||
from khoj.configure import configure_server
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_search_model_or_default
|
||||
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import (
|
||||
|
@ -416,7 +416,7 @@ async def search(
|
|||
]
|
||||
if text_search_models:
|
||||
with timer("Encoding query took", logger=logger):
|
||||
search_model = await sync_to_async(get_default_search_model)()
|
||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
|
|
|
@ -19,7 +19,7 @@ from khoj.utils.state import SearchType
|
|||
from khoj.utils.rawconfig import SearchResponse, Entry
|
||||
from khoj.utils.jsonl import load_jsonl
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
from khoj.database.adapters import EntryAdapters, get_default_search_model
|
||||
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||
from khoj.database.models import KhojUser, Entry as DbEntry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -115,7 +115,7 @@ async def query(
|
|||
# Encode the query using the bi-encoder
|
||||
if question_embedding is None:
|
||||
with timer("Query Encode Time", logger, state.device):
|
||||
search_model = await sync_to_async(get_default_search_model)()
|
||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
||||
question_embedding = state.embeddings_model[search_model.name].embed_query(query)
|
||||
|
||||
# Find relevant entries for the query
|
||||
|
|
Loading…
Reference in a new issue