Add explicit filters to asymmetric search

User can filter results to ones which include, exclude specified words
To show entries which include, exclude specific words, user should prepend
a '+', '-' before the word. E.g "+hello -bye"
This commit is contained in:
Debanjum Singh Solanky 2021-08-15 17:12:04 -07:00
parent 91a2c598fe
commit 660e6c3937

View file

@ -6,6 +6,7 @@ import time
import gzip
import os
import sys
import re
import torch
import argparse
import pathlib
@ -56,8 +57,13 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, verbose=False):
return corpus_embeddings
def query_notes(query, corpus_embeddings, entries, bi_encoder, cross_encoder, topk=100):
def query_notes(raw_query, corpus_embeddings, entries, bi_encoder, cross_encoder, top_k=100):
"Search all notes for entries that answer the query"
# Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
# Encode the query using the bi-encoder
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
@ -65,6 +71,11 @@ def query_notes(query, corpus_embeddings, entries, bi_encoder, cross_encoder, to
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
hits = hits[0] # Get the hits for the first query
# Filter results using explicit filters
hits = explicit_filter(hits, entries, required_words, blocked_words)
if hits is None or len(hits) == 0:
return hits
# Score all retrieved entries using the cross-encoder
cross_inp = [[query, entries[hit['corpus_id']]] for hit in hits]
cross_scores = cross_encoder.predict(cross_inp)
@ -76,6 +87,28 @@ def query_notes(query, corpus_embeddings, entries, bi_encoder, cross_encoder, to
# Order results by cross encoder score followed by biencoder score
hits.sort(key=lambda x: x['score'], reverse=True) # sort by biencoder score
hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross encoder score
return hits
def explicit_filter(hits, entries, required_words, blocked_words):
hits_by_word_set = [(set(word.lower()
for word
in re.split(
',|\.| |\]|\[\(|\)|\{|\}',
entries[hit['corpus_id']])
if word != ""),
hit)
for hit in hits]
if len(required_words) == 0 and len(blocked_words) == 0:
return hits
if len(required_words) > 0:
return [hit for (words_in_entry, hit) in hits_by_word_set
if required_words.intersection(words_in_entry) and not blocked_words.intersection(words_in_entry)]
if len(blocked_words) > 0:
return [hit for (words_in_entry, hit) in hits_by_word_set
if not blocked_words.intersection(words_in_entry)]
return hits