From ef746f8ef17d212263186c4592d05d06588374d3 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 4 Jun 2024 09:31:33 +0530 Subject: [PATCH] Speed up load of embedding models and hence app start when no internet - Set `local_files_only=True' when not connected to internet. - This variable is passed through sentence_transformers library to huggingface_hub. - Huggingface_hub directly looks for model on disk rather than trying to retrieve it from huggingface when `local_files_only=True'. - This speeds up model load time on app start. See https://github.com/UKPLab/sentence-transformers/pull/2603 for details --- pyproject.toml | 2 +- src/khoj/processor/embeddings.py | 9 ++++++--- src/khoj/utils/helpers.py | 4 ++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 73ea19e3..0cbf3009 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dependencies = [ "pyyaml ~= 6.0", "rich >= 13.3.1", "schedule == 1.1.0", - "sentence-transformers == 2.5.1", + "sentence-transformers == 3.0.0", "transformers >= 4.28.0", "torch == 2.2.2", "uvicorn == 0.17.6", diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index 701bbfac..8cf0599c 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -13,7 +13,7 @@ from tenacity import ( ) from torch import nn -from khoj.utils.helpers import get_device, merge_dicts +from khoj.utils.helpers import get_device, is_internet_connected, merge_dicts from khoj.utils.rawconfig import SearchResponse logger = logging.getLogger(__name__) @@ -31,9 +31,10 @@ class EmbeddingsModel: ): default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True} default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True} + default_model_kwargs = {"device": get_device(), "local_files_only": not is_internet_connected(timeout=5)} self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs) self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs) - self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()}) + self.model_kwargs = merge_dicts(model_kwargs, default_model_kwargs) self.model_name = model_name self.inference_endpoint = embeddings_inference_endpoint self.api_key = embeddings_inference_endpoint_api_key @@ -97,7 +98,9 @@ class CrossEncoderModel: cross_encoder_inference_endpoint_api_key: str = None, ): self.model_name = model_name - self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device()) + self.cross_encoder_model = CrossEncoder( + model_name=self.model_name, device=get_device(), local_files_only=not is_internet_connected(timeout=5) + ) self.inference_endpoint = cross_encoder_inference_endpoint self.api_key = cross_encoder_inference_endpoint_api_key diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 48be715b..e7e02c4c 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -404,9 +404,9 @@ def is_valid_url(url: str) -> bool: return False -def is_internet_connected(): +def is_internet_connected(timeout=None): try: - response = requests.head("https://www.google.com") + response = requests.head("https://www.google.com", timeout=timeout) return response.status_code == 200 except: return False