mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Encode the asymmetric, symmetric search queries in parallel for speed
Use timer to measure time to encode queries and total search time
This commit is contained in:
parent
db07362ca3
commit
6d94d6e75a
1 changed files with 25 additions and 8 deletions
|
@ -2,6 +2,7 @@
|
|||
from collections import defaultdict
|
||||
import concurrent.futures
|
||||
import math
|
||||
import time
|
||||
import yaml
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
@ -118,6 +119,8 @@ def search(
|
|||
dedupe: Optional[bool] = True,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
start_time = time.time()
|
||||
|
||||
results: List[SearchResponse] = []
|
||||
if q is None or q == "":
|
||||
logger.warn(f"No query param (q) passed in API call to initiate search")
|
||||
|
@ -139,15 +142,26 @@ def search(
|
|||
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||
defiltered_query = filter.defilter(user_query)
|
||||
|
||||
encoded_asymmetric_query = state.model.org_search.bi_encoder.encode(
|
||||
[defiltered_query], convert_to_tensor=True, device=state.device
|
||||
)
|
||||
encoded_asymmetric_query = util.normalize_embeddings(encoded_asymmetric_query)
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
with timer("Encoding query for asymmetric search took", logger=logger):
|
||||
encode_asymmetric_futures = executor.submit(
|
||||
state.model.org_search.bi_encoder.encode,
|
||||
[defiltered_query],
|
||||
convert_to_tensor=True,
|
||||
device=state.device,
|
||||
)
|
||||
|
||||
encoded_symmetric_query = state.model.org_search.bi_encoder.encode(
|
||||
[defiltered_query], convert_to_tensor=True, device=state.device
|
||||
)
|
||||
encoded_symmetric_query = util.normalize_embeddings(encoded_symmetric_query)
|
||||
with timer("Encoding query for symmetric search took", logger=logger):
|
||||
encode_symmetric_futures = executor.submit(
|
||||
state.model.org_search.bi_encoder.encode,
|
||||
[defiltered_query],
|
||||
convert_to_tensor=True,
|
||||
device=state.device,
|
||||
)
|
||||
|
||||
with timer("Normalizing query embeddings took", logger=logger):
|
||||
encoded_asymmetric_query = util.normalize_embeddings(encode_asymmetric_futures.result())
|
||||
encoded_symmetric_query = util.normalize_embeddings(encode_symmetric_futures.result())
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
if (t == SearchType.Org or t == None) and state.model.org_search:
|
||||
|
@ -279,6 +293,9 @@ def search(
|
|||
]
|
||||
state.previous_query = user_query
|
||||
|
||||
end_time = time.time()
|
||||
logger.debug(f"🔍 Search took {end_time - start_time:.2f} seconds")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue