diff --git a/src/configure.py b/src/configure.py index f6476951..ed46af37 100644 --- a/src/configure.py +++ b/src/configure.py @@ -48,11 +48,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, - filters=[ - DateFilter(), - WordFilter(config.content_type.org.compressed_jsonl.parent, SearchType.Org), - FileFilter(), - ]) + filters=[DateFilter(), WordFilter(), FileFilter()]) # Initialize Org Music Search if (t == SearchType.Music or t == None) and config.content_type.music: @@ -71,11 +67,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate, - filters=[ - DateFilter(), - WordFilter(config.content_type.markdown.compressed_jsonl.parent, SearchType.Markdown), - FileFilter(), - ]) + filters=[DateFilter(), WordFilter(), FileFilter()]) # Initialize Ledger Search if (t == SearchType.Ledger or t == None) and config.content_type.ledger: @@ -85,11 +77,7 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate, - filters=[ - DateFilter(), - WordFilter(config.content_type.ledger.compressed_jsonl.parent, SearchType.Ledger), - FileFilter(), - ]) + filters=[DateFilter(), WordFilter(), FileFilter()]) # Initialize Image Search if (t == SearchType.Image or t == None) and config.content_type.image: diff --git a/src/search_filter/word_filter.py b/src/search_filter/word_filter.py index c7c5d059..6fe0b31e 100644 --- a/src/search_filter/word_filter.py +++ b/src/search_filter/word_filter.py @@ -3,6 +3,7 @@ import re import time import pickle import logging +from collections import defaultdict # Internal Packages from src.search_filter.base_filter import BaseFilter @@ -18,38 +19,24 @@ class WordFilter(BaseFilter): required_regex = r'\+"(\w+)" ?' blocked_regex = r'\-"(\w+)" ?' - def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'): - self.filter_file = resolve_absolute_path(filter_directory / f"word_filter_{search_type.name.lower()}_index.pkl") + def __init__(self, entry_key='raw'): self.entry_key = entry_key - self.search_type = search_type - self.word_to_entry_index = dict() + self.word_to_entry_index = defaultdict(set) self.cache = LRU() def load(self, entries, regenerate=False): - if self.filter_file.exists() and not regenerate: - start = time.time() - with self.filter_file.open('rb') as f: - self.word_to_entry_index = pickle.load(f) - end = time.time() - logger.debug(f"Load word filter index for {self.search_type} from {self.filter_file}: {end - start} seconds") - else: - start = time.time() - self.cache = {} # Clear cache on (re-)generating entries_by_word_set - entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:' - # Create map of words to entries they exist in - for entry_index, entry in enumerate(entries): - for word in re.split(entry_splitter, entry[self.entry_key].lower()): - if word == '': - continue - if word not in self.word_to_entry_index: - self.word_to_entry_index[word] = set() - self.word_to_entry_index[word].add(entry_index) - - with self.filter_file.open('wb') as f: - pickle.dump(self.word_to_entry_index, f) - end = time.time() - logger.debug(f"Index {self.search_type} for word filter to {self.filter_file}: {end - start} seconds") + start = time.time() + self.cache = {} # Clear cache on reload of filter + entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\<|\>|\t|\n|\:|\;|\?|\!|\(|\)|\&|\^|\$|\@|\%|\+|\=|\/|\\|\||\~|\`|\"|\'' + # Create map of words to entries they exist in + for entry_index, entry in enumerate(entries): + for word in re.split(entry_splitter, entry[self.entry_key].lower()): + if word == '': + continue + self.word_to_entry_index[word].add(entry_index) + end = time.time() + logger.debug(f"Created word filter index: {end - start} seconds") return self.word_to_entry_index diff --git a/tests/conftest.py b/tests/conftest.py index 7545527f..ab2703da 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -58,7 +58,7 @@ def model_dir(search_config: SearchConfig): compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), embeddings_file = model_dir.joinpath('note_embeddings.pt')) - filters = [DateFilter(), WordFilter(model_dir, search_type=SearchType.Org), FileFilter()] + filters = [DateFilter(), WordFilter(), FileFilter()] text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) return model_dir diff --git a/tests/test_client.py b/tests/test_client.py index 578c789c..b167bce0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -132,7 +132,7 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - filters = [WordFilter(content_config.org.compressed_jsonl.parent, search_type=SearchType.Org)] + filters = [WordFilter()] model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) user_query = 'How to git install application? +"Emacs"' @@ -149,7 +149,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_ # ---------------------------------------------------------------------------------------------------- def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): # Arrange - filters = [WordFilter(content_config.org.compressed_jsonl.parent, search_type=SearchType.Org)] + filters = [WordFilter()] model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters) user_query = 'How to git install application? -"clone"' diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index 3efe8ed9..db23c2c6 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -1,15 +1,12 @@ -# External Packages -import torch - # Application Packages from src.search_filter.word_filter import WordFilter from src.utils.config import SearchType -def test_no_word_filter(tmp_path): +def test_no_word_filter(): # Arrange - word_filter = WordFilter(tmp_path, SearchType.Org) - embeddings, entries = arrange_content() + word_filter = WordFilter() + entries = arrange_content() q_with_no_filter = 'head tail' # Act @@ -22,10 +19,10 @@ def test_no_word_filter(tmp_path): assert entry_indices == {0, 1, 2, 3} -def test_word_exclude_filter(tmp_path): +def test_word_exclude_filter(): # Arrange - word_filter = WordFilter(tmp_path, SearchType.Org) - embeddings, entries = arrange_content() + word_filter = WordFilter() + entries = arrange_content() q_with_exclude_filter = 'head -"exclude_word" tail' # Act @@ -38,10 +35,10 @@ def test_word_exclude_filter(tmp_path): assert entry_indices == {0, 2} -def test_word_include_filter(tmp_path): +def test_word_include_filter(): # Arrange - word_filter = WordFilter(tmp_path, SearchType.Org) - embeddings, entries = arrange_content() + word_filter = WordFilter() + entries = arrange_content() query_with_include_filter = 'head +"include_word" tail' # Act @@ -54,10 +51,10 @@ def test_word_include_filter(tmp_path): assert entry_indices == {2, 3} -def test_word_include_and_exclude_filter(tmp_path): +def test_word_include_and_exclude_filter(): # Arrange - word_filter = WordFilter(tmp_path, SearchType.Org) - embeddings, entries = arrange_content() + word_filter = WordFilter() + entries = arrange_content() query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail' # Act @@ -71,11 +68,10 @@ def test_word_include_and_exclude_filter(tmp_path): def arrange_content(): - embeddings = torch.randn(4, 10) entries = [ {'compiled': '', 'raw': 'Minimal Entry'}, {'compiled': '', 'raw': 'Entry with exclude_word'}, {'compiled': '', 'raw': 'Entry with include_word'}, {'compiled': '', 'raw': 'Entry with include_word and exclude_word'}] - return embeddings, entries + return entries