Use word to entry map, list comprehension to speed up explicit filter

- Code Changes
  - Use list comprehension and `torch.index_select' methods
    - to speed selection of entries, embedding tensors satisfying filter
    - avoid deep copy of entries, embeddings
    - avoid updating existing lists (of entries, embeddings)

  - Use word to entry map and set operations to mark entries satisfying
    inclusion, exclusion filters

- Results
  - Speed up explicit filtering by two orders of magnitude
  - Improve consistency of speed up across inclusion and exclusion filtering
This commit is contained in:
Debanjum Singh Solanky 2022-09-04 15:09:09 +03:00
parent 28d3dc1434
commit 191a656ed7

View file

@ -25,7 +25,7 @@ class ExplicitFilter:
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
self.entry_key = entry_key self.entry_key = entry_key
self.search_type = search_type self.search_type = search_type
self.entries_by_word_set = None self.word_to_entry_index = dict()
self.cache = {} self.cache = {}
@ -33,24 +33,28 @@ class ExplicitFilter:
if self.filter_file.exists() and not regenerate: if self.filter_file.exists() and not regenerate:
start = time.time() start = time.time()
with self.filter_file.open('rb') as f: with self.filter_file.open('rb') as f:
self.entries_by_word_set = pickle.load(f) self.word_to_entry_index = pickle.load(f)
end = time.time() end = time.time()
logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds") logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
else: else:
start = time.time() start = time.time()
self.cache = {} # Clear cache on (re-)generating entries_by_word_set self.cache = {} # Clear cache on (re-)generating entries_by_word_set
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
self.entries_by_word_set = [set(word.lower() # Create map of words to entries they exist in
for word for entry_index, entry in enumerate(entries):
in re.split(entry_splitter, entry[self.entry_key]) for word in re.split(entry_splitter, entry[self.entry_key].lower()):
if word != "") if word == '':
for entry in entries] continue
if word not in self.word_to_entry_index:
self.word_to_entry_index[word] = set()
self.word_to_entry_index[word].add(entry_index)
with self.filter_file.open('wb') as f: with self.filter_file.open('wb') as f:
pickle.dump(self.entries_by_word_set, f) pickle.dump(self.word_to_entry_index, f)
end = time.time() end = time.time()
logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
return self.entries_by_word_set return self.word_to_entry_index
def can_filter(self, raw_query): def can_filter(self, raw_query):
@ -69,7 +73,7 @@ class ExplicitFilter:
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)]) required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)]) blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)) query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip()
end = time.time() end = time.time()
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds") logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
@ -84,41 +88,33 @@ class ExplicitFilter:
entries, embeddings = self.cache[cache_key] entries, embeddings = self.cache[cache_key]
return query, entries, embeddings return query, entries, embeddings
# deep copy original embeddings, entries before filtering if not self.word_to_entry_index:
self.load(raw_entries, regenerate=False)
start = time.time() start = time.time()
embeddings= deepcopy(raw_embeddings)
entries = deepcopy(raw_entries)
end = time.time()
logger.debug(f"Create copy of embeddings, entries for manipulation: {end - start:.3f} seconds")
if not self.entries_by_word_set: # mark entries that contain all required_words for inclusion
self.load(entries, regenerate=False) entries_with_all_required_words = set(range(len(raw_entries)))
# track id of entries to exclude
start = time.time()
entries_to_exclude = set()
# mark entries that do not contain all required_words for exclusion
if len(required_words) > 0: if len(required_words) > 0:
for id, words_in_entry in enumerate(self.entries_by_word_set): entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id)
# mark entries that contain any blocked_words for exclusion # mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set()
if len(blocked_words) > 0: if len(blocked_words) > 0:
for id, words_in_entry in enumerate(self.entries_by_word_set): entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id)
end = time.time()
logger.debug(f"Mark entries not satisfying filter: {end - start} seconds")
# delete entries (and their embeddings) marked for exclusion
start = time.time()
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
end = time.time() end = time.time()
logger.debug(f"Delete entries not satisfying filter: {end - start} seconds") logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# get entries (and their embeddings) satisfying inclusion and exclusion filters
start = time.time()
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices]
embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
logger.debug(f"Keep entries satisfying filter: {end - start} seconds")
# Cache results # Cache results
self.cache[cache_key] = entries, embeddings self.cache[cache_key] = entries, embeddings