mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 19:03:01 +01:00
Make getting deduped search results configurable via Search API
This commit is contained in:
parent
b6cdc5c7cb
commit
672f61529e
2 changed files with 14 additions and 7 deletions
|
@ -60,6 +60,7 @@ def search(
|
||||||
t: Optional[SearchType] = None,
|
t: Optional[SearchType] = None,
|
||||||
r: Optional[bool] = False,
|
r: Optional[bool] = False,
|
||||||
score_threshold: Optional[float | None] = None,
|
score_threshold: Optional[float | None] = None,
|
||||||
|
dedupe: Optional[bool] = True,
|
||||||
):
|
):
|
||||||
results: List[SearchResponse] = []
|
results: List[SearchResponse] = []
|
||||||
if q is None or q == "":
|
if q is None or q == "":
|
||||||
|
@ -72,7 +73,7 @@ def search(
|
||||||
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
||||||
|
|
||||||
# return cached results, if available
|
# return cached results, if available
|
||||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}"
|
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}-{dedupe}"
|
||||||
if query_cache_key in state.query_cache:
|
if query_cache_key in state.query_cache:
|
||||||
logger.debug(f"Return response from query cache")
|
logger.debug(f"Return response from query cache")
|
||||||
return state.query_cache[query_cache_key]
|
return state.query_cache[query_cache_key]
|
||||||
|
@ -81,7 +82,7 @@ def search(
|
||||||
# query org-mode notes
|
# query org-mode notes
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(
|
hits, entries = text_search.query(
|
||||||
user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold
|
user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||||
)
|
)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
|
@ -92,7 +93,7 @@ def search(
|
||||||
# query markdown files
|
# query markdown files
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(
|
hits, entries = text_search.query(
|
||||||
user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold
|
user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||||
)
|
)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
|
@ -103,7 +104,7 @@ def search(
|
||||||
# query transactions
|
# query transactions
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(
|
hits, entries = text_search.query(
|
||||||
user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold
|
user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||||
)
|
)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
|
@ -114,7 +115,7 @@ def search(
|
||||||
# query music library
|
# query music library
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(
|
hits, entries = text_search.query(
|
||||||
user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold
|
user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold, dedupe=dedupe
|
||||||
)
|
)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
|
@ -148,6 +149,7 @@ def search(
|
||||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||||
rank_results=r,
|
rank_results=r,
|
||||||
score_threshold=score_threshold,
|
score_threshold=score_threshold,
|
||||||
|
dedupe=dedupe,
|
||||||
)
|
)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
|
|
|
@ -101,7 +101,11 @@ def compute_embeddings(
|
||||||
|
|
||||||
|
|
||||||
def query(
|
def query(
|
||||||
raw_query: str, model: TextSearchModel, rank_results: bool = False, score_threshold: float = -math.inf
|
raw_query: str,
|
||||||
|
model: TextSearchModel,
|
||||||
|
rank_results: bool = False,
|
||||||
|
score_threshold: float = -math.inf,
|
||||||
|
dedupe: bool = True,
|
||||||
) -> Tuple[List[dict], List[Entry]]:
|
) -> Tuple[List[dict], List[Entry]]:
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||||
|
@ -139,7 +143,8 @@ def query(
|
||||||
hits = sort_results(rank_results, hits)
|
hits = sort_results(rank_results, hits)
|
||||||
|
|
||||||
# Deduplicate entries by raw entry text before showing to users
|
# Deduplicate entries by raw entry text before showing to users
|
||||||
hits = deduplicate_results(entries, hits)
|
if dedupe:
|
||||||
|
hits = deduplicate_results(entries, hits)
|
||||||
|
|
||||||
return hits, entries
|
return hits, entries
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue