diff --git a/asymmetric.py b/asymmetric.py index 2e84bc20..fd4574cd 100644 --- a/asymmetric.py +++ b/asymmetric.py @@ -42,13 +42,13 @@ def search(query): hits = hits[0] # Get the hits for the first query ##### Re-Ranking ##### - ## Now, score all retrieved passages with the cross_encoder - #cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] - #cross_scores = cross_encoder.predict(cross_inp) - # - ## Sort results by the cross-encoder scores - #for idx in range(len(cross_scores)): - # hits[idx]['cross-score'] = cross_scores[idx] + # Now, score all retrieved passages with the cross_encoder + cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] + cross_scores = cross_encoder.predict(cross_inp) + + # Sort results by the cross-encoder scores + for idx in range(len(cross_scores)): + hits[idx]['cross-score'] = cross_scores[idx] # Output of top-5 hits from bi-encoder print("\n-------------------------\n") @@ -58,11 +58,11 @@ def search(query): print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " "))) # Output of top-5 hits from re-ranker - #print("\n-------------------------\n") - #print("Top-3 Cross-Encoder Re-ranker hits") - #hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) - #for hit in hits[0:3]: - # print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " "))) + print("\n-------------------------\n") + print("Top-3 Cross-Encoder Re-ranker hits") + hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) + for hit in hits[0:3]: + print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " "))) while True: user_query = input("Enter your query: ")