mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 09:25:06 +01:00
Speed up load of embedding models and hence app start when no internet
- Set `local_files_only=True' when not connected to internet. - This variable is passed through sentence_transformers library to huggingface_hub. - Huggingface_hub directly looks for model on disk rather than trying to retrieve it from huggingface when `local_files_only=True'. - This speeds up model load time on app start. See https://github.com/UKPLab/sentence-transformers/pull/2603 for details
This commit is contained in:
parent
a9c383e62c
commit
ef746f8ef1
3 changed files with 9 additions and 6 deletions
|
@ -52,7 +52,7 @@ dependencies = [
|
||||||
"pyyaml ~= 6.0",
|
"pyyaml ~= 6.0",
|
||||||
"rich >= 13.3.1",
|
"rich >= 13.3.1",
|
||||||
"schedule == 1.1.0",
|
"schedule == 1.1.0",
|
||||||
"sentence-transformers == 2.5.1",
|
"sentence-transformers == 3.0.0",
|
||||||
"transformers >= 4.28.0",
|
"transformers >= 4.28.0",
|
||||||
"torch == 2.2.2",
|
"torch == 2.2.2",
|
||||||
"uvicorn == 0.17.6",
|
"uvicorn == 0.17.6",
|
||||||
|
|
|
@ -13,7 +13,7 @@ from tenacity import (
|
||||||
)
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from khoj.utils.helpers import get_device, merge_dicts
|
from khoj.utils.helpers import get_device, is_internet_connected, merge_dicts
|
||||||
from khoj.utils.rawconfig import SearchResponse
|
from khoj.utils.rawconfig import SearchResponse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -31,9 +31,10 @@ class EmbeddingsModel:
|
||||||
):
|
):
|
||||||
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
|
default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True}
|
||||||
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
|
default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True}
|
||||||
|
default_model_kwargs = {"device": get_device(), "local_files_only": not is_internet_connected(timeout=5)}
|
||||||
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
|
self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs)
|
||||||
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
|
self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs)
|
||||||
self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()})
|
self.model_kwargs = merge_dicts(model_kwargs, default_model_kwargs)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.inference_endpoint = embeddings_inference_endpoint
|
self.inference_endpoint = embeddings_inference_endpoint
|
||||||
self.api_key = embeddings_inference_endpoint_api_key
|
self.api_key = embeddings_inference_endpoint_api_key
|
||||||
|
@ -97,7 +98,9 @@ class CrossEncoderModel:
|
||||||
cross_encoder_inference_endpoint_api_key: str = None,
|
cross_encoder_inference_endpoint_api_key: str = None,
|
||||||
):
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device())
|
self.cross_encoder_model = CrossEncoder(
|
||||||
|
model_name=self.model_name, device=get_device(), local_files_only=not is_internet_connected(timeout=5)
|
||||||
|
)
|
||||||
self.inference_endpoint = cross_encoder_inference_endpoint
|
self.inference_endpoint = cross_encoder_inference_endpoint
|
||||||
self.api_key = cross_encoder_inference_endpoint_api_key
|
self.api_key = cross_encoder_inference_endpoint_api_key
|
||||||
|
|
||||||
|
|
|
@ -404,9 +404,9 @@ def is_valid_url(url: str) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_internet_connected():
|
def is_internet_connected(timeout=None):
|
||||||
try:
|
try:
|
||||||
response = requests.head("https://www.google.com")
|
response = requests.head("https://www.google.com", timeout=timeout)
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
except:
|
except:
|
||||||
return False
|
return False
|
||||||
|
|
Loading…
Reference in a new issue