mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Speed up load of embedding models and hence app start when no internet
- Set `local_files_only=True' when not connected to internet. - This variable is passed through sentence_transformers library to huggingface_hub. - Huggingface_hub directly looks for model on disk rather than trying to retrieve it from huggingface when `local_files_only=True'. - This speeds up model load time on app start. See https://github.com/UKPLab/sentence-transformers/pull/2603 for details
This commit is contained in:
parent
a9c383e62c
commit
ef746f8ef1
3 changed files with 9 additions and 6 deletions
|
@ -52,7 +52,7 @@ dependencies = [
|
|||
"pyyaml ~= 6.0",
|
||||
"rich >= 13.3.1",
|
||||
"schedule == 1.1.0",
|
||||
"sentence-transformers == 2.5.1",
|
||||
"sentence-transformers == 3.0.0",
|
||||
"transformers >= 4.28.0",
|
||||
"torch == 2.2.2",
|
||||
"uvicorn == 0.17.6",
|
||||
|
|
|
@ -13,7 +13,7 @@ from tenacity import (
|
|||
)
|
||||
from torch import nn
|
||||
|
||||
from khoj.utils.helpers import get_device, merge_dicts
|
||||
from khoj.utils.helpers import get_device, is_internet_connected, merge_dicts
|
||||
from khoj.utils.rawconfig import SearchResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -31,9 +31,10 @@ class EmbeddingsModel:
|
|||
):
|
||||
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
|
||||
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
|
||||
default_model_kwargs = {"device": get_device(), "local_files_only": not is_internet_connected(timeout=5)}
|
||||
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
|
||||
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
|
||||
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
|
||||
self.model_kwargs = merge_dicts(model_kwargs, default_model_kwargs)
|
||||
self.model_name = model_name
|
||||
self.inference_endpoint = embeddings_inference_endpoint
|
||||
self.api_key = embeddings_inference_endpoint_api_key
|
||||
|
@ -97,7 +98,9 @@ class CrossEncoderModel:
|
|||
cross_encoder_inference_endpoint_api_key: str = None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
||||
self.cross_encoder_model = CrossEncoder(
|
||||
model_name=self.model_name, device=get_device(), local_files_only=not is_internet_connected(timeout=5)
|
||||
)
|
||||
self.inference_endpoint = cross_encoder_inference_endpoint
|
||||
self.api_key = cross_encoder_inference_endpoint_api_key
|
||||
|
||||
|
|
|
@ -404,9 +404,9 @@ def is_valid_url(url: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def is_internet_connected():
|
||||
def is_internet_connected(timeout=None):
|
||||
try:
|
||||
response = requests.head("https://www.google.com")
|
||||
response = requests.head("https://www.google.com", timeout=timeout)
|
||||
return response.status_code == 200
|
||||
except:
|
||||
return False
|
||||
|
|
Loading…
Reference in a new issue