mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Support filtering for results above threshold score in search API
This commit is contained in:
parent
45f461d175
commit
d73042426d
3 changed files with 39 additions and 10 deletions
|
@ -1,4 +1,5 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
|
import math
|
||||||
import yaml
|
import yaml
|
||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
@ -53,7 +54,13 @@ async def set_config_data(updated_config: FullConfig):
|
||||||
|
|
||||||
|
|
||||||
@api.get("/search", response_model=List[SearchResponse])
|
@api.get("/search", response_model=List[SearchResponse])
|
||||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
def search(
|
||||||
|
q: str,
|
||||||
|
n: Optional[int] = 5,
|
||||||
|
t: Optional[SearchType] = None,
|
||||||
|
r: Optional[bool] = False,
|
||||||
|
score_threshold: Optional[float | None] = None,
|
||||||
|
):
|
||||||
results: List[SearchResponse] = []
|
results: List[SearchResponse] = []
|
||||||
if q is None or q == "":
|
if q is None or q == "":
|
||||||
logger.warn(f"No query param (q) passed in API call to initiate search")
|
logger.warn(f"No query param (q) passed in API call to initiate search")
|
||||||
|
@ -62,9 +69,10 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||||
# initialize variables
|
# initialize variables
|
||||||
user_query = q.strip()
|
user_query = q.strip()
|
||||||
results_count = n
|
results_count = n
|
||||||
|
score_threshold = score_threshold if score_threshold is not None else -math.inf
|
||||||
|
|
||||||
# return cached results, if available
|
# return cached results, if available
|
||||||
query_cache_key = f"{user_query}-{n}-{t}-{r}"
|
query_cache_key = f"{user_query}-{n}-{t}-{r}-{score_threshold}"
|
||||||
if query_cache_key in state.query_cache:
|
if query_cache_key in state.query_cache:
|
||||||
logger.debug(f"Return response from query cache")
|
logger.debug(f"Return response from query cache")
|
||||||
return state.query_cache[query_cache_key]
|
return state.query_cache[query_cache_key]
|
||||||
|
@ -72,7 +80,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||||
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
||||||
# query org-mode notes
|
# query org-mode notes
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
|
hits, entries = text_search.query(
|
||||||
|
user_query, state.model.orgmode_search, rank_results=r, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
with timer("Collating results took", logger):
|
with timer("Collating results took", logger):
|
||||||
|
@ -81,7 +91,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||||
elif (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
elif (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||||
# query markdown files
|
# query markdown files
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
|
hits, entries = text_search.query(
|
||||||
|
user_query, state.model.markdown_search, rank_results=r, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
with timer("Collating results took", logger):
|
with timer("Collating results took", logger):
|
||||||
|
@ -90,7 +102,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||||
elif (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
elif (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
|
hits, entries = text_search.query(
|
||||||
|
user_query, state.model.ledger_search, rank_results=r, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
with timer("Collating results took", logger):
|
with timer("Collating results took", logger):
|
||||||
|
@ -99,7 +113,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||||
elif (t == SearchType.Music or t == None) and state.model.music_search:
|
elif (t == SearchType.Music or t == None) and state.model.music_search:
|
||||||
# query music library
|
# query music library
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
|
hits, entries = text_search.query(
|
||||||
|
user_query, state.model.music_search, rank_results=r, score_threshold=score_threshold
|
||||||
|
)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
with timer("Collating results took", logger):
|
with timer("Collating results took", logger):
|
||||||
|
@ -108,7 +124,9 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||||
elif (t == SearchType.Image or t == None) and state.model.image_search:
|
elif (t == SearchType.Image or t == None) and state.model.image_search:
|
||||||
# query images
|
# query images
|
||||||
with timer("Query took", logger):
|
with timer("Query took", logger):
|
||||||
hits = image_search.query(user_query, results_count, state.model.image_search)
|
hits = image_search.query(
|
||||||
|
user_query, results_count, state.model.image_search, score_threshold=score_threshold
|
||||||
|
)
|
||||||
output_directory = constants.web_directory / "images"
|
output_directory = constants.web_directory / "images"
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
|
@ -129,6 +147,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
||||||
# Get plugin search model for specified search type, or the first one if none specified
|
# Get plugin search model for specified search type, or the first one if none specified
|
||||||
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
state.model.plugin_search.get(t.value) or next(iter(state.model.plugin_search.values())),
|
||||||
rank_results=r,
|
rank_results=r,
|
||||||
|
score_threshold=score_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
import glob
|
import glob
|
||||||
|
import math
|
||||||
import pathlib
|
import pathlib
|
||||||
import copy
|
import copy
|
||||||
import shutil
|
import shutil
|
||||||
|
@ -142,7 +143,7 @@ def extract_metadata(image_name):
|
||||||
return image_processed_metadata
|
return image_processed_metadata
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query, count, model: ImageSearchModel):
|
def query(raw_query, count, model: ImageSearchModel, score_threshold: float = -math.inf):
|
||||||
# Set query to image content if query is of form file:/path/to/file.png
|
# Set query to image content if query is of form file:/path/to/file.png
|
||||||
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
if raw_query.startswith("file:") and pathlib.Path(raw_query[5:]).is_file():
|
||||||
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
query_imagepath = resolve_absolute_path(pathlib.Path(raw_query[5:]), strict=True)
|
||||||
|
@ -198,6 +199,9 @@ def query(raw_query, count, model: ImageSearchModel):
|
||||||
for corpus_id, scores in image_hits.items()
|
for corpus_id, scores in image_hits.items()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Filter results by score threshold
|
||||||
|
hits = [hit for hit in hits if hit["image_score"] >= score_threshold]
|
||||||
|
|
||||||
# Sort the images based on their combined metadata, image scores
|
# Sort the images based on their combined metadata, image scores
|
||||||
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
return sorted(hits, key=lambda hit: hit["score"], reverse=True)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Type
|
from typing import List, Tuple, Type
|
||||||
|
|
||||||
|
@ -99,7 +100,9 @@ def compute_embeddings(
|
||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) -> Tuple[List[dict], List[Entry]]:
|
def query(
|
||||||
|
raw_query: str, model: TextSearchModel, rank_results: bool = False, score_threshold: float = -math.inf
|
||||||
|
) -> Tuple[List[dict], List[Entry]]:
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||||
|
|
||||||
|
@ -129,6 +132,9 @@ def query(raw_query: str, model: TextSearchModel, rank_results: bool = False) ->
|
||||||
if rank_results:
|
if rank_results:
|
||||||
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
|
hits = cross_encoder_score(model.cross_encoder, query, entries, hits)
|
||||||
|
|
||||||
|
# Filter results by score threshold
|
||||||
|
hits = [hit for hit in hits if hit.get("cross-score", hit.get("score")) >= score_threshold]
|
||||||
|
|
||||||
# Order results by cross-encoder score followed by bi-encoder score
|
# Order results by cross-encoder score followed by bi-encoder score
|
||||||
hits = sort_results(rank_results, hits)
|
hits = sort_results(rank_results, hits)
|
||||||
|
|
||||||
|
@ -143,7 +149,7 @@ def collate_results(hits, entries: List[Entry], count=5) -> List[SearchResponse]
|
||||||
SearchResponse.parse_obj(
|
SearchResponse.parse_obj(
|
||||||
{
|
{
|
||||||
"entry": entries[hit["corpus_id"]].raw,
|
"entry": entries[hit["corpus_id"]].raw,
|
||||||
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}",
|
"score": f"{hit.get('cross-score', 'score')}:.3f",
|
||||||
"additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled},
|
"additional": {"file": entries[hit["corpus_id"]].file, "compiled": entries[hit["corpus_id"]].compiled},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue