Only search across content types that work with asymmetric search

This commit is contained in:
Debanjum Singh Solanky 2023-06-20 02:28:51 -07:00
parent 6d94d6e75a
commit 0144e610d6

View file

@ -142,27 +142,17 @@ def search(
for filter in [DateFilter(), WordFilter(), FileFilter()]: for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(user_query) defiltered_query = filter.defilter(user_query)
with concurrent.futures.ThreadPoolExecutor() as executor: encoded_asymmetric_query = None
with timer("Encoding query for asymmetric search took", logger=logger): if t == None or (t != SearchType.Ledger and t != SearchType.Image):
encode_asymmetric_futures = executor.submit( with timer("Encoding query took", logger=logger):
state.model.org_search.bi_encoder.encode, encoded_asymmetric_query = util.normalize_embeddings(
[defiltered_query], state.model.org_search.bi_encoder.encode(
convert_to_tensor=True, [defiltered_query],
device=state.device, convert_to_tensor=True,
device=state.device,
)
) )
with timer("Encoding query for symmetric search took", logger=logger):
encode_symmetric_futures = executor.submit(
state.model.org_search.bi_encoder.encode,
[defiltered_query],
convert_to_tensor=True,
device=state.device,
)
with timer("Normalizing query embeddings took", logger=logger):
encoded_asymmetric_query = util.normalize_embeddings(encode_asymmetric_futures.result())
encoded_symmetric_query = util.normalize_embeddings(encode_symmetric_futures.result())
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
if (t == SearchType.Org or t == None) and state.model.org_search: if (t == SearchType.Org or t == None) and state.model.org_search:
# query org-mode notes # query org-mode notes
@ -206,14 +196,13 @@ def search(
) )
] ]
if (t == SearchType.Ledger or t == None) and state.model.ledger_search: if (t == SearchType.Ledger) and state.model.ledger_search:
# query transactions # query transactions
search_futures[t] += [ search_futures[t] += [
executor.submit( executor.submit(
text_search.query, text_search.query,
user_query, user_query,
state.model.ledger_search, state.model.ledger_search,
question_embedding=encoded_symmetric_query,
rank_results=r, rank_results=r,
score_threshold=score_threshold, score_threshold=score_threshold,
dedupe=dedupe, dedupe=dedupe,
@ -294,7 +283,7 @@ def search(
state.previous_query = user_query state.previous_query = user_query
end_time = time.time() end_time = time.time()
logger.debug(f"🔍 Search took {end_time - start_time:.2f} seconds") logger.debug(f"🔍 Search took: {end_time - start_time:.2f} seconds")
return results return results