mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Truncate image queries below max tokens length supported by ML model
This would previously return the infamous tensor size mismatch error Verify this error is not raised since adding the query truncation logic
This commit is contained in:
parent
3d9ed91e42
commit
6908b6eed3
2 changed files with 26 additions and 1 deletions
|
@ -143,7 +143,9 @@ def query(raw_query, count, model: ImageSearchModel):
|
|||
query.thumbnail((640, query.height)) # scale down image for faster processing
|
||||
logger.info(f"Find Images by Image: {query_imagepath}")
|
||||
else:
|
||||
query = raw_query
|
||||
# Truncate words in query to stay below max_tokens supported by ML model
|
||||
max_words = 20
|
||||
query = " ".join(raw_query.split()[:max_words])
|
||||
logger.info(f"Find Images by Text: {query}")
|
||||
|
||||
# Now we encode the query (which can either be an image or a text string)
|
||||
|
|
|
@ -23,6 +23,7 @@ def test_image_search_setup(content_config: ContentConfig, search_config: Search
|
|||
assert len(image_search_model.image_embeddings) == 3
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_image_metadata(content_config: ContentConfig):
|
||||
"Verify XMP Description and Subjects Extracted from Image"
|
||||
# Arrange
|
||||
|
@ -80,6 +81,28 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
|||
actual_image_path.unlink()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_image_search_query_truncated(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||
# Arrange
|
||||
model.image_search = image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||
max_words_supported = 10
|
||||
query = " ".join(["hello"]*100)
|
||||
truncated_query = " ".join(["hello"]*max_words_supported)
|
||||
|
||||
# Act
|
||||
try:
|
||||
with caplog.at_level(logging.INFO, logger="src.search_type.image_search"):
|
||||
image_search.query(
|
||||
query,
|
||||
count = 1,
|
||||
model = model.image_search)
|
||||
# Assert
|
||||
except RuntimeError as e:
|
||||
if "The size of tensor a (102) must match the size of tensor b (77)" in str(e):
|
||||
assert False, f"Query length exceeds max tokens supported by model\n"
|
||||
assert f"Find Images by Text: {truncated_query}" in caplog.text, "Query not truncated"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_image_search_by_filepath(content_config: ContentConfig, search_config: SearchConfig, caplog):
|
||||
# Arrange
|
||||
|
|
Loading…
Reference in a new issue