mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Support filtering for results above threshold score in search API
This commit is contained in:
parent
45f461d175
commit
d73042426d
3 changed files with 39 additions and 10 deletions
|
@ -1,4 +1,5 @@
|
|||
# Standard Packages
|
||||
import math
|
||||
import yaml
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
@ -53,7 +54,13 @@ async def set_config_data(updated_config: FullConfig):
|
|||
|
||||
|
||||
@api.get("/search", response_model=List[SearchResponse])
|
||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
||||
def search(
|
||||
q: str,
|
||||
n: Optional[int] = 5,
|
||||
t: Optional[SearchType] = None,
|
||||
r: Optional[bool] = False,
|
||||
score_threshold: Optional[float | None] = None,
|
||||
):
|
||||
results: List[SearchResponse] = []
|
||||
if q is None or q == "":
|
||||
logger.warn(f"No query param (q) passed in API call to initiate search")
|
||||
|
@ -62,9 +69,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
# initialize variables
|
||||
user_query = q.strip()
|
||||
results_count = n
|
||||
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
||||
|
||||
# return cached results, if available
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}"
|
||||
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}"
|
||||
if query_cache_key in state.query_cache:
|
||||
logger.debug(f"Return response from query cache")
|
||||
return state.query_cache[query_cache_key]
|
||||
|
@ -72,7 +80,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
||||
# query org-mode notes
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
|
@ -81,7 +91,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
elif (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||
# query markdown files
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
|
@ -90,7 +102,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
elif (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
||||
# query transactions
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
|
@ -99,7 +113,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
elif (t == SearchType.Music or t == None) and state.model.music_search:
|
||||
# query music library
|
||||
with timer("Query took", logger):
|
||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
|
||||
hits, entries = text_search.query(
|
||||
user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
with timer("Collating results took", logger):
|
||||
|
@ -108,7 +124,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
elif (t == SearchType.Image or t == None) and state.model.image_search:
|
||||
# query images
|
||||
with timer("Query took", logger):
|
||||
hits = image_search.query(user_query, results_count, state.model.image_search)
|
||||
hits = image_search.query(
|
||||
user_query, results_count, state.model.image_search, score_threshold=score_threshold
|
||||
)
|
||||
output_directory = constants.web_directory / "images"
|
||||
|
||||
# collate and return results
|
||||
|
@ -129,6 +147,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
# Get plugin search model for specified search type, or the first one if none specified
|
||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||
rank_results=r,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
# collate and return results
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Standard Packages
|
||||
import glob
|
||||
import math
|
||||
import pathlib
|
||||
import copy
|
||||
import shutil
|
||||
|
@ -142,7 +143,7 @@ def extract_metadata(image_name):
|
|||
return image_processed_metadata
|
||||
|
||||
|
||||
def query(raw_query, count, model: ImageSearchModel):
|
||||
def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
||||
# Set query to image content if query is of form file:/path/to/file.png
|
||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
||||
|
@ -198,6 +199,9 @@ def query(raw_query, count, model: ImageSearchModel):
|
|||
for corpus_id, scores in image_hits.items()
|
||||
]
|
||||
|
||||
# Filter results by score threshold
|
||||
hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
|
||||
|
||||
# Sort the images based on their combined metadata, image scores
|
||||
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Standard Packages
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
|
@ -99,7 +100,9 @@ def compute_embeddings(
|
|||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> Tuple[List[dict], List[Entry]]:
|
||||
def query(
|
||||
raw_query: str, model: TextSearchModel, rank_results: bool = False, score_threshold: float = -math.inf
|
||||
) -> Tuple[List[dict], List[Entry]]:
|
||||
"Search for entries that answer the query"
|
||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||
|
||||
|
@ -129,6 +132,9 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
|
|||
if rank_results:
|
||||
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
|
||||
|
||||
# Filter results by score threshold
|
||||
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
|
||||
|
||||
# Order results by cross-encoder score followed by bi-encoder score
|
||||
hits = sort_results(rank_results, hits)
|
||||
|
||||
|
@ -143,7 +149,7 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
|
|||
SearchResponse.parse_obj(
|
||||
{
|
||||
"entry": entries[hit["corpus_id"]].raw,
|
||||
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}",
|
||||
"score": f"{hit.get('cross-score', 'score')}:.3f",
|
||||
"additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled},
|
||||
}
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue