From 6908b6eed3edc49762abb38cd1f6c6e11e444bdd Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 21 Jan 2023 14:11:00 -0300 Subject: [PATCH] Truncate image queries below max tokens length supported by ML model This would previously return the infamous tensor size mismatch error Verify this error is not raised since adding the query truncation logic --- src/search_type/image_search.py | 4 +++- tests/test_image_search.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index 323556e7..6dc57b6e 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -143,7 +143,9 @@ def query(raw_query, count, model: ImageSearchModel): query.thumbnail((640, query.height)) # scale down image for faster processing logger.info(f"Find Images by Image: {query_imagepath}") else: - query = raw_query + # Truncate words in query to stay below max_tokens supported by ML model + max_words = 20 + query = " ".join(raw_query.split()[:max_words]) logger.info(f"Find Images by Text: {query}") # Now we encode the query (which can either be an image or a text string) diff --git a/tests/test_image_search.py b/tests/test_image_search.py index 2a396f54..77000d60 100644 --- a/tests/test_image_search.py +++ b/tests/test_image_search.py @@ -23,6 +23,7 @@ def test_image_search_setup(content_config: ContentConfig, search_config: Search assert len(image_search_model.image_embeddings) == 3 +# ---------------------------------------------------------------------------------------------------- def test_image_metadata(content_config: ContentConfig): "Verify XMP Description and Subjects Extracted from Image" # Arrange @@ -80,6 +81,28 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig actual_image_path.unlink() +# ---------------------------------------------------------------------------------------------------- +def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog): + # Arrange + model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) + max_words_supported = 10 + query = " ".join(["hello"]*100) + truncated_query = " ".join(["hello"]*max_words_supported) + + # Act + try: + with caplog.at_level(logging.INFO, logger="src.search_type.image_search"): + image_search.query( + query, + count = 1, + model = model.image_search) + # Assert + except RuntimeError as e: + if "The size of tensor a (102) must match the size of tensor b (77)" in str(e): + assert False, f"Query length exceeds max tokens supported by model\n" + assert f"Find Images by Text: {truncated_query}" in caplog.text, "Query not truncated" + + # ---------------------------------------------------------------------------------------------------- def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog): # Arrange