mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-20 06:55:08 +00:00
Use cross-encoder to rerank search results by default on GPU machines
Latest sentence-transformer package uses GPU for cross-encoder. This makes it fast enough to enable reranking on machines with GPU. Enabling search reranking by default allows (at least) users with GPUs to side-step learning the UI affordance to rerank results (i.e hitting Cmd/Ctrl-Enter or ENTER).
This commit is contained in:
parent
42d4bc6b14
commit
1105d8814f
3 changed files with 10 additions and 4 deletions
|
@ -50,7 +50,7 @@ dependencies = [
|
||||||
"pyyaml == 6.0",
|
"pyyaml == 6.0",
|
||||||
"rich >= 13.3.1",
|
"rich >= 13.3.1",
|
||||||
"schedule == 1.1.0",
|
"schedule == 1.1.0",
|
||||||
"sentence-transformers == 2.3.1",
|
"sentence-transformers == 2.5.1",
|
||||||
"transformers >= 4.28.0",
|
"transformers >= 4.28.0",
|
||||||
"torch == 2.0.1",
|
"torch == 2.0.1",
|
||||||
"uvicorn == 0.17.6",
|
"uvicorn == 0.17.6",
|
||||||
|
|
|
@ -177,8 +177,9 @@ def deduplicated_search_responses(hits: List[SearchResponse]):
|
||||||
|
|
||||||
|
|
||||||
def rerank_and_sort_results(hits, query, rank_results, search_model_name):
|
def rerank_and_sort_results(hits, query, rank_results, search_model_name):
|
||||||
# If we have more than one result and reranking is enabled
|
# Rerank results if explicitly requested or if device has GPU
|
||||||
rank_results = rank_results and len(list(hits)) > 1
|
# AND if we have more than one result
|
||||||
|
rank_results = (rank_results or state.device.type != "cpu") and len(list(hits)) > 1
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
if rank_results:
|
if rank_results:
|
||||||
|
|
|
@ -331,7 +331,12 @@ def batcher(iterable, max_n):
|
||||||
yield (x for x in chunk if x is not None)
|
yield (x for x in chunk if x is not None)
|
||||||
|
|
||||||
|
|
||||||
|
def is_env_var_true(env_var: str, default: str = "false") -> bool:
|
||||||
|
"""Get state of boolean environment variable"""
|
||||||
|
return os.getenv(env_var, default).lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
def in_debug_mode():
|
def in_debug_mode():
|
||||||
"""Check if Khoj is running in debug mode.
|
"""Check if Khoj is running in debug mode.
|
||||||
Set KHOJ_DEBUG environment variable to true to enable debug mode."""
|
Set KHOJ_DEBUG environment variable to true to enable debug mode."""
|
||||||
return os.getenv("KHOJ_DEBUG", "false").lower() == "true"
|
return is_env_var_true("KHOJ_DEBUG")
|
||||||
|
|
Loading…
Add table
Reference in a new issue