mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Add retries in case the embeddings API fails (#628)
* Add retries in case the embeddings API fails * Improve error handling in the inference endpoint API request handler - retry only if HTTP exception - use logger to output information about errors
This commit is contained in:
parent
b782683e60
commit
71cbe5160d
1 changed files with 34 additions and 21 deletions
|
@ -4,6 +4,13 @@ from typing import List
|
|||
import requests
|
||||
import tqdm
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from khoj.utils.helpers import get_device
|
||||
|
@ -28,13 +35,32 @@ class EmbeddingsModel:
|
|||
|
||||
def embed_query(self, query):
|
||||
if self.api_key is not None and self.inference_endpoint is not None:
|
||||
target_url = f"{self.inference_endpoint}"
|
||||
payload = {"inputs": [query]}
|
||||
headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
|
||||
response = requests.post(target_url, json=payload, headers=headers)
|
||||
return response.json()["embeddings"][0]
|
||||
return self.embed_with_api([query])[0]
|
||||
return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0]
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception_type(requests.exceptions.HTTPError),
|
||||
wait=wait_random_exponential(multiplier=1, max=10),
|
||||
stop=stop_after_attempt(5),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
)
|
||||
def embed_with_api(self, docs):
|
||||
payload = {"inputs": docs}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
try:
|
||||
response = requests.post(self.inference_endpoint, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
logger.error(
|
||||
f" Error while calling inference endpoint {self.inference_endpoint} with error {e}, response {response.json()} ",
|
||||
exc_info=True,
|
||||
)
|
||||
raise e
|
||||
return response.json()["embeddings"]
|
||||
|
||||
def embed_documents(self, docs):
|
||||
if self.api_key is not None and self.inference_endpoint is not None:
|
||||
target_url = f"{self.inference_endpoint}"
|
||||
|
@ -44,25 +70,12 @@ class EmbeddingsModel:
|
|||
)
|
||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||
# break up the docs payload in chunks of 1000 to avoid hitting rate limits
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
embeddings = []
|
||||
with tqdm.tqdm(total=len(docs)) as pbar:
|
||||
for i in range(0, len(docs), 1000):
|
||||
payload = {"inputs": docs[i : i + 1000]}
|
||||
response = requests.post(target_url, json=payload, headers=headers)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
print(f"Error: {e}")
|
||||
print(f"Response: {response.json()}")
|
||||
raise e
|
||||
if i == 0:
|
||||
embeddings = response.json()["embeddings"]
|
||||
else:
|
||||
embeddings += response.json()["embeddings"]
|
||||
docs_to_embed = docs[i : i + 1000]
|
||||
generated_embeddings = self.embed_with_api(docs_to_embed)
|
||||
embeddings += generated_embeddings
|
||||
pbar.update(1000)
|
||||
return embeddings
|
||||
return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()
|
||||
|
|
Loading…
Reference in a new issue