From 0914f284bb861700f500f37407e11d029114b4e8 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 31 Jul 2021 02:56:45 -0700 Subject: [PATCH] Re-rank using cross encoder to get even more relevant results The cross encoder re-ranked results are much better for more distant queries. It does take more time with the cross-encoder re-ranking but it seems worth it to get more relevant results --- asymmetric.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/asymmetric.py b/asymmetric.py index 2e84bc20..fd4574cd 100644 --- a/asymmetric.py +++ b/asymmetric.py @@ -42,13 +42,13 @@ def search(query): hits = hits[0] # Get the hits for the first query ##### Re-Ranking ##### - ## Now, score all retrieved passages with the cross_encoder - #cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] - #cross_scores = cross_encoder.predict(cross_inp) - # - ## Sort results by the cross-encoder scores - #for idx in range(len(cross_scores)): - # hits[idx]['cross-score'] = cross_scores[idx] + # Now, score all retrieved passages with the cross_encoder + cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] + cross_scores = cross_encoder.predict(cross_inp) + + # Sort results by the cross-encoder scores + for idx in range(len(cross_scores)): + hits[idx]['cross-score'] = cross_scores[idx] # Output of top-5 hits from bi-encoder print("\n-------------------------\n") @@ -58,11 +58,11 @@ def search(query): print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " "))) # Output of top-5 hits from re-ranker - #print("\n-------------------------\n") - #print("Top-3 Cross-Encoder Re-ranker hits") - #hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) - #for hit in hits[0:3]: - # print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " "))) + print("\n-------------------------\n") + print("Top-3 Cross-Encoder Re-ranker hits") + hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) + for hit in hits[0:3]: + print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " "))) while True: user_query = input("Enter your query: ")