From 672f61529e4a6e5e8a1dc7ba48568933aa2a9ee1 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 6 Mar 2023 23:48:46 -0600 Subject: [PATCH] Make getting deduped search results configurable via Search API --- src/khoj/routers/api.py | 12 +++++++----- src/khoj/search_type/text_search.py | 9 +++++++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 637136a4..cd4d7a54 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -60,6 +60,7 @@ def search( t: Optional[SearchType] = None, r: Optional[bool] = False, score_threshold: Optional[float | None] = None, + dedupe: Optional[bool] = True, ): results: List[SearchResponse] = [] if q is None or q == "": @@ -72,7 +73,7 @@ def search( score_threshold = score_threshold if score_threshold is not None else -math.inf # return cached results, if available - query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}" + query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}" if query_cache_key in state.query_cache: logger.debug(f"Return response from query cache") return state.query_cache[query_cache_key] @@ -81,7 +82,7 @@ def search( # query org-mode notes with timer("Query took", logger): hits, entries = text_search.query( - user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold + user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe ) # collate and return results @@ -92,7 +93,7 @@ def search( # query markdown files with timer("Query took", logger): hits, entries = text_search.query( - user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold + user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe ) # collate and return results @@ -103,7 +104,7 @@ def search( # query transactions with timer("Query took", logger): hits, entries = text_search.query( - user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold + user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe ) # collate and return results @@ -114,7 +115,7 @@ def search( # query music library with timer("Query took", logger): hits, entries = text_search.query( - user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold + user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe ) # collate and return results @@ -148,6 +149,7 @@ def search( state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())), rank_results=r, score_threshold=score_threshold, + dedupe=dedupe, ) # collate and return results diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index 5bc430c8..69093942 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -101,7 +101,11 @@ def compute_embeddings( def query( - raw_query: str, model: TextSearchModel, rank_results: bool = False, score_threshold: float = -math.inf + raw_query: str, + model: TextSearchModel, + rank_results: bool = False, + score_threshold: float = -math.inf, + dedupe: bool = True, ) -> Tuple[List[dict], List[Entry]]: "Search for entries that answer the query" query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings @@ -139,7 +143,8 @@ def query( hits = sort_results(rank_results, hits) # Deduplicate entries by raw entry text before showing to users - hits = deduplicate_results(entries, hits) + if dedupe: + hits = deduplicate_results(entries, hits) return hits, entries