Use word to entry map, list comprehension to speed up explicit filter

- Code Changes
  - Use list comprehension and `torch.index_select' methods
    - to speed selection of entries, embedding tensors satisfying filter
    - avoid deep copy of entries, embeddings
    - avoid updating existing lists (of entries, embeddings)

  - Use word to entry map and set operations to mark entries satisfying
    inclusion, exclusion filters

- Results
  - Speed up explicit filtering by two orders of magnitude
  - Improve consistency of speed up across inclusion and exclusion filtering
This commit is contained in:
Debanjum Singh Solanky 2022-09-04 15:09:09 +03:00
parent 28d3dc1434
commit 191a656ed7

View file

@ -25,7 +25,7 @@ class ExplicitFilter:
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
self.entry_key = entry_key
self.search_type = search_type
self.entries_by_word_set = None
self.word_to_entry_index = dict()
self.cache = {}
@ -33,24 +33,28 @@ class ExplicitFilter:
if self.filter_file.exists() and not regenerate:
start = time.time()
with self.filter_file.open('rb') as f:
self.entries_by_word_set = pickle.load(f)
self.word_to_entry_index = pickle.load(f)
end = time.time()
logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
else:
start = time.time()
self.cache = {} # Clear cache on (re-)generating entries_by_word_set
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
self.entries_by_word_set = [set(word.lower()
for word
in re.split(entry_splitter, entry[self.entry_key])
if word != "")
for entry in entries]
# Create map of words to entries they exist in
for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, entry[self.entry_key].lower()):
if word == '':
continue
if word not in self.word_to_entry_index:
self.word_to_entry_index[word] = set()
self.word_to_entry_index[word].add(entry_index)
with self.filter_file.open('wb') as f:
pickle.dump(self.entries_by_word_set, f)
pickle.dump(self.word_to_entry_index, f)
end = time.time()
logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
return self.entries_by_word_set
return self.word_to_entry_index
def can_filter(self, raw_query):
@ -69,7 +73,7 @@ class ExplicitFilter:
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query))
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip()
end = time.time()
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
@ -84,41 +88,33 @@ class ExplicitFilter:
entries, embeddings = self.cache[cache_key]
return query, entries, embeddings
# deep copy original embeddings, entries before filtering
if not self.word_to_entry_index:
self.load(raw_entries, regenerate=False)
start = time.time()
embeddings= deepcopy(raw_embeddings)
entries = deepcopy(raw_entries)
end = time.time()
logger.debug(f"Create copy of embeddings, entries for manipulation: {end - start:.3f} seconds")
if not self.entries_by_word_set:
self.load(entries, regenerate=False)
# track id of entries to exclude
start = time.time()
entries_to_exclude = set()
# mark entries that do not contain all required_words for exclusion
# mark entries that contain all required_words for inclusion
entries_with_all_required_words = set(range(len(raw_entries)))
if len(required_words) > 0:
for id, words_in_entry in enumerate(self.entries_by_word_set):
if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id)
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
# mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set()
if len(blocked_words) > 0:
for id, words_in_entry in enumerate(self.entries_by_word_set):
if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id)
end = time.time()
logger.debug(f"Mark entries not satisfying filter: {end - start} seconds")
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
# delete entries (and their embeddings) marked for exclusion
start = time.time()
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
end = time.time()
logger.debug(f"Delete entries not satisfying filter: {end - start} seconds")
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# get entries (and their embeddings) satisfying inclusion and exclusion filters
start = time.time()
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices]
embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
logger.debug(f"Keep entries satisfying filter: {end - start} seconds")
# Cache results
self.cache[cache_key] = entries, embeddings