mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Add option to use HuggingFace's inference endpoint for generating embeddings (#609)
* Support using hosted Huggingface inference endpoint for embeddings generation * Since the huggingface inference endpoint is model-specific, make the URL an optional property of the search model config * Handle ECONNREFUSED error in desktop app * Drive API key via the search model config model and use more generic names
This commit is contained in:
parent
02187b19bb
commit
50575b749b
5 changed files with 84 additions and 3 deletions
|
@ -208,7 +208,10 @@ function pushDataToKhoj (regenerate = false) {
|
|||
})
|
||||
.catch(error => {
|
||||
console.error(error);
|
||||
if (error.response.status == 429) {
|
||||
if (error.code == 'ECONNREFUSED') {
|
||||
const win = BrowserWindow.getAllWindows()[0];
|
||||
if (win) win.webContents.send('update-state', state);
|
||||
} else if (error.response.status == 429) {
|
||||
const win = BrowserWindow.getAllWindows()[0];
|
||||
if (win) win.webContents.send('needsSubscription', true);
|
||||
if (win) win.webContents.send('update-state', state);
|
||||
|
|
|
@ -144,7 +144,15 @@ def configure_server(
|
|||
state.cross_encoder_model = dict()
|
||||
|
||||
for model in search_models:
|
||||
state.embeddings_model.update({model.name: EmbeddingsModel(model.bi_encoder)})
|
||||
state.embeddings_model.update(
|
||||
{
|
||||
model.name: EmbeddingsModel(
|
||||
model.bi_encoder,
|
||||
model.embeddings_inference_endpoint,
|
||||
model.embeddings_inference_endpoint_api_key,
|
||||
)
|
||||
}
|
||||
)
|
||||
state.cross_encoder_model.update({model.name: CrossEncoderModel(model.cross_encoder)})
|
||||
|
||||
state.SearchType = configure_search_types()
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
# Generated by Django 4.2.7 on 2024-01-15 18:12
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0024_alter_entry_embeddings"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="searchmodelconfig",
|
||||
name="embeddings_inference_endpoint",
|
||||
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="searchmodelconfig",
|
||||
name="embeddings_inference_endpoint_api_key",
|
||||
field=models.CharField(blank=True, default=None, max_length=200, null=True),
|
||||
),
|
||||
]
|
|
@ -110,6 +110,8 @@ class SearchModelConfig(BaseModel):
|
|||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.TEXT)
|
||||
bi_encoder = models.CharField(max_length=200, default="thenlper/gte-small")
|
||||
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||
embeddings_inference_endpoint = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
embeddings_inference_endpoint_api_key = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class TextToImageModelConfig(BaseModel):
|
||||
|
|
|
@ -1,23 +1,69 @@
|
|||
import logging
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
import tqdm
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||
from torch import nn
|
||||
|
||||
from khoj.utils.helpers import get_device
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmbeddingsModel:
|
||||
def __init__(self, model_name: str = "thenlper/gte-small"):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "thenlper/gte-small",
|
||||
embeddings_inference_endpoint: str = None,
|
||||
embeddings_inference_endpoint_api_key: str = None,
|
||||
):
|
||||
self.encode_kwargs = {"normalize_embeddings": True}
|
||||
self.model_kwargs = {"device": get_device()}
|
||||
self.model_name = model_name
|
||||
self.inference_endpoint = embeddings_inference_endpoint
|
||||
self.api_key = embeddings_inference_endpoint_api_key
|
||||
self.embeddings_model = SentenceTransformer(self.model_name, **self.model_kwargs)
|
||||
|
||||
def embed_query(self, query):
|
||||
if self.api_key is not None and self.inference_endpoint is not None:
|
||||
target_url = f"{self.inference_endpoint}"
|
||||
payload = {"inputs": [query]}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
response = requests.post(target_url, json=payload, headers=headers)
|
||||
return response.json()["embeddings"][0]
|
||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||
|
||||
def embed_documents(self, docs):
|
||||
if self.api_key is not None and self.inference_endpoint is not None:
|
||||
target_url = f"{self.inference_endpoint}"
|
||||
if "huggingface" not in target_url:
|
||||
logger.warning(
|
||||
f"Using custom inference endpoint {target_url} is not yet supported. Please us a HuggingFace inference endpoint."
|
||||
)
|
||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
with tqdm.tqdm(total=len(docs)) as pbar:
|
||||
for i in range(0, len(docs), 1000):
|
||||
payload = {"inputs": docs[i : i + 1000]}
|
||||
response = requests.post(target_url, json=payload, headers=headers)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Response: {response.json()}")
|
||||
raise e
|
||||
if i == 0:
|
||||
embeddings = response.json()["embeddings"]
|
||||
else:
|
||||
embeddings += response.json()["embeddings"]
|
||||
pbar.update(1000)
|
||||
return embeddings
|
||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue