From 191a656ed7c0c8442f208abe1d30520c305e947f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 4 Sep 2022 15:09:09 +0300 Subject: [PATCH] Use word to entry map, list comprehension to speed up explicit filter - Code Changes - Use list comprehension and `torch.index_select' methods - to speed selection of entries, embedding tensors satisfying filter - avoid deep copy of entries, embeddings - avoid updating existing lists (of entries, embeddings) - Use word to entry map and set operations to mark entries satisfying inclusion, exclusion filters - Results - Speed up explicit filtering by two orders of magnitude - Improve consistency of speed up across inclusion and exclusion filtering --- src/search_filter/explicit_filter.py | 70 +++++++++++++--------------- 1 file changed, 33 insertions(+), 37 deletions(-) diff --git a/src/search_filter/explicit_filter.py b/src/search_filter/explicit_filter.py index e715e8b6..6f64ede5 100644 --- a/src/search_filter/explicit_filter.py +++ b/src/search_filter/explicit_filter.py @@ -25,7 +25,7 @@ class ExplicitFilter: self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") self.entry_key = entry_key self.search_type = search_type - self.entries_by_word_set = None + self.word_to_entry_index = dict() self.cache = {} @@ -33,24 +33,28 @@ class ExplicitFilter: if self.filter_file.exists() and not regenerate: start = time.time() with self.filter_file.open('rb') as f: - self.entries_by_word_set = pickle.load(f) + self.word_to_entry_index = pickle.load(f) end = time.time() logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds") else: start = time.time() self.cache = {} # Clear cache on (re-)generating entries_by_word_set entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' - self.entries_by_word_set = [set(word.lower() - for word - in re.split(entry_splitter, entry[self.entry_key]) - if word != "") - for entry in entries] + # Create map of words to entries they exist in + for entry_index, entry in enumerate(entries): + for word in re.split(entry_splitter, entry[self.entry_key].lower()): + if word == '': + continue + if word not in self.word_to_entry_index: + self.word_to_entry_index[word] = set() + self.word_to_entry_index[word].add(entry_index) + with self.filter_file.open('wb') as f: - pickle.dump(self.entries_by_word_set, f) + pickle.dump(self.word_to_entry_index, f) end = time.time() logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") - return self.entries_by_word_set + return self.word_to_entry_index def can_filter(self, raw_query): @@ -69,7 +73,7 @@ class ExplicitFilter: required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)]) blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)]) - query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)) + query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip() end = time.time() logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") @@ -84,41 +88,33 @@ class ExplicitFilter: entries, embeddings = self.cache[cache_key] return query, entries, embeddings - # deep copy original embeddings, entries before filtering + if not self.word_to_entry_index: + self.load(raw_entries, regenerate=False) + start = time.time() - embeddings= deepcopy(raw_embeddings) - entries = deepcopy(raw_entries) - end = time.time() - logger.debug(f"Create copy of embeddings, entries for manipulation: {end - start:.3f} seconds") - if not self.entries_by_word_set: - self.load(entries, regenerate=False) - - # track id of entries to exclude - start = time.time() - entries_to_exclude = set() - - # mark entries that do not contain all required_words for exclusion + # mark entries that contain all required_words for inclusion + entries_with_all_required_words = set(range(len(raw_entries))) if len(required_words) > 0: - for id, words_in_entry in enumerate(self.entries_by_word_set): - if not required_words.issubset(words_in_entry): - entries_to_exclude.add(id) + entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words]) # mark entries that contain any blocked_words for exclusion + entries_with_any_blocked_words = set() if len(blocked_words) > 0: - for id, words_in_entry in enumerate(self.entries_by_word_set): - if words_in_entry.intersection(blocked_words): - entries_to_exclude.add(id) - end = time.time() - logger.debug(f"Mark entries not satisfying filter: {end - start} seconds") + entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words]) - # delete entries (and their embeddings) marked for exclusion - start = time.time() - for id in sorted(list(entries_to_exclude), reverse=True): - del entries[id] - embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) end = time.time() - logger.debug(f"Delete entries not satisfying filter: {end - start} seconds") + logger.debug(f"Mark entries satisfying filter: {end - start} seconds") + + # get entries (and their embeddings) satisfying inclusion and exclusion filters + start = time.time() + + included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words + entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices] + embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices))) + + end = time.time() + logger.debug(f"Keep entries satisfying filter: {end - start} seconds") # Cache results self.cache[cache_key] = entries, embeddings