Support filtering for results above threshold score in search API

This commit is contained in:
Debanjum Singh Solanky 2023-03-05 15:43:27 -06:00
parent 45f461d175
commit d73042426d
3 changed files with 39 additions and 10 deletions

View file

@ -1,4 +1,5 @@
# Standard Packages # Standard Packages
import math
import yaml import yaml
import logging import logging
from typing import List, Optional from typing import List, Optional
@ -53,7 +54,13 @@ async def set_config_data(updated_config: FullConfig):
@api.get("/search", response_model=List[SearchResponse]) @api.get("/search", response_model=List[SearchResponse])
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): def search(
q: str,
n: Optional[int] = 5,
t: Optional[SearchType] = None,
r: Optional[bool] = False,
score_threshold: Optional[float | None] = None,
):
results: List[SearchResponse] = [] results: List[SearchResponse] = []
if q is None or q == "": if q is None or q == "":
logger.warn(f"No query param (q) passed in API call to initiate search") logger.warn(f"No query param (q) passed in API call to initiate search")
@ -62,9 +69,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
# initialize variables # initialize variables
user_query = q.strip() user_query = q.strip()
results_count = n results_count = n
score_threshold = score_threshold if score_threshold is not None else -math.inf
# return cached results, if available # return cached results, if available
query_cache_key = f"{user_query}-{n}-{t}-{r}" query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}"
if query_cache_key in state.query_cache: if query_cache_key in state.query_cache:
logger.debug(f"Return response from query cache") logger.debug(f"Return response from query cache")
return state.query_cache[query_cache_key] return state.query_cache[query_cache_key]
@ -72,7 +80,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Org or t == None) and state.model.orgmode_search: if (t == SearchType.Org or t == None) and state.model.orgmode_search:
# query org-mode notes # query org-mode notes
with timer("Query took", logger): with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r) hits, entries = text_search.query(
user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold
)
# collate and return results # collate and return results
with timer("Collating results took", logger): with timer("Collating results took", logger):
@ -81,7 +91,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
elif (t == SearchType.Markdown or t == None) and state.model.markdown_search: elif (t == SearchType.Markdown or t == None) and state.model.markdown_search:
# query markdown files # query markdown files
with timer("Query took", logger): with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r) hits, entries = text_search.query(
user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold
)
# collate and return results # collate and return results
with timer("Collating results took", logger): with timer("Collating results took", logger):
@ -90,7 +102,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
elif (t == SearchType.Ledger or t == None) and state.model.ledger_search: elif (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions # query transactions
with timer("Query took", logger): with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r) hits, entries = text_search.query(
user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold
)
# collate and return results # collate and return results
with timer("Collating results took", logger): with timer("Collating results took", logger):
@ -99,7 +113,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
elif (t == SearchType.Music or t == None) and state.model.music_search: elif (t == SearchType.Music or t == None) and state.model.music_search:
# query music library # query music library
with timer("Query took", logger): with timer("Query took", logger):
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r) hits, entries = text_search.query(
user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold
)
# collate and return results # collate and return results
with timer("Collating results took", logger): with timer("Collating results took", logger):
@ -108,7 +124,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
elif (t == SearchType.Image or t == None) and state.model.image_search: elif (t == SearchType.Image or t == None) and state.model.image_search:
# query images # query images
with timer("Query took", logger): with timer("Query took", logger):
hits = image_search.query(user_query, results_count, state.model.image_search) hits = image_search.query(
user_query, results_count, state.model.image_search, score_threshold=score_threshold
)
output_directory = constants.web_directory / "images" output_directory = constants.web_directory / "images"
# collate and return results # collate and return results
@ -129,6 +147,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
# Get plugin search model for specified search type, or the first one if none specified # Get plugin search model for specified search type, or the first one if none specified
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())), state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
rank_results=r, rank_results=r,
score_threshold=score_threshold,
) )
# collate and return results # collate and return results

View file

@ -1,5 +1,6 @@
# Standard Packages # Standard Packages
import glob import glob
import math
import pathlib import pathlib
import copy import copy
import shutil import shutil
@ -142,7 +143,7 @@ def extract_metadata(image_name):
return image_processed_metadata return image_processed_metadata
def query(raw_query, count, model: ImageSearchModel): def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
# Set query to image content if query is of form file:/path/to/file.png # Set query to image content if query is of form file:/path/to/file.png
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file(): if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True) query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
@ -198,6 +199,9 @@ def query(raw_query, count, model: ImageSearchModel):
for corpus_id, scores in image_hits.items() for corpus_id, scores in image_hits.items()
] ]
# Filter results by score threshold
hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
# Sort the images based on their combined metadata, image scores # Sort the images based on their combined metadata, image scores
return sorted(hits, key=lambda hit: hit["score"], reverse=True) return sorted(hits, key=lambda hit: hit["score"], reverse=True)

View file

@ -1,5 +1,6 @@
# Standard Packages # Standard Packages
import logging import logging
import math
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Type from typing import List, Tuple, Type
@ -99,7 +100,9 @@ def compute_embeddings(
return corpus_embeddings return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> Tuple[List[dict], List[Entry]]: def query(
raw_query: str, model: TextSearchModel, rank_results: bool = False, score_threshold: float = -math.inf
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query" "Search for entries that answer the query"
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
@ -129,6 +132,9 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
if rank_results: if rank_results:
hits = cross_encoder_score(model.cross_encoder, query, entries, hits) hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
# Filter results by score threshold
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
# Order results by cross-encoder score followed by bi-encoder score # Order results by cross-encoder score followed by bi-encoder score
hits = sort_results(rank_results, hits) hits = sort_results(rank_results, hits)
@ -143,7 +149,7 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
SearchResponse.parse_obj( SearchResponse.parse_obj(
{ {
"entry": entries[hit["corpus_id"]].raw, "entry": entries[hit["corpus_id"]].raw,
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}", "score": f"{hit.get('cross-score', 'score')}:.3f",
"additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled}, "additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled},
} }
) )