diff --git a/pyproject.toml b/pyproject.toml
index 5a206cce..693415d4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -66,7 +66,7 @@ dependencies = [
"gpt4all >= 2.0.0; platform_system == 'Windows' or platform_system == 'Darwin'",
"itsdangerous == 2.1.2",
"httpx == 0.25.0",
- "pgvector == 0.2.3",
+ "pgvector == 0.2.4",
"psycopg2-binary == 2.9.9",
"google-auth == 2.23.3",
"python-multipart == 0.0.6",
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index 0610d763..c11a8e8a 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -269,6 +269,14 @@ def get_or_create_search_models():
return search_models
+async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
+ config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
+ if not config:
+ return None
+ new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
+ return new_config
+
+
class ConversationAdapters:
@staticmethod
def get_conversation_by_user(user: KhojUser):
diff --git a/src/khoj/database/migrations/0024_alter_entry_embeddings.py b/src/khoj/database/migrations/0024_alter_entry_embeddings.py
new file mode 100644
index 00000000..a1bbf45d
--- /dev/null
+++ b/src/khoj/database/migrations/0024_alter_entry_embeddings.py
@@ -0,0 +1,18 @@
+# Generated by Django 4.2.7 on 2023-12-20 07:27
+
+from django.db import migrations
+import pgvector.django
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0023_usersearchmodelconfig"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="entry",
+ name="embeddings",
+ field=pgvector.django.VectorField(),
+ ),
+ ]
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index 34d3d1d7..a9fa38f7 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -185,7 +185,7 @@ class Entry(BaseModel):
GITHUB = "github"
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
- embeddings = VectorField(dimensions=384)
+ embeddings = VectorField(dimensions=None)
raw = models.TextField()
compiled = models.TextField()
heading = models.CharField(max_length=1000, default=None, null=True, blank=True)
diff --git a/src/khoj/interface/web/assets/icons/matrix_blob.svg b/src/khoj/interface/web/assets/icons/matrix_blob.svg
new file mode 100644
index 00000000..592aa53e
--- /dev/null
+++ b/src/khoj/interface/web/assets/icons/matrix_blob.svg
@@ -0,0 +1,2 @@
+
+
diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html
index 723f8aaa..b22960aa 100644
--- a/src/khoj/interface/web/base_config.html
+++ b/src/khoj/interface/web/base_config.html
@@ -296,6 +296,7 @@
height: 32px;
}
+ select#search-models,
select#chat-models {
margin-bottom: 0;
padding: 8px;
diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html
index 3ce9c9cc..6aeeb426 100644
--- a/src/khoj/interface/web/config.html
+++ b/src/khoj/interface/web/config.html
@@ -146,6 +146,26 @@
+
+
+
+
+ Text Model
+
+
+
+
+
+
+
+
+
@@ -266,6 +286,30 @@
})
};
+ function updateSearchModel() {
+ const searchModel = document.getElementById("search-models").value;
+ const saveSearchModelButton = document.getElementById("save-search-model");
+ saveSearchModelButton.disabled = true;
+ saveSearchModelButton.innerHTML = "Saving...";
+
+ fetch('/api/config/data/search/model?id=' + searchModel, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ }
+ })
+ .then(response => response.json())
+ .then(data => {
+ if (data.status == "ok") {
+ saveSearchModelButton.innerHTML = "Save";
+ saveSearchModelButton.disabled = false;
+ } else {
+ saveSearchModelButton.innerHTML = "Error";
+ saveSearchModelButton.disabled = false;
+ }
+ })
+ };
+
function clearContentType(content_source) {
fetch('/api/config/data/content-source/' + content_source, {
method: 'DELETE',
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index ad360dd3..1f4fff17 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -332,6 +332,31 @@ async def update_chat_model(
return {"status": "ok"}
+@api.post("/config/data/search/model", status_code=200)
+@requires(["authenticated"])
+async def update_chat_model(
+ request: Request,
+ id: str,
+ client: Optional[str] = None,
+):
+ user = request.user.object
+
+ new_config = await adapters.aset_user_search_model(user, int(id))
+
+ update_telemetry_state(
+ request=request,
+ telemetry_type="api",
+ api="set_search_model",
+ client=client,
+ metadata={"search_model": new_config.setting.name},
+ )
+
+ if new_config is None:
+ return {"status": "error", "message": "Model not found"}
+
+ return {"status": "ok"}
+
+
# Create Routes
@api.get("/config/data/default")
def get_default_config_data():
@@ -410,14 +435,10 @@ async def search(
defiltered_query = filter.defilter(defiltered_query)
encoded_asymmetric_query = None
- if t == SearchType.All or t != SearchType.Image:
- text_search_models: List[TextSearchModel] = [
- model for model in state.search_models.__dict__.values() if isinstance(model, TextSearchModel)
- ]
- if text_search_models:
- with timer("Encoding query took", logger=logger):
- search_model = await sync_to_async(get_user_search_model_or_default)(user)
- encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
+ if t != SearchType.Image:
+ with timer("Encoding query took", logger=logger):
+ search_model = await sync_to_async(get_user_search_model_or_default)(user)
+ encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
with concurrent.futures.ThreadPoolExecutor() as executor:
if t in [
@@ -473,9 +494,9 @@ async def search(
results += text_search.collate_results(hits, dedupe=dedupe)
# Sort results across all content types and take top results
- results = text_search.rerank_and_sort_results(results, query=defiltered_query, rank_results=r)[
- :results_count
- ]
+ results = text_search.rerank_and_sort_results(
+ results, query=defiltered_query, rank_results=r, search_model_name=search_model.name
+ )[:results_count]
# Cache results
if user:
diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py
index 7907f99e..00a087b1 100644
--- a/src/khoj/routers/web_client.py
+++ b/src/khoj/routers/web_client.py
@@ -155,6 +155,11 @@ def config_page(request: Request):
for conversation_option in conversation_options:
all_conversation_options.append({"chat_model": conversation_option.chat_model, "id": conversation_option.id})
+ search_model_options = adapters.get_or_create_search_models().all()
+ all_search_model_options = list()
+ for search_model_option in search_model_options:
+ all_search_model_options.append({"name": search_model_option.name, "id": search_model_option.id})
+
return templates.TemplateResponse(
"config.html",
context={
@@ -163,6 +168,7 @@ def config_page(request: Request):
"anonymous_mode": state.anonymous_mode,
"username": user.username,
"conversation_options": all_conversation_options,
+ "search_model_options": all_search_model_options,
"selected_conversation_config": selected_conversation_config.id if selected_conversation_config else None,
"user_photo": user_picture,
"billing_enabled": state.billing_enabled,
diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py
index a4b8c9b1..2111a8a7 100644
--- a/src/khoj/search_type/text_search.py
+++ b/src/khoj/search_type/text_search.py
@@ -180,13 +180,13 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
)
-def rerank_and_sort_results(hits, query, rank_results):
+def rerank_and_sort_results(hits, query, rank_results, search_model_name):
# If we have more than one result and reranking is enabled
rank_results = rank_results and len(list(hits)) > 1
# Score all retrieved entries using the cross-encoder
if rank_results:
- hits = cross_encoder_score(query, hits)
+ hits = cross_encoder_score(query, hits, search_model_name)
# Sort results by cross-encoder score followed by bi-encoder score
hits = sort_results(rank_results=rank_results, hits=hits)
@@ -219,10 +219,10 @@ def setup(
)
-def cross_encoder_score(query: str, hits: List[SearchResponse]) -> List[SearchResponse]:
+def cross_encoder_score(query: str, hits: List[SearchResponse], search_model_name: str) -> List[SearchResponse]:
"""Score all retrieved entries using the cross-encoder"""
with timer("Cross-Encoder Predict Time", logger, state.device):
- cross_scores = state.cross_encoder_model.predict(query, hits)
+ cross_scores = state.cross_encoder_model[search_model_name].predict(query, hits)
# Convert cross-encoder scores to distances and pass in hits for reranking
for idx in range(len(cross_scores)):