diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index e15e75bb..cada1532 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -4,6 +4,13 @@ from typing import List import requests import tqdm from sentence_transformers import CrossEncoder, SentenceTransformer +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) from torch import nn from khoj.utils.helpers import get_device @@ -28,13 +35,32 @@ class EmbeddingsModel: def embed_query(self, query): if self.api_key is not None and self.inference_endpoint is not None: - target_url = f"{self.inference_endpoint}" - payload = {"inputs": [query]} - headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} - response = requests.post(target_url, json=payload, headers=headers) - return response.json()["embeddings"][0] + return self.embed_with_api([query])[0] return self.embeddings_model.encode([query], show_progress_bar=False, **self.encode_kwargs)[0] + @retry( + retry=retry_if_exception_type(requests.exceptions.HTTPError), + wait=wait_random_exponential(multiplier=1, max=10), + stop=stop_after_attempt(5), + before_sleep=before_sleep_log(logger, logging.DEBUG), + ) + def embed_with_api(self, docs): + payload = {"inputs": docs} + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + try: + response = requests.post(self.inference_endpoint, json=payload, headers=headers) + response.raise_for_status() + except requests.exceptions.HTTPError as e: + logger.error( + f" Error while calling inference endpoint {self.inference_endpoint} with error {e}, response {response.json()} ", + exc_info=True, + ) + raise e + return response.json()["embeddings"] + def embed_documents(self, docs): if self.api_key is not None and self.inference_endpoint is not None: target_url = f"{self.inference_endpoint}" @@ -44,25 +70,12 @@ class EmbeddingsModel: ) return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist() # break up the docs payload in chunks of 1000 to avoid hitting rate limits - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - } embeddings = [] with tqdm.tqdm(total=len(docs)) as pbar: for i in range(0, len(docs), 1000): - payload = {"inputs": docs[i : i + 1000]} - response = requests.post(target_url, json=payload, headers=headers) - try: - response.raise_for_status() - except requests.exceptions.HTTPError as e: - print(f"Error: {e}") - print(f"Response: {response.json()}") - raise e - if i == 0: - embeddings = response.json()["embeddings"] - else: - embeddings += response.json()["embeddings"] + docs_to_embed = docs[i : i + 1000] + generated_embeddings = self.embed_with_api(docs_to_embed) + embeddings += generated_embeddings pbar.update(1000) return embeddings return self.embeddings_model.encode(docs, show_progress_bar=True, **self.encode_kwargs).tolist()