mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Only search across content types that work with asymmetric search
This commit is contained in:
parent
6d94d6e75a
commit
0144e610d6
1 changed files with 11 additions and 22 deletions
|
@ -142,27 +142,17 @@ def search(
|
||||||
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||||
defiltered_query = filter.defilter(user_query)
|
defiltered_query = filter.defilter(user_query)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
encoded_asymmetric_query = None
|
||||||
with timer("Encoding query for asymmetric search took", logger=logger):
|
if t == None or (t != SearchType.Ledger and t != SearchType.Image):
|
||||||
encode_asymmetric_futures = executor.submit(
|
with timer("Encoding query took", logger=logger):
|
||||||
state.model.org_search.bi_encoder.encode,
|
encoded_asymmetric_query = util.normalize_embeddings(
|
||||||
[defiltered_query],
|
state.model.org_search.bi_encoder.encode(
|
||||||
convert_to_tensor=True,
|
[defiltered_query],
|
||||||
device=state.device,
|
convert_to_tensor=True,
|
||||||
|
device=state.device,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Encoding query for symmetric search took", logger=logger):
|
|
||||||
encode_symmetric_futures = executor.submit(
|
|
||||||
state.model.org_search.bi_encoder.encode,
|
|
||||||
[defiltered_query],
|
|
||||||
convert_to_tensor=True,
|
|
||||||
device=state.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
with timer("Normalizing query embeddings took", logger=logger):
|
|
||||||
encoded_asymmetric_query = util.normalize_embeddings(encode_asymmetric_futures.result())
|
|
||||||
encoded_symmetric_query = util.normalize_embeddings(encode_symmetric_futures.result())
|
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
if (t == SearchType.Org or t == None) and state.model.org_search:
|
if (t == SearchType.Org or t == None) and state.model.org_search:
|
||||||
# query org-mode notes
|
# query org-mode notes
|
||||||
|
@ -206,14 +196,13 @@ def search(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
if (t == SearchType.Ledger) and state.model.ledger_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
search_futures[t] += [
|
search_futures[t] += [
|
||||||
executor.submit(
|
executor.submit(
|
||||||
text_search.query,
|
text_search.query,
|
||||||
user_query,
|
user_query,
|
||||||
state.model.ledger_search,
|
state.model.ledger_search,
|
||||||
question_embedding=encoded_symmetric_query,
|
|
||||||
rank_results=r,
|
rank_results=r,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
dedupe=dedupe,
|
dedupe=dedupe,
|
||||||
|
@ -294,7 +283,7 @@ def search(
|
||||||
state.previous_query = user_query
|
state.previous_query = user_query
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
logger.debug(f"🔍 Search took {end_time - start_time:.2f} seconds")
|
logger.debug(f"🔍 Search took: {end_time - start_time:.2f} seconds")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue