mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Use word to entry map, list comprehension to speed up explicit filter
- Code Changes - Use list comprehension and `torch.index_select' methods - to speed selection of entries, embedding tensors satisfying filter - avoid deep copy of entries, embeddings - avoid updating existing lists (of entries, embeddings) - Use word to entry map and set operations to mark entries satisfying inclusion, exclusion filters - Results - Speed up explicit filtering by two orders of magnitude - Improve consistency of speed up across inclusion and exclusion filtering
This commit is contained in:
parent
28d3dc1434
commit
191a656ed7
1 changed files with 33 additions and 37 deletions
|
@ -25,7 +25,7 @@ class ExplicitFilter:
|
|||
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
|
||||
self.entry_key = entry_key
|
||||
self.search_type = search_type
|
||||
self.entries_by_word_set = None
|
||||
self.word_to_entry_index = dict()
|
||||
self.cache = {}
|
||||
|
||||
|
||||
|
@ -33,24 +33,28 @@ class ExplicitFilter:
|
|||
if self.filter_file.exists() and not regenerate:
|
||||
start = time.time()
|
||||
with self.filter_file.open('rb') as f:
|
||||
self.entries_by_word_set = pickle.load(f)
|
||||
self.word_to_entry_index = pickle.load(f)
|
||||
end = time.time()
|
||||
logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
|
||||
else:
|
||||
start = time.time()
|
||||
self.cache = {} # Clear cache on (re-)generating entries_by_word_set
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
||||
self.entries_by_word_set = [set(word.lower()
|
||||
for word
|
||||
in re.split(entry_splitter, entry[self.entry_key])
|
||||
if word != "")
|
||||
for entry in entries]
|
||||
# Create map of words to entries they exist in
|
||||
for entry_index, entry in enumerate(entries):
|
||||
for word in re.split(entry_splitter, entry[self.entry_key].lower()):
|
||||
if word == '':
|
||||
continue
|
||||
if word not in self.word_to_entry_index:
|
||||
self.word_to_entry_index[word] = set()
|
||||
self.word_to_entry_index[word].add(entry_index)
|
||||
|
||||
with self.filter_file.open('wb') as f:
|
||||
pickle.dump(self.entries_by_word_set, f)
|
||||
pickle.dump(self.word_to_entry_index, f)
|
||||
end = time.time()
|
||||
logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
|
||||
|
||||
return self.entries_by_word_set
|
||||
return self.word_to_entry_index
|
||||
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
|
@ -69,7 +73,7 @@ class ExplicitFilter:
|
|||
|
||||
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
|
||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)])
|
||||
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query))
|
||||
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip()
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
|
||||
|
@ -84,41 +88,33 @@ class ExplicitFilter:
|
|||
entries, embeddings = self.cache[cache_key]
|
||||
return query, entries, embeddings
|
||||
|
||||
# deep copy original embeddings, entries before filtering
|
||||
if not self.word_to_entry_index:
|
||||
self.load(raw_entries, regenerate=False)
|
||||
|
||||
start = time.time()
|
||||
embeddings= deepcopy(raw_embeddings)
|
||||
entries = deepcopy(raw_entries)
|
||||
end = time.time()
|
||||
logger.debug(f"Create copy of embeddings, entries for manipulation: {end - start:.3f} seconds")
|
||||
|
||||
if not self.entries_by_word_set:
|
||||
self.load(entries, regenerate=False)
|
||||
|
||||
# track id of entries to exclude
|
||||
start = time.time()
|
||||
entries_to_exclude = set()
|
||||
|
||||
# mark entries that do not contain all required_words for exclusion
|
||||
# mark entries that contain all required_words for inclusion
|
||||
entries_with_all_required_words = set(range(len(raw_entries)))
|
||||
if len(required_words) > 0:
|
||||
for id, words_in_entry in enumerate(self.entries_by_word_set):
|
||||
if not required_words.issubset(words_in_entry):
|
||||
entries_to_exclude.add(id)
|
||||
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
|
||||
|
||||
# mark entries that contain any blocked_words for exclusion
|
||||
entries_with_any_blocked_words = set()
|
||||
if len(blocked_words) > 0:
|
||||
for id, words_in_entry in enumerate(self.entries_by_word_set):
|
||||
if words_in_entry.intersection(blocked_words):
|
||||
entries_to_exclude.add(id)
|
||||
end = time.time()
|
||||
logger.debug(f"Mark entries not satisfying filter: {end - start} seconds")
|
||||
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
|
||||
|
||||
# delete entries (and their embeddings) marked for exclusion
|
||||
start = time.time()
|
||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||
del entries[id]
|
||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||
end = time.time()
|
||||
logger.debug(f"Delete entries not satisfying filter: {end - start} seconds")
|
||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
||||
|
||||
# get entries (and their embeddings) satisfying inclusion and exclusion filters
|
||||
start = time.time()
|
||||
|
||||
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
|
||||
entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices]
|
||||
embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices)))
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Keep entries satisfying filter: {end - start} seconds")
|
||||
|
||||
# Cache results
|
||||
self.cache[cache_key] = entries, embeddings
|
||||
|
|
Loading…
Reference in a new issue