diff --git a/src/search_type/image_search.py b/src/search_type/image_search.py index 323556e7..6dc57b6e 100644 --- a/src/search_type/image_search.py +++ b/src/search_type/image_search.py @@ -143,7 +143,9 @@ def query(raw_query, count, model: ImageSearchModel): query.thumbnail((640, query.height)) # scale down image for faster processing logger.info(f"Find Images by Image: {query_imagepath}") else: - query = raw_query + # Truncate words in query to stay below max_tokens supported by ML model + max_words = 20 + query = " ".join(raw_query.split()[:max_words]) logger.info(f"Find Images by Text: {query}") # Now we encode the query (which can either be an image or a text string) diff --git a/tests/test_image_search.py b/tests/test_image_search.py index 2a396f54..77000d60 100644 --- a/tests/test_image_search.py +++ b/tests/test_image_search.py @@ -23,6 +23,7 @@ def test_image_search_setup(content_config: ContentConfig, search_config: Search assert len(image_search_model.image_embeddings) == 3 +# ---------------------------------------------------------------------------------------------------- def test_image_metadata(content_config: ContentConfig): "Verify XMP Description and Subjects Extracted from Image" # Arrange @@ -80,6 +81,28 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig actual_image_path.unlink() +# ---------------------------------------------------------------------------------------------------- +def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog): + # Arrange + model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False) + max_words_supported = 10 + query = " ".join(["hello"]*100) + truncated_query = " ".join(["hello"]*max_words_supported) + + # Act + try: + with caplog.at_level(logging.INFO, logger="src.search_type.image_search"): + image_search.query( + query, + count = 1, + model = model.image_search) + # Assert + except RuntimeError as e: + if "The size of tensor a (102) must match the size of tensor b (77)" in str(e): + assert False, f"Query length exceeds max tokens supported by model\n" + assert f"Find Images by Text: {truncated_query}" in caplog.text, "Query not truncated" + + # ---------------------------------------------------------------------------------------------------- def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog): # Arrange