Truncate image queries below max tokens length supported by ML model

This would previously return the infamous tensor size mismatch error
Verify this error is not raised since adding the query truncation logic
This commit is contained in:
Debanjum Singh Solanky 2023-01-21 14:11:00 -03:00
parent 3d9ed91e42
commit 6908b6eed3
2 changed files with 26 additions and 1 deletions

View file

@ -143,7 +143,9 @@ def query(raw_query, count, model: ImageSearchModel):
query.thumbnail((640, query.height)) # scale down image for faster processing
logger.info(f"Find Images by Image: {query_imagepath}")
else:
query = raw_query
# Truncate words in query to stay below max_tokens supported by ML model
max_words = 20
query = " ".join(raw_query.split()[:max_words])
logger.info(f"Find Images by Text: {query}")
# Now we encode the query (which can either be an image or a text string)

View file

@ -23,6 +23,7 @@ def test_image_search_setup(content_config: ContentConfig, search_config: Search
assert len(image_search_model.image_embeddings) == 3
# ----------------------------------------------------------------------------------------------------
def test_image_metadata(content_config: ContentConfig):
"Verify XMP Description and Subjects Extracted from Image"
# Arrange
@ -80,6 +81,28 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
actual_image_path.unlink()
# ----------------------------------------------------------------------------------------------------
def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
max_words_supported = 10
query = " ".join(["hello"]*100)
truncated_query = " ".join(["hello"]*max_words_supported)
# Act
try:
with caplog.at_level(logging.INFO, logger="src.search_type.image_search"):
image_search.query(
query,
count = 1,
model = model.image_search)
# Assert
except RuntimeError as e:
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
assert False, f"Query length exceeds max tokens supported by model\n"
assert f"Find Images by Text: {truncated_query}" in caplog.text, "Query not truncated"
# ----------------------------------------------------------------------------------------------------
def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
# Arrange