mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Add retries in case the embeddings API fails (#628)
* Add retries in case the embeddings API fails * Improve error handling in the inference endpoint API request handler - retry only if HTTP exception - use logger to output information about errors
This commit is contained in:
parent
b782683e60
commit
71cbe5160d
1 changed files with 34 additions and 21 deletions
|
@ -4,6 +4,13 @@ from typing import List
|
||||||
import requests
|
import requests
|
||||||
import tqdm
|
import tqdm
|
||||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_random_exponential,
|
||||||
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from khoj.utils.helpers import get_device
|
from khoj.utils.helpers import get_device
|
||||||
|
@ -28,13 +35,32 @@ class EmbeddingsModel:
|
||||||
|
|
||||||
def embed_query(self, query):
|
def embed_query(self, query):
|
||||||
if self.api_key is not None and self.inference_endpoint is not None:
|
if self.api_key is not None and self.inference_endpoint is not None:
|
||||||
target_url = f"{self.inference_endpoint}"
|
return self.embed_with_api([query])[0]
|
||||||
payload = {"inputs": [query]}
|
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
|
||||||
response = requests.post(target_url, json=payload, headers=headers)
|
|
||||||
return response.json()["embeddings"][0]
|
|
||||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_exception_type(requests.exceptions.HTTPError),
|
||||||
|
wait=wait_random_exponential(multiplier=1, max=10),
|
||||||
|
stop=stop_after_attempt(5),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
)
|
||||||
|
def embed_with_api(self, docs):
|
||||||
|
payload = {"inputs": docs}
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
response = requests.post(self.inference_endpoint, json=payload, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
except requests.exceptions.HTTPError as e:
|
||||||
|
logger.error(
|
||||||
|
f" Error while calling inference endpoint {self.inference_endpoint} with error {e}, response {response.json()} ",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
return response.json()["embeddings"]
|
||||||
|
|
||||||
def embed_documents(self, docs):
|
def embed_documents(self, docs):
|
||||||
if self.api_key is not None and self.inference_endpoint is not None:
|
if self.api_key is not None and self.inference_endpoint is not None:
|
||||||
target_url = f"{self.inference_endpoint}"
|
target_url = f"{self.inference_endpoint}"
|
||||||
|
@ -44,25 +70,12 @@ class EmbeddingsModel:
|
||||||
)
|
)
|
||||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
}
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
with tqdm.tqdm(total=len(docs)) as pbar:
|
with tqdm.tqdm(total=len(docs)) as pbar:
|
||||||
for i in range(0, len(docs), 1000):
|
for i in range(0, len(docs), 1000):
|
||||||
payload = {"inputs": docs[i : i + 1000]}
|
docs_to_embed = docs[i : i + 1000]
|
||||||
response = requests.post(target_url, json=payload, headers=headers)
|
generated_embeddings = self.embed_with_api(docs_to_embed)
|
||||||
try:
|
embeddings += generated_embeddings
|
||||||
response.raise_for_status()
|
|
||||||
except requests.exceptions.HTTPError as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
print(f"Response: {response.json()}")
|
|
||||||
raise e
|
|
||||||
if i == 0:
|
|
||||||
embeddings = response.json()["embeddings"]
|
|
||||||
else:
|
|
||||||
embeddings += response.json()["embeddings"]
|
|
||||||
pbar.update(1000)
|
pbar.update(1000)
|
||||||
return embeddings
|
return embeddings
|
||||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||||
|
|
Loading…
Add table
Reference in a new issue