diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index 2091aa30..2dc61b47 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -18,8 +18,8 @@ from utils.config import ImageSearchModel, ImageSearchConfig def initialize_model(): # Initialize Model torch.set_num_threads(4) - model = SentenceTransformer('clip-ViT-B-32') #Load the CLIP model - return model + encoder = SentenceTransformer('clip-ViT-B-32') #Load the CLIP model + return encoder def extract_entries(image_directory, verbose=0): @@ -32,16 +32,16 @@ def extract_entries(image_directory, verbose=0): return image_names -def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0): +def compute_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" - image_embeddings = compute_image_embeddings(image_names, model, embeddings_file, batch_size, regenerate, verbose) - image_metadata_embeddings = compute_metadata_embeddings(image_names, model, embeddings_file, batch_size, use_xmp_metadata, regenerate, verbose) + image_embeddings = compute_image_embeddings(image_names, encoder, embeddings_file, batch_size, regenerate, verbose) + image_metadata_embeddings = compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size, use_xmp_metadata, regenerate, verbose) return image_embeddings, image_metadata_embeddings -def compute_image_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, verbose=0): +def compute_image_embeddings(image_names, encoder, embeddings_file, batch_size=50, regenerate=False, verbose=0): image_embeddings = None # Load pre-computed image embeddings from file if exists @@ -55,7 +55,7 @@ def compute_image_embeddings(image_names, model, embeddings_file, batch_size=50, image_embeddings = [] for index in trange(0, len(image_names), batch_size): images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]] - image_embeddings += model.encode(images, convert_to_tensor=True, batch_size=batch_size) + image_embeddings += encoder.encode(images, convert_to_tensor=True, batch_size=batch_size) torch.save(image_embeddings, embeddings_file) if verbose > 0: print(f"Saved computed embeddings to {embeddings_file}") @@ -63,7 +63,7 @@ def compute_image_embeddings(image_names, model, embeddings_file, batch_size=50, return image_embeddings -def compute_metadata_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, use_xmp_metadata=False, verbose=0): +def compute_metadata_embeddings(image_names, encoder, embeddings_file, batch_size=50, use_xmp_metadata=False, regenerate=False, verbose=0): image_metadata_embeddings = None # Load pre-computed image metadata embedding file if exists @@ -77,7 +77,7 @@ def compute_metadata_embeddings(image_names, model, embeddings_file, batch_size= image_metadata_embeddings = [] for index in trange(0, len(image_names), batch_size): image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names[index:index+batch_size]] - image_metadata_embeddings += model.encode(image_metadata, convert_to_tensor=True, batch_size=batch_size) + image_metadata_embeddings += encoder.encode(image_metadata, convert_to_tensor=True, batch_size=batch_size) torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata") if verbose > 0: print(f"Saved computed metadata embeddings to {embeddings_file}_metadata") @@ -155,17 +155,17 @@ def collate_results(hits, image_names, image_directory, count=5): def setup(config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel: # Initialize Model - model = initialize_model() + encoder = initialize_model() # Extract Entries image_directory = resolve_absolute_path(config.input_directory, strict=True) - image_names = extract_entries(config.input_directory, config.verbose) + image_names = extract_entries(image_directory, config.verbose) # Compute or Load Embeddings embeddings_file = resolve_absolute_path(config.embeddings_file) image_embeddings, image_metadata_embeddings = compute_embeddings( image_names, - model, + encoder, embeddings_file, batch_size=config.batch_size, regenerate=regenerate, @@ -175,7 +175,7 @@ def setup(config: ImageSearchConfig, regenerate: bool) -> ImageSearchModel: return ImageSearchModel(image_names, image_embeddings, image_metadata_embeddings, - model, + encoder, config.verbose) diff --git a/src/tests/data/guineapig_grass.jpg b/src/tests/data/guineapig_grass.jpg new file mode 100644 index 00000000..a94668da Binary files /dev/null and b/src/tests/data/guineapig_grass.jpg differ diff --git a/src/tests/data/horse_dog.jpg b/src/tests/data/horse_dog.jpg new file mode 100644 index 00000000..3937613a Binary files /dev/null and b/src/tests/data/horse_dog.jpg differ diff --git a/src/tests/data/kitten_park.jpg b/src/tests/data/kitten_park.jpg new file mode 100644 index 00000000..98899d65 Binary files /dev/null and b/src/tests/data/kitten_park.jpg differ diff --git a/src/tests/test_main.py b/src/tests/test_main.py index dcda9f56..c06b9964 100644 --- a/src/tests/test_main.py +++ b/src/tests/test_main.py @@ -7,8 +7,9 @@ from fastapi.testclient import TestClient # Internal Packages from main import app, search_config, model -from search_type import asymmetric -from utils.config import SearchConfig, TextSearchConfig +from search_type import asymmetric, image_search +from utils.config import SearchConfig, TextSearchConfig, ImageSearchConfig +from utils.helpers import resolve_absolute_path # Arrange @@ -91,6 +92,47 @@ def test_notes_search(): assert "Semantic Search via Emacs" in search_result +# ---------------------------------------------------------------------------------------------------- +def test_image_search(): + # Arrange + search_config = SearchConfig() + search_config.image = ImageSearchConfig( + input_directory = Path('tests/data'), + embeddings_file = Path('tests/data/.image_embeddings.pt'), + batch_size = 10, + use_xmp_metadata = False, + verbose = 2) + + # Act + model.image_search = image_search.setup(search_config.image, regenerate=True) + + # Assert + assert len(model.image_search.image_names) == 3 + assert len(model.image_search.image_embeddings) == 3 + + # Arrange + for query, expected_image_name in [("kitten in a park", "kitten_park.jpg"), + ("horse and dog in a farm", "horse_dog.jpg"), + ("A guinea pig eating grass", "guineapig_grass.jpg")]: + # Act + hits = image_search.query( + query, + count = 1, + model = model.image_search) + + results = image_search.collate_results( + hits, + model.image_search.image_names, + search_config.image.input_directory, + count=1) + + actual_image = results[0]["Entry"] + expected_image = resolve_absolute_path(search_config.image.input_directory.joinpath(expected_image_name)) + + # Assert + assert expected_image == actual_image + + # ---------------------------------------------------------------------------------------------------- def test_notes_regenerate(): # Arrange