From 3afe05431229b26e3eca34610c7af2091c8ab835 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 16 Sep 2021 10:51:39 -0700 Subject: [PATCH] Make image batch size to encode configurable via config.yml --- sample_config.yml | 2 +- src/main.py | 5 +++-- src/search_type/image_search.py | 7 +++---- src/utils/cli.py | 3 ++- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sample_config.yml b/sample_config.yml index b01a0acd..68024f45 100644 --- a/sample_config.yml +++ b/sample_config.yml @@ -11,7 +11,7 @@ content-type: image: embeddings-file: '.image_embeddings.pt' - + batch-size: 50 search-type: asymmetric: diff --git a/src/main.py b/src/main.py index a38336ce..e13a0cc7 100644 --- a/src/main.py +++ b/src/main.py @@ -187,8 +187,9 @@ if __name__ == '__main__': image_names, image_embeddings, image_metadata_embeddings, image_encoder = image_search.setup( pathlib.Path(image_config['input-directory']), pathlib.Path(image_config['embeddings-file']), - args.regenerate, - args.verbose) + batch_size=image_config['batch-size'], + regenerate=args.regenerate, + verbose=args.verbose) # Start Application Server uvicorn.run(app) diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index 8c1eb852..181be13f 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -30,7 +30,7 @@ def extract_entries(image_directory, verbose=0): return image_names -def compute_embeddings(image_names, model, embeddings_file, regenerate=False, verbose=0): +def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, verbose=0): "Compute (and Save) Embeddings or Load Pre-Computed Embeddings" image_embeddings = None image_metadata_embeddings = None @@ -51,7 +51,6 @@ def compute_embeddings(image_names, model, embeddings_file, regenerate=False, ve if verbose > 0: print(f"Loading the {len(image_names)} images into memory") - batch_size = 50 if image_embeddings is None: image_embeddings = [] for index in trange(0, len(image_names), batch_size): @@ -137,7 +136,7 @@ def collate_results(hits, image_names, image_directory, count=5): in hits[0:count]] -def setup(image_directory, embeddings_file, regenerate=False, verbose=0): +def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, verbose=0): # Initialize Model model = initialize_model() @@ -147,7 +146,7 @@ def setup(image_directory, embeddings_file, regenerate=False, verbose=0): # Compute or Load Embeddings embeddings_file = resolve_absolute_path(embeddings_file) - image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file, regenerate=regenerate, verbose=verbose) + image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file, batch_size=batch_size, regenerate=regenerate, verbose=verbose) return image_names, image_embeddings, image_metadata_embeddings, model diff --git a/src/utils/cli.py b/src/utils/cli.py index 106c9d89..8d46bd02 100644 --- a/src/utils/cli.py +++ b/src/utils/cli.py @@ -56,7 +56,8 @@ default_config = { }, 'image': { - 'embeddings-file': '.image_embeddings.pt' + 'embeddings-file': '.image_embeddings.pt', + 'batch-size': 50 }, 'music': {