mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Batch encode images to keep memory consumption manageable
- Issue: Process would get killed while encoding images for consuming too much memory - Fix: - Encode images in batches and append to image_embeddings - No need to use copy or deep_copy anymore with batch processing. It would earlier throw too many files open error Other Changes: - Use tqdm to see progress even when using batch - See progress bar of encoding independent of verbosity (for now)
This commit is contained in:
parent
d8abbc0552
commit
41c328dae0
1 changed files with 8 additions and 10 deletions
|
@ -6,6 +6,7 @@ import copy
|
|||
# External Packages
|
||||
from sentence_transformers import SentenceTransformer, util
|
||||
from PIL import Image
|
||||
from tqdm import trange
|
||||
import torch
|
||||
|
||||
# Internal Packages
|
||||
|
@ -50,23 +51,20 @@ def compute_embeddings(image_names, model, embeddings_file, regenerate=False, ve
|
|||
if verbose > 0:
|
||||
print(f"Loading the {len(image_names)} images into memory")
|
||||
|
||||
batch_size = 50
|
||||
if image_embeddings is None:
|
||||
image_embeddings = model.encode(
|
||||
[Image.open(image_name).copy() for image_name in image_names],
|
||||
batch_size=128, convert_to_tensor=True, show_progress_bar=verbose > 0)
|
||||
|
||||
image_embeddings = []
|
||||
for index in trange(0, len(image_names), batch_size):
|
||||
images = [Image.open(image_name) for image_name in image_names[index:index+batch_size]]
|
||||
image_embeddings += model.encode(images, convert_to_tensor=True, batch_size=batch_size)
|
||||
torch.save(image_embeddings, embeddings_file)
|
||||
|
||||
if verbose > 0:
|
||||
print(f"Saved computed embeddings to {embeddings_file}")
|
||||
|
||||
if image_metadata_embeddings is None:
|
||||
image_metadata_embeddings = model.encode(
|
||||
[extract_metadata(image_name, verbose) for image_name in image_names],
|
||||
batch_size=128, convert_to_tensor=True, show_progress_bar=verbose > 0)
|
||||
|
||||
image_metadata = [extract_metadata(image_name, verbose) for image_name in image_names],
|
||||
image_metadata_embeddings = model.encode(image_metadata, batch_size=batch_size, convert_to_tensor=True, show_progress_bar=True)
|
||||
torch.save(image_metadata_embeddings, f"{embeddings_file}_metadata")
|
||||
|
||||
if verbose > 0:
|
||||
print(f"Saved computed metadata embeddings to {embeddings_file}_metadata")
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue