mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Re-rank using cross encoder to get even more relevant results
The cross encoder re-ranked results are much better for more distant queries. It does take more time with the cross-encoder re-ranking but it seems worth it to get more relevant results
This commit is contained in:
parent
9864a2b551
commit
0914f284bb
1 changed files with 12 additions and 12 deletions
|
@ -42,13 +42,13 @@ def search(query):
|
|||
hits = hits[0] # Get the hits for the first query
|
||||
|
||||
##### Re-Ranking #####
|
||||
## Now, score all retrieved passages with the cross_encoder
|
||||
#cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
|
||||
#cross_scores = cross_encoder.predict(cross_inp)
|
||||
#
|
||||
## Sort results by the cross-encoder scores
|
||||
#for idx in range(len(cross_scores)):
|
||||
# hits[idx]['cross-score'] = cross_scores[idx]
|
||||
# Now, score all retrieved passages with the cross_encoder
|
||||
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
|
||||
cross_scores = cross_encoder.predict(cross_inp)
|
||||
|
||||
# Sort results by the cross-encoder scores
|
||||
for idx in range(len(cross_scores)):
|
||||
hits[idx]['cross-score'] = cross_scores[idx]
|
||||
|
||||
# Output of top-5 hits from bi-encoder
|
||||
print("\n-------------------------\n")
|
||||
|
@ -58,11 +58,11 @@ def search(query):
|
|||
print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))
|
||||
|
||||
# Output of top-5 hits from re-ranker
|
||||
#print("\n-------------------------\n")
|
||||
#print("Top-3 Cross-Encoder Re-ranker hits")
|
||||
#hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
||||
#for hit in hits[0:3]:
|
||||
# print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
|
||||
print("\n-------------------------\n")
|
||||
print("Top-3 Cross-Encoder Re-ranker hits")
|
||||
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
||||
for hit in hits[0:3]:
|
||||
print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
|
||||
|
||||
while True:
|
||||
user_query = input("Enter your query: ")
|
||||
|
|
Loading…
Reference in a new issue