diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index e04bbe49..10d429c3 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -36,7 +36,7 @@ def initialize_model(search_config: ImageSearchConfig): encoder = load_model( model_dir = search_config.model_directory, model_name = search_config.encoder, - model_type = SentenceTransformer) + model_type = search_config.encoder_type or SentenceTransformer) return encoder diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 53eb3c3d..5fd04470 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -37,7 +37,7 @@ def initialize_model(search_config: TextSearchConfig): bi_encoder = load_model( model_dir = search_config.model_directory, model_name = search_config.encoder, - model_type = SentenceTransformer, + model_type = search_config.encoder_type or SentenceTransformer, device=f'{state.device}') # The cross-encoder re-ranks the results to improve quality diff --git a/src/utils/helpers.py b/src/utils/helpers.py index 2285018d..38181ad6 100644 --- a/src/utils/helpers.py +++ b/src/utils/helpers.py @@ -1,5 +1,6 @@ # Standard Packages from pathlib import Path +from importlib import import_module import sys from os.path import join from collections import OrderedDict @@ -44,17 +45,18 @@ def merge_dicts(priority_dict: dict, default_dict: dict): return merged_dict -def load_model(model_name, model_dir, model_type, device:str=None): +def load_model(model_name: str, model_type, model_dir=None, device:str=None): "Load model from disk or huggingface" # Construct model path model_path = join(model_dir, model_name.replace("/", "_")) if model_dir is not None else None # Load model from model_path if it exists there + model_type_class = get_class_by_name(model_type) if isinstance(model_type, str) else model_type if model_path is not None and resolve_absolute_path(model_path).exists(): - model = model_type(get_absolute_path(model_path), device=device) + model = model_type_class(get_absolute_path(model_path), device=device) # Else load the model from the model_name else: - model = model_type(model_name, device=device) + model = model_type_class(model_name, device=device) if model_path is not None: model.save(model_path) @@ -66,6 +68,12 @@ def is_pyinstaller_app(): return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') +def get_class_by_name(name: str) -> object: + "Returns the class object from name string" + module_name, class_name = name.rsplit('.', 1) + return getattr(import_module(module_name), class_name) + + class LRU(OrderedDict): def __init__(self, *args, capacity=128, **kwargs): self.capacity = capacity diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 165be0d1..5ed3a9eb 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -50,10 +50,12 @@ class ContentConfig(ConfigBase): class TextSearchConfig(ConfigBase): encoder: str cross_encoder: str + encoder_type: Optional[str] model_directory: Optional[Path] class ImageSearchConfig(ConfigBase): encoder: str + encoder_type: Optional[str] model_directory: Optional[Path] class SearchConfig(ConfigBase):