From 1168244c923e7bdd578aba956e38f20fe18aea9f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 26 Jul 2022 22:56:36 +0400 Subject: [PATCH] Make cross-encoder re-rank results if query param set on /search API - Improve search speed by ~10x Tested on corpus of 125K lines, 12.5K entries - Allow cross-encoder to re-rank results by settings &?r=true when querying /search API - It's an optional param that default to False - Earlier all results were re-ranked by cross-encoder - Making this configurable allows for much faster results, if desired but for lower accuracy --- src/main.py | 10 +++++----- src/search_type/text_search.py | 26 ++++++++++++++------------ tests/test_asymmetric_search.py | 3 ++- tests/test_client.py | 2 +- 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/main.py b/src/main.py index cb975322..de8e287f 100644 --- a/src/main.py +++ b/src/main.py @@ -59,7 +59,7 @@ async def config_data(updated_config: FullConfig): return config @app.get('/search') -def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): +def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False): if q is None or q == '': print(f'No query param (q) passed in API call to initiate search') return {} @@ -72,7 +72,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Org or t == None) and model.orgmode_search: # query org-mode notes query_start = time.time() - hits, entries = text_search.query(user_query, model.orgmode_search, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose) + hits, entries = text_search.query(user_query, model.orgmode_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose) query_end = time.time() # collate and return results @@ -83,7 +83,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Music or t == None) and model.music_search: # query music library query_start = time.time() - hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose) + hits, entries = text_search.query(user_query, model.music_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose) query_end = time.time() # collate and return results @@ -94,7 +94,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Markdown or t == None) and model.orgmode_search: # query markdown files query_start = time.time() - hits, entries = text_search.query(user_query, model.markdown_search, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose) + hits, entries = text_search.query(user_query, model.markdown_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose) query_end = time.time() # collate and return results @@ -105,7 +105,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Ledger or t == None) and model.ledger_search: # query transactions query_start = time.time() - hits, entries = text_search.query(user_query, model.ledger_search, filters=[ExplicitFilter(), DateFilter()], verbose=verbose) + hits, entries = text_search.query(user_query, model.ledger_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose) query_end = time.time() # collate and return results diff --git a/src/search_type/text_search.py b/src/search_type/text_search.py index 739aba40..81c92605 100644 --- a/src/search_type/text_search.py +++ b/src/search_type/text_search.py @@ -63,7 +63,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d return corpus_embeddings -def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = [], verbose=0): +def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0): "Search for entries that answer the query" query = raw_query @@ -108,21 +108,23 @@ def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = print(f"Search Time: {end - start:.3f} seconds") # Score all retrieved entries using the cross-encoder - start = time.time() - cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits] - cross_scores = model.cross_encoder.predict(cross_inp) - end = time.time() - if verbose > 1: - print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds") + if rank_results: + start = time.time() + cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits] + cross_scores = model.cross_encoder.predict(cross_inp) + end = time.time() + if verbose > 1: + print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds") - # Store cross-encoder scores in results dictionary for ranking - for idx in range(len(cross_scores)): - hits[idx]['cross-score'] = cross_scores[idx] + # Store cross-encoder scores in results dictionary for ranking + for idx in range(len(cross_scores)): + hits[idx]['cross-score'] = cross_scores[idx] # Order results by cross-encoder score followed by bi-encoder score start = time.time() hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score - hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score + if rank_results: + hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score end = time.time() if verbose > 1: print(f"Rank Time: {end - start:.3f} seconds") @@ -152,7 +154,7 @@ def collate_results(hits, entries, count=5): return [ { "entry": entries[hit['corpus_id']]['raw'], - "score": f"{hit['cross-score']:.3f}" + "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}" } for hit in hits[0:count]] diff --git a/tests/test_asymmetric_search.py b/tests/test_asymmetric_search.py index b14cc10d..135f9680 100644 --- a/tests/test_asymmetric_search.py +++ b/tests/test_asymmetric_search.py @@ -29,7 +29,8 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC # Act hits, entries = text_search.query( query, - model = model.notes_search) + model = model.notes_search, + rank_results=True) results = text_search.collate_results( hits, diff --git a/tests/test_client.py b/tests/test_client.py index 3efce1b8..04d26a80 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -119,7 +119,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig user_query = "How to git install application?" # Act - response = client.get(f"/search?q={user_query}&n=1&t=org") + response = client.get(f"/search?q={user_query}&n=1&t=org&r=true") # Assert assert response.status_code == 200