diff --git a/sample_config.yml b/sample_config.yml index 3509f038..8d5de409 100644 --- a/sample_config.yml +++ b/sample_config.yml @@ -36,6 +36,7 @@ search-type: image: encoder: "clip-ViT-B-32" + model_directory: "tests/data/.image_encoder" processor: conversation: diff --git a/src/main.py b/src/main.py index 2edbc209..de6aa952 100644 --- a/src/main.py +++ b/src/main.py @@ -145,7 +145,7 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None # Initialize Image Search if (t == SearchType.Image or t == None) and config.content_type.image: # Extract Entries, Generate Image Embeddings - model.image_search = image_search.setup(config.content_type.image, regenerate=regenerate, verbose=verbose) + model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate, verbose=verbose) return model diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index f8af7d8e..9f4d04f2 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -10,16 +10,22 @@ from tqdm import trange import torch # Internal Packages -from src.utils.helpers import resolve_absolute_path +from src.utils.helpers import resolve_absolute_path, load_model import src.utils.exiftool as exiftool from src.utils.config import ImageSearchModel -from src.utils.rawconfig import ImageSearchConfig +from src.utils.rawconfig import ImageSearchConfig, ImageSearchTypeConfig -def initialize_model(): +def initialize_model(search_config: ImageSearchTypeConfig): # Initialize Model torch.set_num_threads(4) - encoder = SentenceTransformer('sentence-transformers/clip-ViT-B-32') #Load the CLIP model + + # Load the CLIP model + encoder = load_model( + model_dir = search_config.model_directory, + model_name = search_config.encoder, + model_type = SentenceTransformer) + return encoder @@ -154,9 +160,9 @@ def collate_results(hits, image_names, image_directory, count=5): in hits[0:count]] -def setup(config: ImageSearchConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel: +def setup(config: ImageSearchConfig, search_config: ImageSearchTypeConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel: # Initialize Model - encoder = initialize_model() + encoder = initialize_model(search_config) # Extract Entries image_directory = resolve_absolute_path(config.input_directory, strict=True) diff --git a/src/utils/rawconfig.py b/src/utils/rawconfig.py index 00eb2fc7..bbb2de31 100644 --- a/src/utils/rawconfig.py +++ b/src/utils/rawconfig.py @@ -49,6 +49,7 @@ class AsymmetricConfig(ConfigBase): class ImageSearchTypeConfig(ConfigBase): encoder: Optional[str] + model_directory: Optional[Path] class SearchTypeConfig(ConfigBase): asymmetric: Optional[AsymmetricConfig]