Load entries_by_word_set from file only once on first load of explicit filter

This commit is contained in:
Debanjum Singh Solanky 2022-09-04 00:37:37 +03:00
parent 858d86075b
commit 8d9f507df3

View file

@ -24,29 +24,30 @@ class ExplicitFilter:
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
self.entry_key = entry_key self.entry_key = entry_key
self.search_type = search_type self.search_type = search_type
self.entries_by_word_set = None
def load(self, entries, regenerate=False): def load(self, entries, regenerate=False):
if self.filter_file.exists() and not regenerate: if self.filter_file.exists() and not regenerate:
start = time.time() start = time.time()
with self.filter_file.open('rb') as f: with self.filter_file.open('rb') as f:
entries_by_word_set = pickle.load(f) self.entries_by_word_set = pickle.load(f)
end = time.time() end = time.time()
logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds") logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
else: else:
start = time.time() start = time.time()
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
entries_by_word_set = [set(word.lower() self.entries_by_word_set = [set(word.lower()
for word for word
in re.split(entry_splitter, entry[self.entry_key]) in re.split(entry_splitter, entry[self.entry_key])
if word != "") if word != "")
for entry in entries] for entry in entries]
with self.filter_file.open('wb') as f: with self.filter_file.open('wb') as f:
pickle.dump(entries_by_word_set, f) pickle.dump(self.entries_by_word_set, f)
end = time.time() end = time.time()
logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
return entries_by_word_set return self.entries_by_word_set
def can_filter(self, raw_query): def can_filter(self, raw_query):
@ -73,8 +74,8 @@ class ExplicitFilter:
if len(required_words) == 0 and len(blocked_words) == 0: if len(required_words) == 0 and len(blocked_words) == 0:
return query, entries, embeddings return query, entries, embeddings
# load or generate word set for each entry if not self.entries_by_word_set:
entries_by_word_set = self.load(entries, regenerate=False) self.load(entries, regenerate=False)
# track id of entries to exclude # track id of entries to exclude
start = time.time() start = time.time()
@ -82,13 +83,13 @@ class ExplicitFilter:
# mark entries that do not contain all required_words for exclusion # mark entries that do not contain all required_words for exclusion
if len(required_words) > 0: if len(required_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set): for id, words_in_entry in enumerate(self.entries_by_word_set):
if not required_words.issubset(words_in_entry): if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id) entries_to_exclude.add(id)
# mark entries that contain any blocked_words for exclusion # mark entries that contain any blocked_words for exclusion
if len(blocked_words) > 0: if len(blocked_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set): for id, words_in_entry in enumerate(self.entries_by_word_set):
if words_in_entry.intersection(blocked_words): if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id) entries_to_exclude.add(id)
end = time.time() end = time.time()