Encode the asymmetric, symmetric search queries in parallel for speed

Use timer to measure time to encode queries and total search time
This commit is contained in:
Debanjum Singh Solanky 2023-06-20 01:17:21 -07:00
parent db07362ca3
commit 6d94d6e75a

View file

@ -2,6 +2,7 @@
from collections import defaultdict
import concurrent.futures
import math
import time
import yaml
import logging
from datetime import datetime
@ -118,6 +119,8 @@ def search(
dedupe: Optional[bool] = True,
client: Optional[str] = None,
):
start_time = time.time()
results: List[SearchResponse] = []
if q is None or q == "":
logger.warn(f"No query param (q) passed in API call to initiate search")
@ -139,15 +142,26 @@ def search(
for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(user_query)
encoded_asymmetric_query = state.model.org_search.bi_encoder.encode(
[defiltered_query], convert_to_tensor=True, device=state.device
)
encoded_asymmetric_query = util.normalize_embeddings(encoded_asymmetric_query)
with concurrent.futures.ThreadPoolExecutor() as executor:
with timer("Encoding query for asymmetric search took", logger=logger):
encode_asymmetric_futures = executor.submit(
state.model.org_search.bi_encoder.encode,
[defiltered_query],
convert_to_tensor=True,
device=state.device,
)
encoded_symmetric_query = state.model.org_search.bi_encoder.encode(
[defiltered_query], convert_to_tensor=True, device=state.device
)
encoded_symmetric_query = util.normalize_embeddings(encoded_symmetric_query)
with timer("Encoding query for symmetric search took", logger=logger):
encode_symmetric_futures = executor.submit(
state.model.org_search.bi_encoder.encode,
[defiltered_query],
convert_to_tensor=True,
device=state.device,
)
with timer("Normalizing query embeddings took", logger=logger):
encoded_asymmetric_query = util.normalize_embeddings(encode_asymmetric_futures.result())
encoded_symmetric_query = util.normalize_embeddings(encode_symmetric_futures.result())
with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == None) and state.model.org_search:
@ -279,6 +293,9 @@ def search(
]
state.previous_query = user_query
end_time = time.time()
logger.debug(f"🔍 Search took {end_time - start_time:.2f} seconds")
return results