mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 19:03:01 +01:00
Make cross-encoder re-rank results if query param set on /search API
- Improve search speed by ~10x Tested on corpus of 125K lines, 12.5K entries - Allow cross-encoder to re-rank results by settings &?r=true when querying /search API - It's an optional param that default to False - Earlier all results were re-ranked by cross-encoder - Making this configurable allows for much faster results, if desired but for lower accuracy
This commit is contained in:
parent
b1e64fd4a8
commit
1168244c92
4 changed files with 22 additions and 19 deletions
10
src/main.py
10
src/main.py
|
@ -59,7 +59,7 @@ async def config_data(updated_config: FullConfig):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@app.get('/search')
|
@app.get('/search')
|
||||||
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Optional[bool] = False):
|
||||||
if q is None or q == '':
|
if q is None or q == '':
|
||||||
print(f'No query param (q) passed in API call to initiate search')
|
print(f'No query param (q) passed in API call to initiate search')
|
||||||
return {}
|
return {}
|
||||||
|
@ -72,7 +72,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
if (t == SearchType.Org or t == None) and model.orgmode_search:
|
if (t == SearchType.Org or t == None) and model.orgmode_search:
|
||||||
# query org-mode notes
|
# query org-mode notes
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
hits, entries = text_search.query(user_query, model.orgmode_search, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
|
hits, entries = text_search.query(user_query, model.orgmode_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
|
@ -83,7 +83,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
if (t == SearchType.Music or t == None) and model.music_search:
|
if (t == SearchType.Music or t == None) and model.music_search:
|
||||||
# query music library
|
# query music library
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
hits, entries = text_search.query(user_query, model.music_search, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
|
hits, entries = text_search.query(user_query, model.music_search, rank_results=r, device=device, filters=[DateFilter(), ExplicitFilter()], verbose=verbose)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
|
@ -94,7 +94,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
if (t == SearchType.Markdown or t == None) and model.orgmode_search:
|
if (t == SearchType.Markdown or t == None) and model.orgmode_search:
|
||||||
# query markdown files
|
# query markdown files
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
hits, entries = text_search.query(user_query, model.markdown_search, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
|
hits, entries = text_search.query(user_query, model.markdown_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
|
@ -105,7 +105,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None):
|
||||||
if (t == SearchType.Ledger or t == None) and model.ledger_search:
|
if (t == SearchType.Ledger or t == None) and model.ledger_search:
|
||||||
# query transactions
|
# query transactions
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
hits, entries = text_search.query(user_query, model.ledger_search, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
|
hits, entries = text_search.query(user_query, model.ledger_search, rank_results=r, device=device, filters=[ExplicitFilter(), DateFilter()], verbose=verbose)
|
||||||
query_end = time.time()
|
query_end = time.time()
|
||||||
|
|
||||||
# collate and return results
|
# collate and return results
|
||||||
|
|
|
@ -63,7 +63,7 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False, d
|
||||||
return corpus_embeddings
|
return corpus_embeddings
|
||||||
|
|
||||||
|
|
||||||
def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list = [], verbose=0):
|
def query(raw_query: str, model: TextSearchModel, rank_results=False, device='cpu', filters: list = [], verbose=0):
|
||||||
"Search for entries that answer the query"
|
"Search for entries that answer the query"
|
||||||
query = raw_query
|
query = raw_query
|
||||||
|
|
||||||
|
@ -108,6 +108,7 @@ def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list =
|
||||||
print(f"Search Time: {end - start:.3f} seconds")
|
print(f"Search Time: {end - start:.3f} seconds")
|
||||||
|
|
||||||
# Score all retrieved entries using the cross-encoder
|
# Score all retrieved entries using the cross-encoder
|
||||||
|
if rank_results:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
cross_inp = [[query, entries[hit['corpus_id']]['compiled']] for hit in hits]
|
||||||
cross_scores = model.cross_encoder.predict(cross_inp)
|
cross_scores = model.cross_encoder.predict(cross_inp)
|
||||||
|
@ -122,6 +123,7 @@ def query(raw_query: str, model: TextSearchModel, device='cpu', filters: list =
|
||||||
# Order results by cross-encoder score followed by bi-encoder score
|
# Order results by cross-encoder score followed by bi-encoder score
|
||||||
start = time.time()
|
start = time.time()
|
||||||
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score
|
||||||
|
if rank_results:
|
||||||
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score
|
||||||
end = time.time()
|
end = time.time()
|
||||||
if verbose > 1:
|
if verbose > 1:
|
||||||
|
@ -152,7 +154,7 @@ def collate_results(hits, entries, count=5):
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"entry": entries[hit['corpus_id']]['raw'],
|
"entry": entries[hit['corpus_id']]['raw'],
|
||||||
"score": f"{hit['cross-score']:.3f}"
|
"score": f"{hit['cross-score'] if 'cross-score' in hit else hit['score']:.3f}"
|
||||||
}
|
}
|
||||||
for hit
|
for hit
|
||||||
in hits[0:count]]
|
in hits[0:count]]
|
||||||
|
|
|
@ -29,7 +29,8 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
|
||||||
# Act
|
# Act
|
||||||
hits, entries = text_search.query(
|
hits, entries = text_search.query(
|
||||||
query,
|
query,
|
||||||
model = model.notes_search)
|
model = model.notes_search,
|
||||||
|
rank_results=True)
|
||||||
|
|
||||||
results = text_search.collate_results(
|
results = text_search.collate_results(
|
||||||
hits,
|
hits,
|
||||||
|
|
|
@ -119,7 +119,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
|
||||||
user_query = "How to git install application?"
|
user_query = "How to git install application?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
response = client.get(f"/search?q={user_query}&n=1&t=org&r=true")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
Loading…
Reference in a new issue