From 1105d8814fcd83ae6fa1357547042e3c6a29d34e Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Wed, 14 Feb 2024 18:37:53 +0530 Subject: [PATCH] Use cross-encoder to rerank search results by default on GPU machines Latest sentence-transformer package uses GPU for cross-encoder. This makes it fast enough to enable reranking on machines with GPU. Enabling search reranking by default allows (at least) users with GPUs to side-step learning the UI affordance to rerank results (i.e hitting Cmd/Ctrl-Enter or ENTER). --- pyproject.toml | 2 +- src/khoj/search_type/text_search.py | 5 +++-- src/khoj/utils/helpers.py | 7 ++++++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 17003c6c..8bab7876 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dependencies = [ "pyyaml == 6.0", "rich >= 13.3.1", "schedule == 1.1.0", - "sentence-transformers == 2.3.1", + "sentence-transformers == 2.5.1", "transformers >= 4.28.0", "torch == 2.0.1", "uvicorn == 0.17.6", diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py index d5ea35e6..48bc9e46 100644 --- a/src/khoj/search_type/text_search.py +++ b/src/khoj/search_type/text_search.py @@ -177,8 +177,9 @@ def deduplicated_search_responses(hits: List[SearchResponse]): def rerank_and_sort_results(hits, query, rank_results, search_model_name): - # If we have more than one result and reranking is enabled - rank_results = rank_results and len(list(hits)) > 1 + # Rerank results if explicitly requested or if device has GPU + # AND if we have more than one result + rank_results = (rank_results or state.device.type != "cpu") and len(list(hits)) > 1 # Score all retrieved entries using the cross-encoder if rank_results: diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index d2b64296..f30ddd04 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -331,7 +331,12 @@ def batcher(iterable, max_n): yield (x for x in chunk if x is not None) +def is_env_var_true(env_var: str, default: str = "false") -> bool: + """Get state of boolean environment variable""" + return os.getenv(env_var, default).lower() == "true" + + def in_debug_mode(): """Check if Khoj is running in debug mode. Set KHOJ_DEBUG environment variable to true to enable debug mode.""" - return os.getenv("KHOJ_DEBUG", "false").lower() == "true" + return is_env_var_true("KHOJ_DEBUG")