Make cross-encoder re-rank results if query param set on /search API

- Improve search speed by ~10x
  Tested on corpus of 125K lines, 12.5K entries

- Allow cross-encoder to re-rank results by settings &?r=true when querying /search API
  - It's an optional param that default to False
  - Earlier all results were re-ranked by cross-encoder
  - Making this configurable allows for much faster results, if desired
    but for lower accuracy
This commit is contained in:
Debanjum Singh Solanky 2022-07-26 22:56:36 +04:00
parent b1e64fd4a8
commit 1168244c92
4 changed files with 22 additions and 19 deletions

View file

@ -59,7 +59,7 @@ async def config_data(updated_config: FullConfig):
return config return config
@app.get('/search') @app.get('/search')
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
if q is None or q == '': if q is None or q == '':
print(f'No query param (q) passed in API call to initiate search') print(f'No query param (q) passed in API call to initiate search')
return {} return {}
@ -72,7 +72,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Org or t == None) and model.orgmode_search: if (t == SearchType.Org or t == None) and model.orgmode_search:
# query org-mode notes # query org-mode notes
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, model.orgmode_search, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose) hits, entries = text_search.query(user_query, model.orgmode_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@ -83,7 +83,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Music or t == None) and model.music_search: if (t == SearchType.Music or t == None) and model.music_search:
# query music library # query music library
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose) hits, entries = text_search.query(user_query, model.music_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@ -94,7 +94,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Markdown or t == None) and model.orgmode_search: if (t == SearchType.Markdown or t == None) and model.orgmode_search:
# query markdown files # query markdown files
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, model.markdown_search, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose) hits, entries = text_search.query(user_query, model.markdown_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@ -105,7 +105,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
if (t == SearchType.Ledger or t == None) and model.ledger_search: if (t == SearchType.Ledger or t == None) and model.ledger_search:
# query transactions # query transactions
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, model.ledger_search, filters=[ExplicitFilter(), DateFilter()], verbose=verbose) hits, entries = text_search.query(user_query, model.ledger_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results

View file

@ -63,7 +63,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
return corpus_embeddings return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = [], verbose=0): def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0):
"Search for entries that answer the query" "Search for entries that answer the query"
query = raw_query query = raw_query
@ -108,6 +108,7 @@ def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list =
print(f"Search Time: {end - start:.3f} seconds") print(f"Search Time: {end - start:.3f} seconds")
# Score all retrieved entries using the cross-encoder # Score all retrieved entries using the cross-encoder
if rank_results:
start = time.time() start = time.time()
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits] cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
cross_scores = model.cross_encoder.predict(cross_inp) cross_scores = model.cross_encoder.predict(cross_inp)
@ -122,6 +123,7 @@ def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list =
# Order results by cross-encoder score followed by bi-encoder score # Order results by cross-encoder score followed by bi-encoder score
start = time.time() start = time.time()
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
if rank_results:
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
end = time.time() end = time.time()
if verbose > 1: if verbose > 1:
@ -152,7 +154,7 @@ def collate_results(hits, entries, count=5):
return [ return [
{ {
"entry": entries[hit['corpus_id']]['raw'], "entry": entries[hit['corpus_id']]['raw'],
"score": f"{hit['cross-score']:.3f}" "score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
} }
for hit for hit
in hits[0:count]] in hits[0:count]]

View file

@ -29,7 +29,8 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
# Act # Act
hits, entries = text_search.query( hits, entries = text_search.query(
query, query,
model = model.notes_search) model = model.notes_search,
rank_results=True)
results = text_search.collate_results( results = text_search.collate_results(
hits, hits,

View file

@ -119,7 +119,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
user_query = "How to git install application?" user_query = "How to git install application?"
# Act # Act
response = client.get(f"/search?q={user_query}&n=1&t=org") response = client.get(f"/search?q={user_query}&n=1&t=org&r=true")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200