diff --git a/pyproject.toml b/pyproject.toml index 5a206cce..693415d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,7 @@ dependencies = [ "gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'", "itsdangerous == 2.1.2", "httpx == 0.25.0", - "pgvector == 0.2.3", + "pgvector == 0.2.4", "psycopg2-binary == 2.9.9", "google-auth == 2.23.3", "python-multipart == 0.0.6", diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 0610d763..c11a8e8a 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -269,6 +269,14 @@ def get_or_create_search_models(): return search_models +async def aset_user_search_model(user: KhojUser, search_model_config_id: int): + config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst() + if not config: + return None + new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config}) + return new_config + + class ConversationAdapters: @staticmethod def get_conversation_by_user(user: KhojUser): diff --git a/src/khoj/database/migrations/0024_alter_entry_embeddings.py b/src/khoj/database/migrations/0024_alter_entry_embeddings.py new file mode 100644 index 00000000..a1bbf45d --- /dev/null +++ b/src/khoj/database/migrations/0024_alter_entry_embeddings.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.7 on 2023-12-20 07:27 + +from django.db import migrations +import pgvector.django + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0023_usersearchmodelconfig"), + ] + + operations = [ + migrations.AlterField( + model_name="entry", + name="embeddings", + field=pgvector.django.VectorField(), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 34d3d1d7..a9fa38f7 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -185,7 +185,7 @@ class Entry(BaseModel): GITHUB = "github" user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True) - embeddings = VectorField(dimensions=384) + embeddings = VectorField(dimensions=None) raw = models.TextField() compiled = models.TextField() heading = models.CharField(max_length=1000, default=None, null=True, blank=True) diff --git a/src/khoj/interface/web/assets/icons/matrix_blob.svg b/src/khoj/interface/web/assets/icons/matrix_blob.svg new file mode 100644 index 00000000..592aa53e --- /dev/null +++ b/src/khoj/interface/web/assets/icons/matrix_blob.svg @@ -0,0 +1,2 @@ + + diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index 723f8aaa..b22960aa 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -296,6 +296,7 @@ height: 32px; } + select#search-models, select#chat-models { margin-bottom: 0; padding: 8px; diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 3ce9c9cc..6aeeb426 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -146,6 +146,26 @@ +
+
+ Chat +

+ Text Model +

+
+
+ +
+
+ +
+
@@ -266,6 +286,30 @@ }) }; + function updateSearchModel() { + const searchModel = document.getElementById("search-models").value; + const saveSearchModelButton = document.getElementById("save-search-model"); + saveSearchModelButton.disabled = true; + saveSearchModelButton.innerHTML = "Saving..."; + + fetch('/api/config/data/search/model?id=' + searchModel, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + } + }) + .then(response => response.json()) + .then(data => { + if (data.status == "ok") { + saveSearchModelButton.innerHTML = "Save"; + saveSearchModelButton.disabled = false; + } else { + saveSearchModelButton.innerHTML = "Error"; + saveSearchModelButton.disabled = false; + } + }) + }; + function clearContentType(content_source) { fetch('/api/config/data/content-source/' + content_source, { method: 'DELETE', diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ad360dd3..1f4fff17 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -332,6 +332,31 @@ async def update_chat_model( return {"status": "ok"} +@api.post("/config/data/search/model", status_code=200) +@requires(["authenticated"]) +async def update_chat_model( + request: Request, + id: str, + client: Optional[str] = None, +): + user = request.user.object + + new_config = await adapters.aset_user_search_model(user, int(id)) + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_search_model", + client=client, + metadata={"search_model": new_config.setting.name}, + ) + + if new_config is None: + return {"status": "error", "message": "Model not found"} + + return {"status": "ok"} + + # Create Routes @api.get("/config/data/default") def get_default_config_data(): @@ -410,14 +435,10 @@ async def search( defiltered_query = filter.defilter(defiltered_query) encoded_asymmetric_query = None - if t == SearchType.All or t != SearchType.Image: - text_search_models: List[TextSearchModel] = [ - model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel) - ] - if text_search_models: - with timer("Encoding query took", logger=logger): - search_model = await sync_to_async(get_user_search_model_or_default)(user) - encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) + if t != SearchType.Image: + with timer("Encoding query took", logger=logger): + search_model = await sync_to_async(get_user_search_model_or_default)(user) + encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query) with concurrent.futures.ThreadPoolExecutor() as executor: if t in [ @@ -473,9 +494,9 @@ async def search( results += text_search.collate_results(hits, dedupe=dedupe) # Sort results across all content types and take top results - results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[ - :results_count - ] + results = text_search.rerank_and_sort_results( + results, query=defiltered_query, rank_results=r, search_model_name=search_model.name + )[:results_count] # Cache results if user: diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 7907f99e..00a087b1 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -155,6 +155,11 @@ def config_page(request: Request): for conversation_option in conversation_options: all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id}) + search_model_options = adapters.get_or_create_search_models().all() + all_search_model_options = list() + for search_model_option in search_model_options: + all_search_model_options.append({"name": search_model_option.name, "id": search_model_option.id}) + return templates.TemplateResponse( "config.html", context={ @@ -163,6 +168,7 @@ def config_page(request: Request): "anonymous_mode": state.anonymous_mode, "username": user.username, "conversation_options": all_conversation_options, + "search_model_options": all_search_model_options, "selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None, "user_photo": user_picture, "billing_enabled": state.billing_enabled, diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index a4b8c9b1..2111a8a7 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -180,13 +180,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]): ) -def rerank_and_sort_results(hits, query, rank_results): +def rerank_and_sort_results(hits, query, rank_results, search_model_name): # If we have more than one result and reranking is enabled rank_results = rank_results and len(list(hits)) > 1 # Score all retrieved entries using the cross-encoder if rank_results: - hits = cross_encoder_score(query, hits) + hits = cross_encoder_score(query, hits, search_model_name) # Sort results by cross-encoder score followed by bi-encoder score hits = sort_results(rank_results=rank_results, hits=hits) @@ -219,10 +219,10 @@ def setup( ) -def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]: +def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]: """Score all retrieved entries using the cross-encoder""" with timer("Cross-Encoder Predict Time", logger, state.device): - cross_scores = state.cross_encoder_model.predict(query, hits) + cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits) # Convert cross-encoder scores to distances and pass in hits for reranking for idx in range(len(cross_scores)):