Save Image Search Model to Disk

This commit is contained in:
Debanjum Singh Solanky 2022-01-14 17:09:18 -05:00
parent 934ec233b0
commit 510faa1904
4 changed files with 15 additions and 7 deletions

View file

@ -36,6 +36,7 @@ search-type:
image: image:
encoder: "clip-ViT-B-32" encoder: "clip-ViT-B-32"
model_directory: "tests/data/.image_encoder"
processor: processor:
conversation: conversation:

View file

@ -145,7 +145,7 @@ def initialize_search(config: FullConfig, regenerate: bool, t: SearchType = None
# Initialize Image Search # Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image: if (t == SearchType.Image or t == None) and config.content_type.image:
# Extract Entries, Generate Image Embeddings # Extract Entries, Generate Image Embeddings
model.image_search = image_search.setup(config.content_type.image, regenerate=regenerate, verbose=verbose) model.image_search = image_search.setup(config.content_type.image, search_config=config.search_type.image, regenerate=regenerate, verbose=verbose)
return model return model

View file

@ -10,16 +10,22 @@ from tqdm import trange
import torch import torch
# Internal Packages # Internal Packages
from src.utils.helpers import resolve_absolute_path from src.utils.helpers import resolve_absolute_path, load_model
import src.utils.exiftool as exiftool import src.utils.exiftool as exiftool
from src.utils.config import ImageSearchModel from src.utils.config import ImageSearchModel
from src.utils.rawconfig import ImageSearchConfig from src.utils.rawconfig import ImageSearchConfig, ImageSearchTypeConfig
def initialize_model(): def initialize_model(search_config: ImageSearchTypeConfig):
# Initialize Model # Initialize Model
torch.set_num_threads(4) torch.set_num_threads(4)
encoder = SentenceTransformer('sentence-transformers/clip-ViT-B-32') #Load the CLIP model
# Load the CLIP model
encoder = load_model(
model_dir = search_config.model_directory,
model_name = search_config.encoder,
model_type = SentenceTransformer)
return encoder return encoder
@ -154,9 +160,9 @@ def collate_results(hits, image_names, image_directory, count=5):
in hits[0:count]] in hits[0:count]]
def setup(config: ImageSearchConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel: def setup(config: ImageSearchConfig, search_config: ImageSearchTypeConfig, regenerate: bool, verbose: bool=False) -> ImageSearchModel:
# Initialize Model # Initialize Model
encoder = initialize_model() encoder = initialize_model(search_config)
# Extract Entries # Extract Entries
image_directory = resolve_absolute_path(config.input_directory, strict=True) image_directory = resolve_absolute_path(config.input_directory, strict=True)

View file

@ -49,6 +49,7 @@ class AsymmetricConfig(ConfigBase):
class ImageSearchTypeConfig(ConfigBase): class ImageSearchTypeConfig(ConfigBase):
encoder: Optional[str] encoder: Optional[str]
model_directory: Optional[Path]
class SearchTypeConfig(ConfigBase): class SearchTypeConfig(ConfigBase):
asymmetric: Optional[AsymmetricConfig] asymmetric: Optional[AsymmetricConfig]