Speed up load of embedding models and hence app start when no internet

- Set `local_files_only=True' when not connected to internet.
- This variable is passed through sentence_transformers library to huggingface_hub.
- Huggingface_hub directly looks for model on disk rather than trying to
  retrieve it from huggingface when `local_files_only=True'.
- This speeds up model load time on app start.

See https://github.com/UKPLab/sentence-transformers/pull/2603 for
details
This commit is contained in:
Debanjum Singh Solanky 2024-06-04 09:31:33 +05:30
parent a9c383e62c
commit ef746f8ef1
3 changed files with 9 additions and 6 deletions

View file

@ -52,7 +52,7 @@ dependencies = [
"pyyaml ~= 6.0",
"rich >= 13.3.1",
"schedule == 1.1.0",
"sentence-transformers == 2.5.1",
"sentence-transformers == 3.0.0",
"transformers >= 4.28.0",
"torch == 2.2.2",
"uvicorn == 0.17.6",

View file

@ -13,7 +13,7 @@ from tenacity import (
)
from torch import nn
from khoj.utils.helpers import get_device, merge_dicts
from khoj.utils.helpers import get_device, is_internet_connected, merge_dicts
from khoj.utils.rawconfig import SearchResponse
logger = logging.getLogger(__name__)
@ -31,9 +31,10 @@ class EmbeddingsModel:
):
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
default_model_kwargs = {"device": get_device(), "local_files_only": not is_internet_connected(timeout=5)}
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
self.model_kwargs = merge_dicts(model_kwargs, default_model_kwargs)
self.model_name = model_name
self.inference_endpoint = embeddings_inference_endpoint
self.api_key = embeddings_inference_endpoint_api_key
@ -97,7 +98,9 @@ class CrossEncoderModel:
cross_encoder_inference_endpoint_api_key: str = None,
):
self.model_name = model_name
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
self.cross_encoder_model = CrossEncoder(
model_name=self.model_name, device=get_device(), local_files_only=not is_internet_connected(timeout=5)
)
self.inference_endpoint = cross_encoder_inference_endpoint
self.api_key = cross_encoder_inference_endpoint_api_key

View file

@ -404,9 +404,9 @@ def is_valid_url(url: str) -> bool:
return False
def is_internet_connected():
def is_internet_connected(timeout=None):
try:
response = requests.head("https://www.google.com")
response = requests.head("https://www.google.com", timeout=timeout)
return response.status_code == 200
except:
return False