mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Make image batch size to encode configurable via config.yml
This commit is contained in:
parent
41c328dae0
commit
3afe054312
4 changed files with 9 additions and 8 deletions
|
@ -11,7 +11,7 @@ content-type:
|
||||||
|
|
||||||
image:
|
image:
|
||||||
embeddings-file: '.image_embeddings.pt'
|
embeddings-file: '.image_embeddings.pt'
|
||||||
|
batch-size: 50
|
||||||
|
|
||||||
search-type:
|
search-type:
|
||||||
asymmetric:
|
asymmetric:
|
||||||
|
|
|
@ -187,8 +187,9 @@ if __name__ == '__main__':
|
||||||
image_names, image_embeddings, image_metadata_embeddings, image_encoder = image_search.setup(
|
image_names, image_embeddings, image_metadata_embeddings, image_encoder = image_search.setup(
|
||||||
pathlib.Path(image_config['input-directory']),
|
pathlib.Path(image_config['input-directory']),
|
||||||
pathlib.Path(image_config['embeddings-file']),
|
pathlib.Path(image_config['embeddings-file']),
|
||||||
args.regenerate,
|
batch_size=image_config['batch-size'],
|
||||||
args.verbose)
|
regenerate=args.regenerate,
|
||||||
|
verbose=args.verbose)
|
||||||
|
|
||||||
# Start Application Server
|
# Start Application Server
|
||||||
uvicorn.run(app)
|
uvicorn.run(app)
|
||||||
|
|
|
@ -30,7 +30,7 @@ def extract_entries(image_directory, verbose=0):
|
||||||
return image_names
|
return image_names
|
||||||
|
|
||||||
|
|
||||||
def compute_embeddings(image_names, model, embeddings_file, regenerate=False, verbose=0):
|
def compute_embeddings(image_names, model, embeddings_file, batch_size=50, regenerate=False, verbose=0):
|
||||||
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
"Compute (and Save) Embeddings or Load Pre-Computed Embeddings"
|
||||||
image_embeddings = None
|
image_embeddings = None
|
||||||
image_metadata_embeddings = None
|
image_metadata_embeddings = None
|
||||||
|
@ -51,7 +51,6 @@ def compute_embeddings(image_names, model, embeddings_file, regenerate=False, ve
|
||||||
if verbose > 0:
|
if verbose > 0:
|
||||||
print(f"Loading the {len(image_names)} images into memory")
|
print(f"Loading the {len(image_names)} images into memory")
|
||||||
|
|
||||||
batch_size = 50
|
|
||||||
if image_embeddings is None:
|
if image_embeddings is None:
|
||||||
image_embeddings = []
|
image_embeddings = []
|
||||||
for index in trange(0, len(image_names), batch_size):
|
for index in trange(0, len(image_names), batch_size):
|
||||||
|
@ -137,7 +136,7 @@ def collate_results(hits, image_names, image_directory, count=5):
|
||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
||||||
|
|
||||||
def setup(image_directory, embeddings_file, regenerate=False, verbose=0):
|
def setup(image_directory, embeddings_file, batch_size=50, regenerate=False, verbose=0):
|
||||||
# Initialize Model
|
# Initialize Model
|
||||||
model = initialize_model()
|
model = initialize_model()
|
||||||
|
|
||||||
|
@ -147,7 +146,7 @@ def setup(image_directory, embeddings_file, regenerate=False, verbose=0):
|
||||||
|
|
||||||
# Compute or Load Embeddings
|
# Compute or Load Embeddings
|
||||||
embeddings_file = resolve_absolute_path(embeddings_file)
|
embeddings_file = resolve_absolute_path(embeddings_file)
|
||||||
image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file, regenerate=regenerate, verbose=verbose)
|
image_embeddings, image_metadata_embeddings = compute_embeddings(image_names, model, embeddings_file, batch_size=batch_size, regenerate=regenerate, verbose=verbose)
|
||||||
|
|
||||||
return image_names, image_embeddings, image_metadata_embeddings, model
|
return image_names, image_embeddings, image_metadata_embeddings, model
|
||||||
|
|
||||||
|
|
|
@ -56,7 +56,8 @@ default_config = {
|
||||||
},
|
},
|
||||||
'image':
|
'image':
|
||||||
{
|
{
|
||||||
'embeddings-file': '.image_embeddings.pt'
|
'embeddings-file': '.image_embeddings.pt',
|
||||||
|
'batch-size': 50
|
||||||
},
|
},
|
||||||
'music':
|
'music':
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in a new issue