mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Make cross-encoder re-rank results if query param set on /search API
- Improve search speed by ~10x Tested on corpus of 125K lines, 12.5K entries - Allow cross-encoder to re-rank results by settings &?r=true when querying /search API - It's an optional param that default to False - Earlier all results were re-ranked by cross-encoder - Making this configurable allows for much faster results, if desired but for lower accuracy
This commit is contained in:
parent
b1e64fd4a8
commit
1168244c92
4 changed files with 22 additions and 19 deletions
10
src/main.py
10
src/main.py
|
@ -59,7 +59,7 @@ async def config_data(updated_config: FullConfig):
|
|||
return config
|
||||
|
||||
@app.get('/search')
|
||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
||||
if q is None or q == '':
|
||||
print(f'No query param (q) passed in API call to initiate search')
|
||||
return {}
|
||||
|
@ -72,7 +72,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
|||
if (t == SearchType.Org or t == None) and model.orgmode_search:
|
||||
# query org-mode notes
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, model.orgmode_search, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
|
||||
hits, entries = text_search.query(user_query, model.orgmode_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
@ -83,7 +83,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
|||
if (t == SearchType.Music or t == None) and model.music_search:
|
||||
# query music library
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
|
||||
hits, entries = text_search.query(user_query, model.music_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
@ -94,7 +94,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
|||
if (t == SearchType.Markdown or t == None) and model.orgmode_search:
|
||||
# query markdown files
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, model.markdown_search, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
|
||||
hits, entries = text_search.query(user_query, model.markdown_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
@ -105,7 +105,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
|||
if (t == SearchType.Ledger or t == None) and model.ledger_search:
|
||||
# query transactions
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, model.ledger_search, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
|
||||
hits, entries = text_search.query(user_query, model.ledger_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
|
|
@ -63,7 +63,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
|||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = [], verbose=0):
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0):
|
||||
"Search for entries that answer the query"
|
||||
query = raw_query
|
||||
|
||||
|
@ -108,21 +108,23 @@ def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list =
|
|||
print(f"Search Time: {end - start:.3f} seconds")
|
||||
|
||||
# Score all retrieved entries using the cross-encoder
|
||||
start = time.time()
|
||||
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
|
||||
if rank_results:
|
||||
start = time.time()
|
||||
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Cross-Encoder Predict Time: {end - start:.3f} seconds")
|
||||
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]['cross-score'] = cross_scores[idx]
|
||||
# Store cross-encoder scores in results dictionary for ranking
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]['cross-score'] = cross_scores[idx]
|
||||
|
||||
# Order results by cross-encoder score followed by bi-encoder score
|
||||
start = time.time()
|
||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
if rank_results:
|
||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||
end = time.time()
|
||||
if verbose > 1:
|
||||
print(f"Rank Time: {end - start:.3f} seconds")
|
||||
|
@ -152,7 +154,7 @@ def collate_results(hits, entries, count=5):
|
|||
return [
|
||||
{
|
||||
"entry": entries[hit['corpus_id']]['raw'],
|
||||
"score": f"{hit['cross-score']:.3f}"
|
||||
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
|
||||
}
|
||||
for hit
|
||||
in hits[0:count]]
|
||||
|
|
|
@ -29,7 +29,8 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
|
|||
# Act
|
||||
hits, entries = text_search.query(
|
||||
query,
|
||||
model = model.notes_search)
|
||||
model = model.notes_search,
|
||||
rank_results=True)
|
||||
|
||||
results = text_search.collate_results(
|
||||
hits,
|
||||
|
|
|
@ -119,7 +119,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
|
|||
user_query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
||||
response = client.get(f"/search?q={user_query}&n=1&t=org&r=true")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
|
Loading…
Reference in a new issue