diff --git a/pyproject.toml b/pyproject.toml index 73ea19e3..0cbf3009 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dependencies = [ "pyyaml ~= 6.0", "rich >= 13.3.1", "schedule == 1.1.0", - "sentence-transformers == 2.5.1", + "sentence-transformers == 3.0.0", "transformers >= 4.28.0", "torch == 2.2.2", "uvicorn == 0.17.6", diff --git a/src/khoj/processor/embeddings.py b/src/khoj/processor/embeddings.py index 701bbfac..8cf0599c 100644 --- a/src/khoj/processor/embeddings.py +++ b/src/khoj/processor/embeddings.py @@ -13,7 +13,7 @@ from tenacity import ( ) from torch import nn -from khoj.utils.helpers import get_device, merge_dicts +from khoj.utils.helpers import get_device, is_internet_connected, merge_dicts from khoj.utils.rawconfig import SearchResponse logger = logging.getLogger(__name__) @@ -31,9 +31,10 @@ class EmbeddingsModel: ): default_query_encode_kwargs = {"show_progress_bar": False, "normalize_embeddings": True} default_docs_encode_kwargs = {"show_progress_bar": True, "normalize_embeddings": True} + default_model_kwargs = {"device": get_device(), "local_files_only": not is_internet_connected(timeout=5)} self.query_encode_kwargs = merge_dicts(query_encode_kwargs, default_query_encode_kwargs) self.docs_encode_kwargs = merge_dicts(docs_encode_kwargs, default_docs_encode_kwargs) - self.model_kwargs = merge_dicts(model_kwargs, {"device": get_device()}) + self.model_kwargs = merge_dicts(model_kwargs, default_model_kwargs) self.model_name = model_name self.inference_endpoint = embeddings_inference_endpoint self.api_key = embeddings_inference_endpoint_api_key @@ -97,7 +98,9 @@ class CrossEncoderModel: cross_encoder_inference_endpoint_api_key: str = None, ): self.model_name = model_name - self.cross_encoder_model = CrossEncoder(model_name=self.model_name, device=get_device()) + self.cross_encoder_model = CrossEncoder( + model_name=self.model_name, device=get_device(), local_files_only=not is_internet_connected(timeout=5) + ) self.inference_endpoint = cross_encoder_inference_endpoint self.api_key = cross_encoder_inference_endpoint_api_key diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 48be715b..e7e02c4c 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -404,9 +404,9 @@ def is_valid_url(url: str) -> bool: return False -def is_internet_connected(): +def is_internet_connected(timeout=None): try: - response = requests.head("https://www.google.com") + response = requests.head("https://www.google.com", timeout=timeout) return response.status_code == 200 except: return False