From ef21d78c9998f3c29d7f3756cc363c6f39454b79 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Tue, 5 Dec 2023 00:35:40 -0500 Subject: [PATCH 1/9] Initial changes to support multiple search model configurations - All search models are loaded into memory, and stored in a dictionary indexed by name - Still need to add database migrations and create a UI for user to select their choice. Presently, it uses the default option --- src/khoj/configure.py | 12 +++++++++--- src/khoj/database/adapters/__init__.py | 17 ++++++++++++----- src/khoj/database/admin.py | 2 ++ src/khoj/database/models/__init__.py | 5 +++++ src/khoj/processor/content/text_to_entries.py | 5 +++-- src/khoj/routers/api.py | 5 +++-- src/khoj/search_type/text_search.py | 5 +++-- src/khoj/utils/state.py | 2 +- tests/conftest.py | 12 ++++++++---- 9 files changed, 46 insertions(+), 19 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 19e7d403..8c302450 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -23,7 +23,7 @@ from starlette.authentication import ( from khoj.database.models import KhojUser, Subscription from khoj.database.adapters import ( get_all_users, - get_or_create_search_model, + get_or_create_search_models, aget_user_subscription_state, SubscriptionState, ) @@ -140,8 +140,14 @@ def configure_server( # Initialize Search Models from Config and initialize content try: - state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder) - state.cross_encoder_model = CrossEncoderModel(get_or_create_search_model().cross_encoder) + search_models = get_or_create_search_models() + state.embeddings_model = dict() + state.cross_encoder_model = dict() + + for model in search_models: + state.embeddings_model.update({model.name: EmbeddingsModel(model.bi_encoder)}) + state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)}) + state.SearchType = configure_search_types() state.search_models = configure_search(state.search_models, state.config.search_type) initialize_content(regenerate, search_type, init, user) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 9d18c815..41b0ef06 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -249,12 +249,19 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): return config -def get_or_create_search_model(): - search_model = SearchModelConfig.objects.filter().first() - if not search_model: - search_model = SearchModelConfig.objects.create() +def get_default_search_model(): + if SearchModelConfig.objects.filter(name="default").exists(): + return SearchModelConfig.objects.filter(name="default").first() + return SearchModelConfig.objects.first() - return search_model + +def get_or_create_search_models(): + search_models = SearchModelConfig.objects.all() + if search_models.count() == 0: + SearchModelConfig.objects.create() + search_models = SearchModelConfig.objects.all() + + return search_models class ConversationAdapters: diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 2213fb6e..2561a5da 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -12,6 +12,7 @@ from khoj.database.models import ( SpeechToTextModelOptions, Subscription, ReflectiveQuestion, + UserSearchModelConfig, ) admin.site.register(KhojUser, UserAdmin) @@ -23,3 +24,4 @@ admin.site.register(OfflineChatProcessorConversationConfig) admin.site.register(SearchModelConfig) admin.site.register(Subscription) admin.site.register(ReflectiveQuestion) +admin.site.register(UserSearchModelConfig) diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 82348fbe..19393c9c 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -145,6 +145,11 @@ class UserConversationConfig(BaseModel): setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True) +class UserSearchModelConfig(BaseModel): + user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) + setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE) + + class Conversation(BaseModel): user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) conversation_log = models.JSONField(default=dict) diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index 109c58e6..bfcf37f7 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -14,7 +14,7 @@ from khoj.utils.helpers import is_none_or_empty, timer, batcher from khoj.utils.rawconfig import Entry from khoj.search_filter.date_filter import DateFilter from khoj.database.models import KhojUser, Entry as DbEntry, EntryDates -from khoj.database.adapters import EntryAdapters +from khoj.database.adapters import EntryAdapters, get_default_search_model logger = logging.getLogger(__name__) @@ -112,7 +112,8 @@ class TextToEntries(ABC): with timer("Generated embeddings for entries to add to database in", logger): entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process] data_to_embed = [getattr(entry, key) for entry in entries_to_process] - embeddings += self.embeddings_model.embed_documents(data_to_embed) + model = get_default_search_model() + embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed) added_entries: list[DbEntry] = [] with timer("Added entries to database in", logger): diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ae125980..52b772fd 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -18,7 +18,7 @@ from starlette.authentication import requires # Internal Packages from khoj.configure import configure_server from khoj.database import adapters -from khoj.database.adapters import ConversationAdapters, EntryAdapters +from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model from khoj.database.models import ChatModelOptions from khoj.database.models import Entry as DbEntry from khoj.database.models import ( @@ -412,7 +412,8 @@ async def search( ] if text_search_models: with timer("Encoding query took", logger=logger): - encoded_asymmetric_query = state.embeddings_model.embed_query(defiltered_query) + search_model = await sync_to_async(get_default_search_model)() + encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) with concurrent.futures.ThreadPoolExecutor() as executor: if t in [ diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index d04d4c6a..1523473c 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -19,7 +19,7 @@ from khoj.utils.state import SearchType from khoj.utils.rawconfig import SearchResponse, Entry from khoj.utils.jsonl import load_jsonl from khoj.processor.content.text_to_entries import TextToEntries -from khoj.database.adapters import EntryAdapters +from khoj.database.adapters import EntryAdapters, get_default_search_model from khoj.database.models import KhojUser, Entry as DbEntry logger = logging.getLogger(__name__) @@ -115,7 +115,8 @@ async def query( # Encode the query using the bi-encoder if question_embedding is None: with timer("Query Encode Time", logger, state.device): - question_embedding = state.embeddings_model.embed_query(query) + search_model = await sync_to_async(get_default_search_model)() + question_embedding = state.embeddings_model[search_model.name].embed_query(query) # Find relevant entries for the query top_k = 10 diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index b54cf4b3..4e135b18 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -18,7 +18,7 @@ from khoj.utils.rawconfig import FullConfig # Application Global State config = FullConfig() search_models = SearchModels() -embeddings_model: EmbeddingsModel = None +embeddings_model: Dict[str, EmbeddingsModel] = None cross_encoder_model: CrossEncoderModel = None content_index = ContentIndex() gpt4all_processor_config: GPT4AllProcessorModel = None diff --git a/tests/conftest.py b/tests/conftest.py index 9a500609..bbb3aa39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,8 +45,10 @@ def enable_db_access_for_all_tests(db): @pytest.fixture(scope="session") def search_config() -> SearchConfig: - state.embeddings_model = EmbeddingsModel() - state.cross_encoder_model = CrossEncoderModel() + state.embeddings_model = dict() + state.embeddings_model["default"] = EmbeddingsModel() + state.cross_encoder_model = dict() + state.cross_encoder_model["default"] = CrossEncoderModel() model_dir = resolve_absolute_path("~/.khoj/search") model_dir.mkdir(parents=True, exist_ok=True) @@ -317,8 +319,10 @@ def client( state.config.content_type = content_config state.config.search_type = search_config state.SearchType = configure_search_types() - state.embeddings_model = EmbeddingsModel() - state.cross_encoder_model = CrossEncoderModel() + state.embeddings_model = dict() + state.embeddings_model["default"] = EmbeddingsModel() + state.cross_encoder_model = dict() + state.cross_encoder_model["default"] = CrossEncoderModel() # These lines help us Mock the Search models for these search types state.search_models.image_search = image_search.initialize_model(search_config.image) From 0f6e4ff683d5ae5baa5e6ca0960d96e309fa39cf Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 20 Dec 2023 09:22:26 +0530 Subject: [PATCH 2/9] Add a model that specifies the user's search model configuration - Update all endpoints that generate embeddings to use the new model. Incl. generating text embeddings, creating embeddings for a search query --- src/khoj/database/adapters/__init__.py | 6 +++- .../migrations/0023_usersearchmodelconfig.py | 33 +++++++++++++++++++ src/khoj/processor/content/text_to_entries.py | 4 +-- src/khoj/routers/api.py | 4 +-- src/khoj/search_type/text_search.py | 4 +-- 5 files changed, 44 insertions(+), 7 deletions(-) create mode 100644 src/khoj/database/migrations/0023_usersearchmodelconfig.py diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index b04e4828..0610d763 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -32,6 +32,7 @@ from khoj.database.models import ( SpeechToTextModelOptions, Subscription, UserConversationConfig, + UserSearchModelConfig, OpenAIProcessorConversationConfig, OfflineChatProcessorConversationConfig, ReflectiveQuestion, @@ -250,7 +251,10 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list): return config -def get_default_search_model(): +def get_user_search_model_or_default(user=None): + if user and UserSearchModelConfig.objects.filter(user=user).exists(): + return UserSearchModelConfig.objects.filter(user=user).first().setting + if SearchModelConfig.objects.filter(name="default").exists(): return SearchModelConfig.objects.filter(name="default").first() return SearchModelConfig.objects.first() diff --git a/src/khoj/database/migrations/0023_usersearchmodelconfig.py b/src/khoj/database/migrations/0023_usersearchmodelconfig.py new file mode 100644 index 00000000..0aec99f0 --- /dev/null +++ b/src/khoj/database/migrations/0023_usersearchmodelconfig.py @@ -0,0 +1,33 @@ +# Generated by Django 4.2.7 on 2023-12-19 15:44 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0022_texttoimagemodelconfig"), + ] + + operations = [ + migrations.CreateModel( + name="UserSearchModelConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ( + "setting", + models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.searchmodelconfig"), + ), + ( + "user", + models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/khoj/processor/content/text_to_entries.py b/src/khoj/processor/content/text_to_entries.py index bfcf37f7..0c092ed4 100644 --- a/src/khoj/processor/content/text_to_entries.py +++ b/src/khoj/processor/content/text_to_entries.py @@ -14,7 +14,7 @@ from khoj.utils.helpers import is_none_or_empty, timer, batcher from khoj.utils.rawconfig import Entry from khoj.search_filter.date_filter import DateFilter from khoj.database.models import KhojUser, Entry as DbEntry, EntryDates -from khoj.database.adapters import EntryAdapters, get_default_search_model +from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default logger = logging.getLogger(__name__) @@ -112,7 +112,7 @@ class TextToEntries(ABC): with timer("Generated embeddings for entries to add to database in", logger): entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process] data_to_embed = [getattr(entry, key) for entry in entries_to_process] - model = get_default_search_model() + model = get_user_search_model_or_default(user) embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed) added_entries: list[DbEntry] = [] diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index dfdfceda..ad360dd3 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -18,7 +18,7 @@ from starlette.authentication import requires # Internal Packages from khoj.configure import configure_server from khoj.database import adapters -from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_default_search_model +from khoj.database.adapters import ConversationAdapters, EntryAdapters, get_user_search_model_or_default from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions from khoj.database.models import Entry as DbEntry from khoj.database.models import ( @@ -416,7 +416,7 @@ async def search( ] if text_search_models: with timer("Encoding query took", logger=logger): - search_model = await sync_to_async(get_default_search_model)() + search_model = await sync_to_async(get_user_search_model_or_default)(user) encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) with concurrent.futures.ThreadPoolExecutor() as executor: diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 1523473c..a4b8c9b1 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -19,7 +19,7 @@ from khoj.utils.state import SearchType from khoj.utils.rawconfig import SearchResponse, Entry from khoj.utils.jsonl import load_jsonl from khoj.processor.content.text_to_entries import TextToEntries -from khoj.database.adapters import EntryAdapters, get_default_search_model +from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default from khoj.database.models import KhojUser, Entry as DbEntry logger = logging.getLogger(__name__) @@ -115,7 +115,7 @@ async def query( # Encode the query using the bi-encoder if question_embedding is None: with timer("Query Encode Time", logger, state.device): - search_model = await sync_to_async(get_default_search_model)() + search_model = await sync_to_async(get_user_search_model_or_default)(user) question_embedding = state.embeddings_model[search_model.name].embed_query(query) # Find relevant entries for the query From 5ff9df9d4cace02e6e6cdcdffe1abd9dc9d6a905 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 20 Dec 2023 13:25:43 +0530 Subject: [PATCH 3/9] Add support per user for configuring the preferred search model from the config page - Honor this setting across the relevant places where embeddings are used - Convert the VectorField object to have None for dimensions in order to make the search model easily configurable --- pyproject.toml | 2 +- src/khoj/database/adapters/__init__.py | 8 ++++ .../migrations/0024_alter_entry_embeddings.py | 18 ++++++++ src/khoj/database/models/__init__.py | 2 +- .../web/assets/icons/matrix_blob.svg | 2 + src/khoj/interface/web/base_config.html | 1 + src/khoj/interface/web/config.html | 44 +++++++++++++++++++ src/khoj/routers/api.py | 43 +++++++++++++----- src/khoj/routers/web_client.py | 6 +++ src/khoj/search_type/text_search.py | 8 ++-- 10 files changed, 117 insertions(+), 17 deletions(-) create mode 100644 src/khoj/database/migrations/0024_alter_entry_embeddings.py create mode 100644 src/khoj/interface/web/assets/icons/matrix_blob.svg diff --git a/pyproject.toml b/pyproject.toml index 5a206cce..693415d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ dependencies = [ "gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'", "itsdangerous == 2.1.2", "httpx == 0.25.0", - "pgvector == 0.2.3", + "pgvector == 0.2.4", "psycopg2-binary == 2.9.9", "google-auth == 2.23.3", "python-multipart == 0.0.6", diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 0610d763..c11a8e8a 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -269,6 +269,14 @@ def get_or_create_search_models(): return search_models +async def aset_user_search_model(user: KhojUser, search_model_config_id: int): + config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst() + if not config: + return None + new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) + return new_config + + class ConversationAdapters: @staticmethod def get_conversation_by_user(user: KhojUser): diff --git a/src/khoj/database/migrations/0024_alter_entry_embeddings.py b/src/khoj/database/migrations/0024_alter_entry_embeddings.py new file mode 100644 index 00000000..a1bbf45d --- /dev/null +++ b/src/khoj/database/migrations/0024_alter_entry_embeddings.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2023-12-20 07:27 + +from django.db import migrations +import pgvector.django + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0023_usersearchmodelconfig"), + ] + + operations = [ + migrations.AlterField( + model_name="entry", + name="embeddings", + field=pgvector.django.VectorField(), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 34d3d1d7..a9fa38f7 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -185,7 +185,7 @@ class Entry(BaseModel): GITHUB = "github" user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True) - embeddings = VectorField(dimensions=384) + embeddings = VectorField(dimensions=None) raw = models.TextField() compiled = models.TextField() heading = models.CharField(max_length=1000, default=None, null=True, blank=True) diff --git a/src/khoj/interface/web/assets/icons/matrix_blob.svg b/src/khoj/interface/web/assets/icons/matrix_blob.svg new file mode 100644 index 00000000..592aa53e --- /dev/null +++ b/src/khoj/interface/web/assets/icons/matrix_blob.svg @@ -0,0 +1,2 @@ + + diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index 723f8aaa..b22960aa 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -296,6 +296,7 @@ height: 32px; } + select#search-models, select#chat-models { margin-bottom: 0; padding: 8px; diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 3ce9c9cc..6aeeb426 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -146,6 +146,26 @@ +
+
+ Chat +

+ Text Model +

+
+
+ +
+
+ +
+
@@ -266,6 +286,30 @@ }) }; + function updateSearchModel() { + const searchModel = document.getElementById("search-models").value; + const saveSearchModelButton = document.getElementById("save-search-model"); + saveSearchModelButton.disabled = true; + saveSearchModelButton.innerHTML = "Saving..."; + + fetch('/api/config/data/search/model?id=' + searchModel, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + } + }) + .then(response => response.json()) + .then(data => { + if (data.status == "ok") { + saveSearchModelButton.innerHTML = "Save"; + saveSearchModelButton.disabled = false; + } else { + saveSearchModelButton.innerHTML = "Error"; + saveSearchModelButton.disabled = false; + } + }) + }; + function clearContentType(content_source) { fetch('/api/config/data/content-source/' + content_source, { method: 'DELETE', diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ad360dd3..1f4fff17 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -332,6 +332,31 @@ async def update_chat_model( return {"status": "ok"} +@api.post("/config/data/search/model", status_code=200) +@requires(["authenticated"]) +async def update_chat_model( + request: Request, + id: str, + client: Optional[str] = None, +): + user = request.user.object + + new_config = await adapters.aset_user_search_model(user, int(id)) + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_search_model", + client=client, + metadata={"search_model": new_config.setting.name}, + ) + + if new_config is None: + return {"status": "error", "message": "Model not found"} + + return {"status": "ok"} + + # Create Routes @api.get("/config/data/default") def get_default_config_data(): @@ -410,14 +435,10 @@ async def search( defiltered_query = filter.defilter(defiltered_query) encoded_asymmetric_query = None - if t == SearchType.All or t != SearchType.Image: - text_search_models: List[TextSearchModel] = [ - model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel) - ] - if text_search_models: - with timer("Encoding query took", logger=logger): - search_model = await sync_to_async(get_user_search_model_or_default)(user) - encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) + if t != SearchType.Image: + with timer("Encoding query took", logger=logger): + search_model = await sync_to_async(get_user_search_model_or_default)(user) + encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) with concurrent.futures.ThreadPoolExecutor() as executor: if t in [ @@ -473,9 +494,9 @@ async def search( results += text_search.collate_results(hits, dedupe=dedupe) # Sort results across all content types and take top results - results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[ - :results_count - ] + results = text_search.rerank_and_sort_results( + results, query=defiltered_query, rank_results=r, search_model_name=search_model.name + )[:results_count] # Cache results if user: diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 7907f99e..00a087b1 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -155,6 +155,11 @@ def config_page(request: Request): for conversation_option in conversation_options: all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id}) + search_model_options = adapters.get_or_create_search_models().all() + all_search_model_options = list() + for search_model_option in search_model_options: + all_search_model_options.append({"name": search_model_option.name, "id": search_model_option.id}) + return templates.TemplateResponse( "config.html", context={ @@ -163,6 +168,7 @@ def config_page(request: Request): "anonymous_mode": state.anonymous_mode, "username": user.username, "conversation_options": all_conversation_options, + "search_model_options": all_search_model_options, "selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None, "user_photo": user_picture, "billing_enabled": state.billing_enabled, diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index a4b8c9b1..2111a8a7 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -180,13 +180,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]): ) -def rerank_and_sort_results(hits, query, rank_results): +def rerank_and_sort_results(hits, query, rank_results, search_model_name): # If we have more than one result and reranking is enabled rank_results = rank_results and len(list(hits)) > 1 # Score all retrieved entries using the cross-encoder if rank_results: - hits = cross_encoder_score(query, hits) + hits = cross_encoder_score(query, hits, search_model_name) # Sort results by cross-encoder score followed by bi-encoder score hits = sort_results(rank_results=rank_results, hits=hits) @@ -219,10 +219,10 @@ def setup( ) -def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]: +def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]: """Score all retrieved entries using the cross-encoder""" with timer("Cross-Encoder Predict Time", logger, state.device): - cross_scores = state.cross_encoder_model.predict(query, hits) + cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits) # Convert cross-encoder scores to distances and pass in hits for reranking for idx in range(len(cross_scores)): From aa23da60a3e4dcf0f450985664d961271471473c Mon Sep 17 00:00:00 2001 From: sabaimran Date: Wed, 20 Dec 2023 14:22:08 +0530 Subject: [PATCH 4/9] Add a notification banner to show temporary messages --- src/khoj/interface/web/base_config.html | 14 ++++++++++++++ src/khoj/interface/web/config.html | 13 ++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index b22960aa..b3a5a3a3 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -296,6 +296,20 @@ height: 32px; } + div#notification-banner { + background-color: var(--primary); + border: 1px solid var(--primary-hover); + padding: 8px; + box-shadow: 0px 1px 3px 0px rgba(0,0,0,0.1),0px 1px 2px -1px rgba(0,0,0,0.8); + } + + div#notification-banner-parent { + position: fixed; + right: 0; + bottom: 0; + margin: 20px; + } + select#search-models, select#chat-models { margin-bottom: 0; diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 6aeeb426..11c78fc1 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -150,7 +150,7 @@
Chat

- Text Model + Language

@@ -259,6 +259,10 @@
{% endif %}
+
+
+
+