Make image batch size to encode configurable via config.yml

This commit is contained in:
Debanjum Singh Solanky 2021-09-16 10:51:39 -07:00
parent 41c328dae0
commit 3afe054312
4 changed files with 9 additions and 8 deletions

View file

@ -11,7 +11,7 @@ content-type:
image:
embeddings-file: '.image_embeddings.pt'
batch-size: 50
search-type:
asymmetric:

View file

@ -187,8 +187,9 @@ if __name__ == '__main__':
image_names, image_embeddings, image_metadata_embeddings, image_encoder = image_search.setup(
pathlib.Path(image_config['input-directory']),
pathlib.Path(image_config['embeddings-file']),
args.regenerate,
args.verbose)
batch_size=image_config['batch-size'],
regenerate=args.regenerate,
verbose=args.verbose)
# Start Application Server
uvicorn.run(app)

View file

@ -30,7 +30,7 @@ def extract_entries(image_directory, verbose=0):
return image_names
def compute_embeddings(image_names, model, embeddings_file, regenerate=False, verbose=0):
def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, verbose=0):
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
image_embeddings = None
image_metadata_embeddings = None
@ -51,7 +51,6 @@ def compute_embeddings(image_names, model, embeddings_file, regenerate=False, ve
if verbose > 0:
print(f"Loading the {len(image_names)} images into memory")
batch_size = 50
if image_embeddings is None:
image_embeddings = []
for index in trange(0, len(image_names), batch_size):
@ -137,7 +136,7 @@ def collate_results(hits, image_names, image_directory, count=5):
in hits[0:count]]
def setup(image_directory, embeddings_file, regenerate=False, verbose=0):
def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, verbose=0):
# Initialize Model
model = initialize_model()
@ -147,7 +146,7 @@ def setup(image_directory, embeddings_file, regenerate=False, verbose=0):
# Compute or Load Embeddings
embeddings_file = resolve_absolute_path(embeddings_file)
image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file, regenerate=regenerate, verbose=verbose)
image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file, batch_size=batch_size, regenerate=regenerate, verbose=verbose)
return image_names, image_embeddings, image_metadata_embeddings, model

View file

@ -56,7 +56,8 @@ default_config = {
},
'image':
{
'embeddings-file': '.image_embeddings.pt'
'embeddings-file': '.image_embeddings.pt',
'batch-size': 50
},
'music':
{