diff --git a/src/interface/desktop/main.js b/src/interface/desktop/main.js index 4bb087d9..ec3e6fa4 100644 --- a/src/interface/desktop/main.js +++ b/src/interface/desktop/main.js @@ -208,7 +208,10 @@ function pushDataToKhoj (regenerate = false) { }) .catch(error => { console.error(error); - if (error.response.status == 429) { + if (error.code == 'ECONNREFUSED') { + const win = BrowserWindow.getAllWindows()[0]; + if (win) win.webContents.send('update-state', state); + } else if (error.response.status == 429) { const win = BrowserWindow.getAllWindows()[0]; if (win) win.webContents.send('needsSubscription', true); if (win) win.webContents.send('update-state', state); diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 0f9e5fef..ebeade48 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -144,7 +144,15 @@ def configure_server( state.cross_encoder_model = dict() for model in search_models: - state.embeddings_model.update({model.name: EmbeddingsModel(model.bi_encoder)}) + state.embeddings_model.update( + { + model.name: EmbeddingsModel( + model.bi_encoder, + model.embeddings_inference_endpoint, + model.embeddings_inference_endpoint_api_key, + ) + } + ) state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)}) state.SearchType = configure_search_types() diff --git a/src/khoj/database/migrations/0025_searchmodelconfig_embeddings_inference_endpoint_and_more.py b/src/khoj/database/migrations/0025_searchmodelconfig_embeddings_inference_endpoint_and_more.py new file mode 100644 index 00000000..ef79e223 --- /dev/null +++ b/src/khoj/database/migrations/0025_searchmodelconfig_embeddings_inference_endpoint_and_more.py @@ -0,0 +1,22 @@ +# Generated by Django 4.2.7 on 2024-01-15 18:12 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0024_alter_entry_embeddings"), + ] + + operations = [ + migrations.AddField( + model_name="searchmodelconfig", + name="embeddings_inference_endpoint", + field=models.CharField(blank=True, default=None, max_length=200, null=True), + ), + migrations.AddField( + model_name="searchmodelconfig", + name="embeddings_inference_endpoint_api_key", + field=models.CharField(blank=True, default=None, max_length=200, null=True), + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 873e9628..2b8887f7 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -110,6 +110,8 @@ class SearchModelConfig(BaseModel): model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT) bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small") cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") + embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True) + embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True) class TextToImageModelConfig(BaseModel): diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index 4cb01823..c0e91ce4 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -1,23 +1,69 @@ +import logging from typing import List +import requests +import tqdm from sentence_transformers import CrossEncoder, SentenceTransformer from torch import nn from khoj.utils.helpers import get_device from khoj.utils.rawconfig import SearchResponse +logger = logging.getLogger(__name__) + class EmbeddingsModel: - def __init__(self, model_name: str = "thenlper/gte-small"): + def __init__( + self, + model_name: str = "thenlper/gte-small", + embeddings_inference_endpoint: str = None, + embeddings_inference_endpoint_api_key: str = None, + ): self.encode_kwargs = {"normalize_embeddings": True} self.model_kwargs = {"device": get_device()} self.model_name = model_name + self.inference_endpoint = embeddings_inference_endpoint + self.api_key = embeddings_inference_endpoint_api_key self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs) def embed_query(self, query): + if self.api_key is not None and self.inference_endpoint is not None: + target_url = f"{self.inference_endpoint}" + payload = {"inputs": [query]} + headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} + response = requests.post(target_url, json=payload, headers=headers) + return response.json()["embeddings"][0] return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] def embed_documents(self, docs): + if self.api_key is not None and self.inference_endpoint is not None: + target_url = f"{self.inference_endpoint}" + if "huggingface" not in target_url: + logger.warning( + f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint." + ) + return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() + # break up the docs payload in chunks of 1000 to avoid hitting rate limits + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + with tqdm.tqdm(total=len(docs)) as pbar: + for i in range(0, len(docs), 1000): + payload = {"inputs": docs[i : i + 1000]} + response = requests.post(target_url, json=payload, headers=headers) + try: + response.raise_for_status() + except requests.exceptions.HTTPError as e: + print(f"Error: {e}") + print(f"Response: {response.json()}") + raise e + if i == 0: + embeddings = response.json()["embeddings"] + else: + embeddings += response.json()["embeddings"] + pbar.update(1000) + return embeddings return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()