Re-rank using cross encoder to get even more relevant results

The cross encoder re-ranked results are much better for more distant queries.
It does take more time with the cross-encoder re-ranking but it seems
worth it to get more relevant results
This commit is contained in:
Debanjum Singh Solanky 2021-07-31 02:56:45 -07:00
parent 9864a2b551
commit 0914f284bb

View file

@ -42,13 +42,13 @@ def search(query):
hits = hits[0] # Get the hits for the first query hits = hits[0] # Get the hits for the first query
##### Re-Ranking ##### ##### Re-Ranking #####
## Now, score all retrieved passages with the cross_encoder # Now, score all retrieved passages with the cross_encoder
#cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
#cross_scores = cross_encoder.predict(cross_inp) cross_scores = cross_encoder.predict(cross_inp)
#
## Sort results by the cross-encoder scores # Sort results by the cross-encoder scores
#for idx in range(len(cross_scores)): for idx in range(len(cross_scores)):
# hits[idx]['cross-score'] = cross_scores[idx] hits[idx]['cross-score'] = cross_scores[idx]
# Output of top-5 hits from bi-encoder # Output of top-5 hits from bi-encoder
print("\n-------------------------\n") print("\n-------------------------\n")
@ -58,11 +58,11 @@ def search(query):
print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " "))) print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))
# Output of top-5 hits from re-ranker # Output of top-5 hits from re-ranker
#print("\n-------------------------\n") print("\n-------------------------\n")
#print("Top-3 Cross-Encoder Re-ranker hits") print("Top-3 Cross-Encoder Re-ranker hits")
#hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True) hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
#for hit in hits[0:3]: for hit in hits[0:3]:
# print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " "))) print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
while True: while True:
user_query = input("Enter your query: ") user_query = input("Enter your query: ")