Batch encode images to keep memory consumption manageable

- Issue:
  Process would get killed while encoding images
  for consuming too much memory

- Fix:
  - Encode images in batches and append to image_embeddings
  - No need to use copy or deep_copy anymore with batch processing.
    It would earlier throw too many files open error

Other Changes:
  - Use tqdm to see progress even when using batch
  - See progress bar of encoding independent of verbosity (for now)
This commit is contained in:
Debanjum Singh Solanky 2021-09-16 10:15:54 -07:00
parent d8abbc0552
commit 41c328dae0

View file

@ -6,6 +6,7 @@ import copy
# External Packages
from sentence_transformers import SentenceTransformer, util
from PIL import Image
from tqdm import trange
import torch
# Internal Packages
@ -50,23 +51,20 @@ def compute_embeddings(image_names, model, embeddings_file, regenerate=False, ve
if verbose > 0:
print(f"Loading the {len(image_names)} images into memory")
batch_size = 50
if image_embeddings is None:
image_embeddings = model.encode(
[Image.open(image_name).copy() for image_name in image_names],
batch_size=128, convert_to_tensor=True, show_progress_bar=verbose > 0)
image_embeddings = []
for index in trange(0, len(image_names), batch_size):
images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]]
image_embeddings += model.encode(images, convert_to_tensor=True, batch_size=batch_size)
torch.save(image_embeddings, embeddings_file)
if verbose > 0:
print(f"Saved computed embeddings to {embeddings_file}")
if image_metadata_embeddings is None:
image_metadata_embeddings = model.encode(
[extract_metadata(image_name, verbose) for image_name in image_names],
batch_size=128, convert_to_tensor=True, show_progress_bar=verbose > 0)
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names],
image_metadata_embeddings = model.encode(image_metadata, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
if verbose > 0:
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")