Make getting deduped search results configurable via Search API

This commit is contained in:
Debanjum Singh Solanky 2023-03-06 23:48:46 -06:00
parent b6cdc5c7cb
commit 672f61529e
2 changed files with 14 additions and 7 deletions

View file

@ -60,6 +60,7 @@ def search(
t: Optional[SearchType] = None, t: Optional[SearchType] = None,
r: Optional[bool] = False, r: Optional[bool] = False,
score_threshold: Optional[float | None] = None, score_threshold: Optional[float | None] = None,
dedupe: Optional[bool] = True,
): ):
results: List[SearchResponse] = [] results: List[SearchResponse] = []
if q is None or q == "": if q is None or q == "":
@ -72,7 +73,7 @@ def search(
score_threshold = score_threshold if score_threshold is not None else -math.inf score_threshold = score_threshold if score_threshold is not None else -math.inf
# return cached results, if available # return cached results, if available
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}" query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
if query_cache_key in state.query_cache: if query_cache_key in state.query_cache:
logger.debug(f"Return response from query cache") logger.debug(f"Return response from query cache")
return state.query_cache[query_cache_key] return state.query_cache[query_cache_key]
@ -81,7 +82,7 @@ def search(
# query org-mode notes # query org-mode notes
with timer("Query took", logger): with timer("Query took", logger):
hits, entries = text_search.query( hits, entries = text_search.query(
user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
) )
# collate and return results # collate and return results
@ -92,7 +93,7 @@ def search(
# query markdown files # query markdown files
with timer("Query took", logger): with timer("Query took", logger):
hits, entries = text_search.query( hits, entries = text_search.query(
user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
) )
# collate and return results # collate and return results
@ -103,7 +104,7 @@ def search(
# query transactions # query transactions
with timer("Query took", logger): with timer("Query took", logger):
hits, entries = text_search.query( hits, entries = text_search.query(
user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
) )
# collate and return results # collate and return results
@ -114,7 +115,7 @@ def search(
# query music library # query music library
with timer("Query took", logger): with timer("Query took", logger):
hits, entries = text_search.query( hits, entries = text_search.query(
user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
) )
# collate and return results # collate and return results
@ -148,6 +149,7 @@ def search(
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())), state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
rank_results=r, rank_results=r,
score_threshold=score_threshold, score_threshold=score_threshold,
dedupe=dedupe,
) )
# collate and return results # collate and return results

View file

@ -101,7 +101,11 @@ def compute_embeddings(
def query( def query(
raw_query: str, model: TextSearchModel, rank_results: bool = False, score_threshold: float = -math.inf raw_query: str,
model: TextSearchModel,
rank_results: bool = False,
score_threshold: float = -math.inf,
dedupe: bool = True,
) -> Tuple[List[dict], List[Entry]]: ) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query" "Search for entries that answer the query"
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
@ -139,7 +143,8 @@ def query(
hits = sort_results(rank_results, hits) hits = sort_results(rank_results, hits)
# Deduplicate entries by raw entry text before showing to users # Deduplicate entries by raw entry text before showing to users
hits = deduplicate_results(entries, hits) if dedupe:
hits = deduplicate_results(entries, hits)
return hits, entries return hits, entries