Add retries in case the embeddings API fails (#628)

* Add retries in case the embeddings API fails
* Improve error handling in the inference endpoint API request handler
- retry only if HTTP exception
- use logger to output information about errors
This commit is contained in:
sabaimran 2024-01-29 01:56:34 -08:00 committed by GitHub
parent b782683e60
commit 71cbe5160d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -4,6 +4,13 @@ from typing import List
import requests
import tqdm
from sentence_transformers import CrossEncoder, SentenceTransformer
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from torch import nn
from khoj.utils.helpers import get_device
@ -28,13 +35,32 @@ class EmbeddingsModel:
def embed_query(self, query):
if self.api_key is not None and self.inference_endpoint is not None:
target_url = f"{self.inference_endpoint}"
payload = {"inputs": [query]}
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.post(target_url, json=payload, headers=headers)
return response.json()["embeddings"][0]
return self.embed_with_api([query])[0]
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
@retry(
retry=retry_if_exception_type(requests.exceptions.HTTPError),
wait=wait_random_exponential(multiplier=1, max=10),
stop=stop_after_attempt(5),
before_sleep=before_sleep_log(logger, logging.DEBUG),
)
def embed_with_api(self, docs):
payload = {"inputs": docs}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
try:
response = requests.post(self.inference_endpoint, json=payload, headers=headers)
response.raise_for_status()
except requests.exceptions.HTTPError as e:
logger.error(
f" Error while calling inference endpoint {self.inference_endpoint} with error {e}, response {response.json()} ",
exc_info=True,
)
raise e
return response.json()["embeddings"]
def embed_documents(self, docs):
if self.api_key is not None and self.inference_endpoint is not None:
target_url = f"{self.inference_endpoint}"
@ -44,25 +70,12 @@ class EmbeddingsModel:
)
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
embeddings = []
with tqdm.tqdm(total=len(docs)) as pbar:
for i in range(0, len(docs), 1000):
payload = {"inputs": docs[i : i + 1000]}
response = requests.post(target_url, json=payload, headers=headers)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
print(f"Error: {e}")
print(f"Response: {response.json()}")
raise e
if i == 0:
embeddings = response.json()["embeddings"]
else:
embeddings += response.json()["embeddings"]
docs_to_embed = docs[i : i + 1000]
generated_embeddings = self.embed_with_api(docs_to_embed)
embeddings += generated_embeddings
pbar.update(1000)
return embeddings
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()