From 6d7ab501138b0a86fc30e2192c5c16d9749e8e44 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 12 Jul 2022 13:58:32 +0400 Subject: [PATCH] Run Explicit Filter on Entries, Embeddings before Semantic Search for Query - Issue - Explicit filtering was earlier being done after search by bi-encoder but before re-ranking by cross-encoder - This was limiting the quality of results being returned. As the bi-encoder returned results which were going to be excluded. So the burden of improving those limited results post filtering was on the cross-encoder by re-ranking the remaining results based on query - Fix - Given the embeddings corresponding to an entry are at the same index in their respective lists. We can run the filter for blocked, required words before the search by the bi-encoder model. And limit entries, embeddings being considered for the current query - Result - Semantic search by the bi-encoder gets to return most relevant results for the query, knowing that the results aren't going to be filtered out after. So the cross-encoder shoulders less of the burden of improving results - Corollary - This pre-filtering technique allows us to apply other explicit filters on entries relevant for the current query - E.g limit search for entries within date/time specified in query --- src/main.py | 8 ++--- src/search_type/asymmetric.py | 67 ++++++++++++++++++++++------------- 2 files changed, 46 insertions(+), 29 deletions(-) diff --git a/src/main.py b/src/main.py index 4666e30f..2b580cb8 100644 --- a/src/main.py +++ b/src/main.py @@ -58,17 +58,17 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None): if (t == SearchType.Notes or t == None) and model.notes_search: # query notes - hits = asymmetric.query(user_query, model.notes_search, device=device) + hits, entries = asymmetric.query(user_query, model.notes_search, device=device) # collate and return results - return asymmetric.collate_results(hits, model.notes_search.entries, results_count) + return asymmetric.collate_results(hits, entries, results_count) if (t == SearchType.Music or t == None) and model.music_search: # query music library - hits = asymmetric.query(user_query, model.music_search, device=device) + hits, entries = asymmetric.query(user_query, model.music_search, device=device) # collate and return results - return asymmetric.collate_results(hits, model.music_search.entries, results_count) + return asymmetric.collate_results(hits, entries, results_count) if (t == SearchType.Ledger or t == None) and model.ledger_search: # query transactions diff --git a/src/search_type/asymmetric.py b/src/search_type/asymmetric.py index 524eb9a1..611501f7 100644 --- a/src/search_type/asymmetric.py +++ b/src/search_type/asymmetric.py @@ -6,6 +6,7 @@ import gzip import re import argparse import pathlib +from copy import deepcopy # External Packages import torch @@ -100,24 +101,25 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) + # Copy original embeddings, entries to filter them for query + corpus_embeddings = deepcopy(model.corpus_embeddings) + entries = deepcopy(model.entries) + + # Filter to entries that contain all required_words and no blocked_words + entries, corpus_embeddings = explicit_filter(entries, corpus_embeddings, required_words, blocked_words) + if entries is None or len(entries) == 0: + return {} + # Encode the query using the bi-encoder question_embedding = model.bi_encoder.encode([query], convert_to_tensor=True) question_embedding.to(device) question_embedding = util.normalize_embeddings(question_embedding) # Find relevant entries for the query - hits = util.semantic_search(question_embedding, model.corpus_embeddings, top_k=model.top_k, score_function=util.dot_score) - hits = hits[0] # Get the hits for the first query - - # Filter out entries that contain required words and do not contain blocked words - hits = explicit_filter(hits, - [entry[0] for entry in model.entries], - required_words,blocked_words) - if hits is None or len(hits) == 0: - return hits + hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=model.top_k, score_function=util.dot_score)[0] # Score all retrieved entries using the cross-encoder - cross_inp = [[query, model.entries[hit['corpus_id']][0]] for hit in hits] + cross_inp = [[query, entries[hit['corpus_id']][0]] for hit in hits] cross_scores = model.cross_encoder.predict(cross_inp) # Store cross-encoder scores in results dictionary for ranking @@ -128,28 +130,43 @@ def query(raw_query: str, model: TextSearchModel, device=torch.device('cpu')): hits.sort(key=lambda x: x['score'], reverse=True) # sort by bi-encoder score hits.sort(key=lambda x: x['cross-score'], reverse=True) # sort by cross-encoder score - return hits + return hits, entries -def explicit_filter(hits, entries, required_words, blocked_words): - hits_by_word_set = [(set(word.lower() +def explicit_filter(entries, embeddings, required_words, blocked_words): + if len(required_words) == 0 and len(blocked_words) == 0: + return entries, embeddings + + # convert each entry to a set of words + entries_by_word_set = [set(word.lower() for word in re.split( - r',|\.| |\]|\[\(|\)|\{|\}', - entries[hit['corpus_id']]) - if word != ""), - hit) - for hit in hits] + r',|\.| |\]|\[\(|\)|\{|\}', # split on fullstop, comma or any brackets + entry[0]) + if word != "") + for entry in entries] - if len(required_words) == 0 and len(blocked_words) == 0: - return hits + # track id of entries to exclude + entries_to_exclude = set() + + # mark entries that do not contain all required_words for exclusion if len(required_words) > 0: - return [hit for (words_in_entry, hit) in hits_by_word_set - if required_words.intersection(words_in_entry) and not blocked_words.intersection(words_in_entry)] + for id, words_in_entry in enumerate(entries_by_word_set): + if not required_words.issubset(words_in_entry): + entries_to_exclude.add(id) + + # mark entries that contain any blocked_words for exclusion if len(blocked_words) > 0: - return [hit for (words_in_entry, hit) in hits_by_word_set - if not blocked_words.intersection(words_in_entry)] - return hits + for id, words_in_entry in enumerate(entries_by_word_set): + if words_in_entry.intersection(blocked_words): + entries_to_exclude.add(id) + + # delete entries (and their embeddings) marked for exclusion + for id in sorted(list(entries_to_exclude), reverse=True): + del entries[id] + embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) + + return entries, embeddings def render_results(hits, entries, count=5, display_biencoder_results=False):