2021-08-09 07:17:19 +02:00
from sentence_transformers import SentenceTransformer , util
from PIL import Image
import torch
import argparse
import pathlib
import copy
def initialize_model ( ) :
# Initialize Model
torch . set_num_threads ( 4 )
top_k = 3
model = SentenceTransformer ( ' clip-ViT-B-32 ' ) #Load the CLIP model
return model , top_k
def extract_entries ( image_directory , verbose = False ) :
2021-08-09 09:29:51 +02:00
image_names = list ( image_directory . glob ( ' *.jpg ' ) )
2021-08-09 07:17:19 +02:00
if verbose :
print ( f ' Found { len ( image_names ) } images in { image_directory } ' )
return image_names
def compute_embeddings ( image_names , model , embeddings_file , verbose = False ) :
" Compute (and Save) Embeddings or Load Pre-Computed Embeddings "
# Load pre-computed embeddings from file if exists
if embeddings_file . exists ( ) :
2021-08-09 09:29:51 +02:00
image_embeddings = torch . load ( embeddings_file )
2021-08-09 07:17:19 +02:00
if verbose :
print ( f " Loaded pre-computed embeddings from { embeddings_file } " )
else : # Else compute the image_embeddings from scratch, which can take a while
images = [ ]
if verbose :
print ( f " Loading the { len ( image_names ) } images into memory " )
for image_name in image_names :
images . append ( copy . deepcopy ( Image . open ( image_name ) ) )
if len ( images ) > 0 :
image_embeddings = model . encode ( images , batch_size = 128 , convert_to_tensor = True , show_progress_bar = True )
2021-08-09 09:29:51 +02:00
torch . save ( image_embeddings , embeddings_file )
2021-08-09 07:17:19 +02:00
if verbose :
print ( f " Saved computed embeddings to { embeddings_file } " )
return image_embeddings
2021-08-09 08:11:15 +02:00
def search ( query , image_embeddings , model , count = 3 , verbose = False ) :
# Set query to image content if query is a filepath
if pathlib . Path ( query ) . expanduser ( ) . is_file ( ) :
query_imagepath = pathlib . Path ( query ) . expanduser ( ) . resolve ( strict = True )
query = copy . deepcopy ( Image . open ( query_imagepath ) )
if verbose :
print ( f " Find Images similar to Image at { query_imagepath } " )
else :
print ( f " Find Images by Text: { query } " )
# Now we encode the query (which can either be an image or a text string)
2021-08-09 07:17:19 +02:00
query_embedding = model . encode ( [ query ] , convert_to_tensor = True , show_progress_bar = False )
# Then, we use the util.semantic_search function, which computes the cosine-similarity
# between the query embedding and all image embeddings.
# It then returns the top_k highest ranked images, which we output
hits = util . semantic_search ( query_embedding , image_embeddings , top_k = count ) [ 0 ]
return hits
def render_results ( hits , image_names , image_directory , count ) :
for hit in hits [ : count ] :
print ( image_names [ hit [ ' corpus_id ' ] ] )
2021-08-09 09:29:51 +02:00
image_path = image_directory . joinpath ( image_names [ hit [ ' corpus_id ' ] ] )
2021-08-09 07:17:19 +02:00
with Image . open ( image_path ) as img :
img . show ( )
if __name__ == ' __main__ ' :
# Setup Argument Parser
parser = argparse . ArgumentParser ( description = " Semantic Search on Images " )
parser . add_argument ( ' --image-directory ' , ' -i ' , required = True , type = pathlib . Path , help = " Image directory to query " )
parser . add_argument ( ' --embeddings-file ' , ' -e ' , default = ' embeddings.pt ' , type = pathlib . Path , help = " File to save/load model embeddings to/from. Default: ./embeddings.pt " )
parser . add_argument ( ' --results-count ' , ' -n ' , default = 5 , type = int , help = " Number of results to render. Default: 5 " )
parser . add_argument ( ' --interactive ' , action = ' store_true ' , default = False , help = " Interactive mode allows user to run queries on the model. Default: true " )
parser . add_argument ( ' --verbose ' , action = ' store_true ' , default = False , help = " Show verbose conversion logs. Default: false " )
args = parser . parse_args ( )
2021-08-09 09:29:51 +02:00
# Resolve file, directory paths in args to absolute paths
embeddings_file = args . embeddings_file . expanduser ( ) . resolve ( )
image_directory = args . image_directory . expanduser ( ) . resolve ( strict = True )
2021-08-09 07:17:19 +02:00
# Initialize Model
model , count = initialize_model ( )
# Extract Entries
2021-08-09 09:29:51 +02:00
image_names = extract_entries ( image_directory , args . verbose )
2021-08-09 07:17:19 +02:00
# Compute or Load Embeddings
2021-08-09 09:29:51 +02:00
image_embeddings = compute_embeddings ( image_names , model , embeddings_file , args . verbose )
2021-08-09 07:17:19 +02:00
# Run User Queries on Entries in Interactive Mode
while args . interactive :
# get query from user
user_query = input ( " Enter your query: " )
if user_query == " exit " :
exit ( 0 )
# query notes
2021-08-09 08:11:15 +02:00
hits = search ( user_query , image_embeddings , model , args . results_count , args . verbose )
2021-08-09 07:17:19 +02:00
# render results
2021-08-09 09:29:51 +02:00
render_results ( hits , image_names , image_directory , count = args . results_count )