Create File Filter. Improve, Consolidate Filter Code

### General Filter Improvements
- e441874 Create Abstract Base Class for all filters to inherit from
- 965bd05 Make search filters return entry ids satisfying filter
- 092b9e3 Setup Filters when configuring Text Search for each Search Type
- 31503e7 Do not pass embeddings in argument to `filter.apply` method as unused

### Create File Filter
- 7606724 Add file associated with each entry to entry dict in `org_to_jsonl` converter
- 1f9fd28 Create File Filter to filter files specified in query
- 7dd20d7 Pre-compute file to entry map in speed up file based filter
- 7e083d3 Cache results for file filters passed in query for faster filtering
- 2890b4c Simplify extracting entries satisfying file filter

### Miscellaneous
- f930324 Rename `explicit filter` to more appropriate name `word filter`
- 3707a4c Improve date filter perf. Precompute date to entry map, Cache results
This commit is contained in:
Debanjum 2022-09-05 15:29:55 +00:00 committed by GitHub
commit 0a78cd5477
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 457 additions and 221 deletions

1
.gitignore vendored
View file

@ -13,3 +13,4 @@ src/.data
/dist/
/khoj_assistant.egg-info/
/config/khoj*.yml
.pytest_cache

View file

@ -125,7 +125,6 @@ pip install --upgrade khoj-assistant
- Semantic search using the bi-encoder is fairly fast at \<50 ms
- Reranking using the cross-encoder is slower at \<2s on 15 results. Tweak `top_k` to tradeoff speed for accuracy of results
- Applying explicit filters is very slow currently at \~6s. This is because the filters are rudimentary. Considerable speed-ups can be achieved using indexes etc
### Indexing performance

View file

@ -14,6 +14,9 @@ from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, Con
from src.utils import state
from src.utils.helpers import resolve_absolute_path
from src.utils.rawconfig import FullConfig, ProcessorConfig
from src.search_filter.date_filter import DateFilter
from src.search_filter.word_filter import WordFilter
from src.search_filter.file_filter import FileFilter
logger = logging.getLogger(__name__)
@ -39,23 +42,25 @@ def configure_server(args, required=False):
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None):
# Initialize Org Notes Search
if (t == SearchType.Org or t == None) and config.content_type.org:
filter_directory = resolve_absolute_path(config.content_type.org.compressed_jsonl.parent)
filters = [DateFilter(), WordFilter(filter_directory, search_type=SearchType.Org), FileFilter()]
# Extract Entries, Generate Notes Embeddings
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, search_type=SearchType.Org, regenerate=regenerate)
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, filters=filters)
# Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music:
# Extract Entries, Generate Music Embeddings
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, search_type=SearchType.Music, regenerate=regenerate)
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate)
# Initialize Markdown Search
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
# Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, search_type=SearchType.Markdown, regenerate=regenerate)
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate)
# Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, search_type=SearchType.Ledger, regenerate=regenerate)
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate)
# Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image:

View file

@ -28,10 +28,10 @@ def org_to_jsonl(org_files, org_file_filter, output_file):
org_files = get_org_files(org_files, org_file_filter)
# Extract Entries from specified Org files
entries = extract_org_entries(org_files)
entries, file_to_entries = extract_org_entries(org_files)
# Process Each Entry from All Notes Files
jsonl_data = convert_org_entries_to_jsonl(entries)
jsonl_data = convert_org_entries_to_jsonl(entries, file_to_entries)
# Compress JSONL formatted Data
if output_file.suffix == ".gz":
@ -66,18 +66,19 @@ def get_org_files(org_files=None, org_file_filter=None):
def extract_org_entries(org_files):
"Extract entries from specified Org files"
entries = []
entry_to_file_map = []
for org_file in org_files:
entries.extend(
orgnode.makelist(
str(org_file)))
org_file_entries = orgnode.makelist(str(org_file))
entry_to_file_map += [org_file]*len(org_file_entries)
entries.extend(org_file_entries)
return entries
return entries, entry_to_file_map
def convert_org_entries_to_jsonl(entries) -> str:
def convert_org_entries_to_jsonl(entries, entry_to_file_map) -> str:
"Convert each Org-Mode entries to JSON and collate as JSONL"
jsonl = ''
for entry in entries:
for entry_id, entry in enumerate(entries):
entry_dict = dict()
# Ignore title notes i.e notes with just headings and empty body
@ -106,6 +107,7 @@ def convert_org_entries_to_jsonl(entries) -> str:
if entry_dict:
entry_dict["raw"] = f'{entry}'
entry_dict["file"] = f'{entry_to_file_map[entry_id]}'
# Convert Dictionary to JSON and Append to JSONL string
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'

View file

@ -16,8 +16,6 @@ from fastapi.templating import Jinja2Templates
from src.configure import configure_search
from src.search_type import image_search, text_search
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from src.search_filter.explicit_filter import ExplicitFilter
from src.search_filter.date_filter import DateFilter
from src.utils.rawconfig import FullConfig
from src.utils.config import SearchType
from src.utils.helpers import get_absolute_path, get_from_dict

View file

@ -0,0 +1,20 @@
# Standard Packages
from abc import ABC, abstractmethod
from typing import List, Set, Tuple
# External Packages
import torch
class BaseFilter(ABC):
@abstractmethod
def load(self, *args, **kwargs):
pass
@abstractmethod
def can_filter(self, raw_query:str) -> bool:
pass
@abstractmethod
def apply(self, query:str, raw_entries:List[str]) -> Tuple[str, Set[int]]:
pass

View file

@ -1,56 +1,40 @@
# Standard Packages
import re
import time
import logging
from collections import defaultdict
from datetime import timedelta, datetime
from dateutil.relativedelta import relativedelta, MO
from dateutil.relativedelta import relativedelta
from math import inf
from copy import deepcopy
# External Packages
import torch
import dateparser as dtparse
# Internal Packages
from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU
class DateFilter:
logger = logging.getLogger(__name__)
class DateFilter(BaseFilter):
# Date Range Filter Regexes
# Example filter queries:
# - dt>="yesterday" dt<"tomorrow"
# - dt>="last week"
# - dt:"2 years ago"
# - dt>="yesterday" dt<"tomorrow"
# - dt>="last week"
# - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
def __init__(self, entry_key='raw'):
self.entry_key = entry_key
self.date_to_entry_ids = defaultdict(set)
self.cache = LRU()
def load(*args, **kwargs):
pass
def can_filter(self, raw_query):
"Check if query contains date filters"
return self.extract_date_range(raw_query) is not None
def apply(self, query, raw_entries, raw_embeddings):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
query_daterange = self.extract_date_range(query)
# if no date in query, return all entries
if query_daterange is None:
return query, raw_entries, raw_embeddings
# remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# deep copy original embeddings, entries before filtering
embeddings= deepcopy(raw_embeddings)
entries = deepcopy(raw_entries)
# find entries containing any dates that fall with date range specified in query
entries_to_include = set()
def load(self, entries, **_):
start = time.time()
for id, entry in enumerate(entries):
# Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
@ -59,18 +43,56 @@ class DateFilter:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
except ValueError:
continue
# Check if date in entry is within date range specified in query
if query_daterange[0] <= date_in_entry < query_daterange[1]:
entries_to_include.add(id)
break
self.date_to_entry_ids[date_in_entry].add(id)
end = time.time()
logger.debug(f"Created file filter index: {end - start} seconds")
# delete entries (and their embeddings) marked for exclusion
entries_to_exclude = set(range(len(entries))) - entries_to_include
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
return query, entries, embeddings
def can_filter(self, raw_query):
"Check if query contains date filters"
return self.extract_date_range(raw_query) is not None
def apply(self, query, raw_entries):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
start = time.time()
query_daterange = self.extract_date_range(query)
end = time.time()
logger.debug(f"Extract date range to filter from query: {end - start} seconds")
# if no date in query, return all entries
if query_daterange is None:
return query, set(range(len(raw_entries)))
# remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# return results from cache if exists
cache_key = tuple(query_daterange)
if cache_key in self.cache:
logger.info(f"Return date filter results from cache")
entries_to_include = self.cache[cache_key]
return query, entries_to_include
if not self.date_to_entry_ids:
self.load(raw_entries)
# find entries containing any dates that fall with date range specified in query
start = time.time()
entries_to_include = set()
for date_in_entry in self.date_to_entry_ids.keys():
# Check if date in entry is within date range specified in query
if query_daterange[0] <= date_in_entry < query_daterange[1]:
entries_to_include |= self.date_to_entry_ids[date_in_entry]
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# cache results
self.cache[cache_key] = entries_to_include
return query, entries_to_include
def extract_date_range(self, query):

View file

@ -0,0 +1,79 @@
# Standard Packages
import re
import fnmatch
import time
import logging
from collections import defaultdict
# Internal Packages
from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU
logger = logging.getLogger(__name__)
class FileFilter(BaseFilter):
file_filter_regex = r'file:"(.+?)" ?'
def __init__(self, entry_key='file'):
self.entry_key = entry_key
self.file_to_entry_map = defaultdict(set)
self.cache = LRU()
def load(self, entries, *args, **kwargs):
start = time.time()
for id, entry in enumerate(entries):
self.file_to_entry_map[entry[self.entry_key]].add(id)
end = time.time()
logger.debug(f"Created file filter index: {end - start} seconds")
def can_filter(self, raw_query):
return re.search(self.file_filter_regex, raw_query) is not None
def apply(self, raw_query, raw_entries):
# Extract file filters from raw query
start = time.time()
raw_files_to_search = re.findall(self.file_filter_regex, raw_query)
if not raw_files_to_search:
return raw_query, set(range(len(raw_entries)))
# Convert simple file filters with no path separator into regex
# e.g. "file:notes.org" -> "file:.*notes.org"
files_to_search = []
for file in sorted(raw_files_to_search):
if '/' not in file and '\\' not in file and '*' not in file:
files_to_search += [f'*{file}']
else:
files_to_search += [file]
end = time.time()
logger.debug(f"Extract files_to_search from query: {end - start} seconds")
# Return item from cache if exists
query = re.sub(self.file_filter_regex, '', raw_query).strip()
cache_key = tuple(files_to_search)
if cache_key in self.cache:
logger.info(f"Return file filter results from cache")
included_entry_indices = self.cache[cache_key]
return query, included_entry_indices
if not self.file_to_entry_map:
self.load(raw_entries, regenerate=False)
# Mark entries that contain any blocked_words for exclusion
start = time.time()
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
for entry_file in self.file_to_entry_map.keys()
for search_file in files_to_search
if fnmatch.fnmatch(entry_file, search_file)], set())
if not included_entry_indices:
return query, {}
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# Cache results
self.cache[cache_key] = included_entry_indices
return query, included_entry_indices

View file

@ -8,6 +8,7 @@ import logging
import torch
# Internal Packages
from src.search_filter.base_filter import BaseFilter
from src.utils.helpers import LRU, resolve_absolute_path
from src.utils.config import SearchType
@ -15,13 +16,13 @@ from src.utils.config import SearchType
logger = logging.getLogger(__name__)
class ExplicitFilter:
class WordFilter(BaseFilter):
# Filter Regex
required_regex = r'\+"(\w+)" ?'
blocked_regex = r'\-"(\w+)" ?'
def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'):
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
self.filter_file = resolve_absolute_path(filter_directory / f"word_filter_{search_type.name.lower()}_index.pkl")
self.entry_key = entry_key
self.search_type = search_type
self.word_to_entry_index = dict()
@ -34,7 +35,7 @@ class ExplicitFilter:
with self.filter_file.open('rb') as f:
self.word_to_entry_index = pickle.load(f)
end = time.time()
logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
logger.debug(f"Load word filter index for {self.search_type} from {self.filter_file}: {end - start} seconds")
else:
start = time.time()
self.cache = {} # Clear cache on (re-)generating entries_by_word_set
@ -51,23 +52,22 @@ class ExplicitFilter:
with self.filter_file.open('wb') as f:
pickle.dump(self.word_to_entry_index, f)
end = time.time()
logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
logger.debug(f"Index {self.search_type} for word filter to {self.filter_file}: {end - start} seconds")
return self.word_to_entry_index
def can_filter(self, raw_query):
"Check if query contains explicit filters"
# Extract explicit query portion with required, blocked words to filter from natural query
"Check if query contains word filters"
required_words = re.findall(self.required_regex, raw_query)
blocked_words = re.findall(self.blocked_regex, raw_query)
return len(required_words) != 0 or len(blocked_words) != 0
def apply(self, raw_query, raw_entries, raw_embeddings):
def apply(self, raw_query, raw_entries):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from explicit required, blocked words filters
# Separate natural query from required, blocked words filters
start = time.time()
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
@ -78,14 +78,14 @@ class ExplicitFilter:
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
if len(required_words) == 0 and len(blocked_words) == 0:
return query, raw_entries, raw_embeddings
return query, set(range(len(raw_entries)))
# Return item from cache if exists
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
if cache_key in self.cache:
logger.info(f"Explicit filter results from cache")
entries, embeddings = self.cache[cache_key]
return query, entries, embeddings
logger.info(f"Return word filter results from cache")
included_entry_indices = self.cache[cache_key]
return query, included_entry_indices
if not self.word_to_entry_index:
self.load(raw_entries, regenerate=False)
@ -105,17 +105,10 @@ class ExplicitFilter:
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# get entries (and their embeddings) satisfying inclusion and exclusion filters
start = time.time()
# get entries satisfying inclusion and exclusion filters
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices]
embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
logger.debug(f"Keep entries satisfying filter: {end - start} seconds")
# Cache results
self.cache[cache_key] = entries, embeddings
self.cache[cache_key] = included_entry_indices
return query, entries, embeddings
return query, included_entry_indices

View file

@ -7,13 +7,12 @@ import time
# External Packages
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from src.search_filter.date_filter import DateFilter
from src.search_filter.explicit_filter import ExplicitFilter
from src.search_filter.base_filter import BaseFilter
# Internal Packages
from src.utils import state
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.utils.config import SearchType, TextSearchModel
from src.utils.config import TextSearchModel
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
from src.utils.jsonl import load_jsonl
@ -53,9 +52,7 @@ def initialize_model(search_config: TextSearchConfig):
def extract_entries(jsonl_file):
"Load entries from compressed jsonl"
return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'}
for entry
in load_jsonl(jsonl_file)]
return load_jsonl(jsonl_file)
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
@ -79,12 +76,25 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
# Filter query, entries and embeddings before semantic search
start = time.time()
start_filter = time.time()
included_entry_indices = set(range(len(entries)))
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
for filter in filters_in_query:
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
end = time.time()
logger.debug(f"Filter Time: {end - start:.3f} seconds")
query, included_entry_indices_by_filter = filter.apply(query, entries)
included_entry_indices.intersection_update(included_entry_indices_by_filter)
# Get entries (and associated embeddings) satisfying all filters
if not included_entry_indices:
return [], []
else:
start = time.time()
entries = [entries[id] for id in included_entry_indices]
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
end_filter = time.time()
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds")
if entries is None or len(entries) == 0:
return [], []
@ -153,7 +163,7 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel:
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
@ -170,8 +180,6 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
filter_directory = resolve_absolute_path(config.compressed_jsonl.parent)
filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)]
for filter in filters:
filter.load(entries, regenerate=regenerate)

View file

@ -2,9 +2,11 @@
from enum import Enum
from dataclasses import dataclass
from pathlib import Path
from typing import List
# Internal Packages
from src.utils.rawconfig import ConversationProcessorConfig
from src.search_filter.base_filter import BaseFilter
class SearchType(str, Enum):
@ -20,7 +22,7 @@ class ProcessorType(str, Enum):
class TextSearchModel():
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k):
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: List[BaseFilter], top_k):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder

View file

@ -1,4 +1,4 @@
# Standard Packages
# External Packages
import pytest
# Internal Packages
@ -6,10 +6,13 @@ from src.search_type import image_search, text_search
from src.utils.config import SearchType
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.search_filter.date_filter import DateFilter
from src.search_filter.word_filter import WordFilter
from src.search_filter.file_filter import FileFilter
@pytest.fixture(scope='session')
def search_config(tmp_path_factory):
def search_config(tmp_path_factory) -> SearchConfig:
model_dir = tmp_path_factory.mktemp('data')
search_config = SearchConfig()
@ -35,7 +38,7 @@ def search_config(tmp_path_factory):
@pytest.fixture(scope='session')
def model_dir(search_config):
def model_dir(search_config: SearchConfig):
model_dir = search_config.asymmetric.model_directory
# Generate Image Embeddings from Test Images
@ -55,13 +58,14 @@ def model_dir(search_config):
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
filters = [DateFilter(), WordFilter(model_dir, search_type=SearchType.Org), FileFilter()]
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
return model_dir
@pytest.fixture(scope='session')
def content_config(model_dir):
def content_config(model_dir) -> ContentConfig:
content_config = ContentConfig()
content_config.org = TextContentConfig(
input_files = None,

View file

@ -4,7 +4,6 @@ from PIL import Image
# External Packages
from fastapi.testclient import TestClient
import pytest
# Internal Packages
from src.main import app
@ -12,7 +11,8 @@ from src.utils.config import SearchType
from src.utils.state import model, config
from src.search_type import text_search, image_search
from src.utils.rawconfig import ContentConfig, SearchConfig
from src.processor.org_mode import org_to_jsonl
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.search_filter.word_filter import WordFilter
# Arrange
@ -116,7 +116,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
# ----------------------------------------------------------------------------------------------------
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
user_query = "How to git install application?"
# Act
@ -132,7 +132,8 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
filters = [WordFilter(content_config.org.compressed_jsonl.parent, search_type=SearchType.Org)]
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
user_query = 'How to git install application? +"Emacs"'
# Act
@ -140,7 +141,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
# Assert
assert response.status_code == 200
# assert actual_data contains explicitly included word "Emacs"
# assert actual_data contains word "Emacs"
search_result = response.json()[0]["entry"]
assert "Emacs" in search_result
@ -148,7 +149,8 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
filters = [WordFilter(content_config.org.compressed_jsonl.parent, search_type=SearchType.Org)]
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
user_query = 'How to git install application? -"clone"'
# Act
@ -156,6 +158,6 @@ def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_
# Assert
assert response.status_code == 200
# assert actual_data does not contains explicitly excluded word "Emacs"
# assert actual_data does not contains word "Emacs"
search_result = response.json()[0]["entry"]
assert "clone" not in search_result

View file

@ -18,40 +18,34 @@ def test_date_filter():
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
q_with_no_date_filter = 'head tail'
ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_no_date_filter, entries.copy(), embeddings)
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
assert ret_query == 'head tail'
assert len(ret_emb) == 3
assert ret_entries == entries
assert entry_indices == {0, 1, 2}
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
assert ret_query == 'head tail'
assert len(ret_emb) == 0
assert ret_entries == []
assert entry_indices == set()
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail'
assert ret_entries == [entries[2]]
assert len(ret_emb) == 1
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail'
assert ret_entries == [entries[1]]
assert len(ret_emb) == 1
assert entry_indices == {1}
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail'
assert ret_entries == [entries[2]]
assert len(ret_emb) == 1
assert entry_indices == {2}
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
assert ret_query == 'head tail'
assert ret_entries == [entries[1], entries[2]]
assert len(ret_emb) == 2
assert entry_indices == {1, 2}
def test_extract_date_range():

View file

@ -1,85 +0,0 @@
# External Packages
import torch
# Application Packages
from src.search_filter.explicit_filter import ExplicitFilter
from src.utils.config import SearchType
def test_no_explicit_filter(tmp_path):
# Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
q_with_no_filter = 'head tail'
# Act
can_filter = explicit_filter.can_filter(q_with_no_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings)
# Assert
assert can_filter == False
assert ret_query == 'head tail'
assert len(ret_emb) == 4
assert ret_entries == entries
def test_explicit_exclude_filter(tmp_path):
# Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
q_with_exclude_filter = 'head -"exclude_word" tail'
# Act
can_filter = explicit_filter.can_filter(q_with_exclude_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert len(ret_emb) == 2
assert ret_entries == [entries[0], entries[2]]
def test_explicit_include_filter(tmp_path):
# Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
query_with_include_filter = 'head +"include_word" tail'
# Act
can_filter = explicit_filter.can_filter(query_with_include_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert len(ret_emb) == 2
assert ret_entries == [entries[2], entries[3]]
def test_explicit_include_and_exclude_filter(tmp_path):
# Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act
can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert len(ret_emb) == 1
assert ret_entries == [entries[2]]
def arrange_content():
embeddings = torch.randn(4, 10)
entries = [
{'compiled': '', 'raw': 'Minimal Entry'},
{'compiled': '', 'raw': 'Entry with exclude_word'},
{'compiled': '', 'raw': 'Entry with include_word'},
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}]
return embeddings, entries

112
tests/test_file_filter.py Normal file
View file

@ -0,0 +1,112 @@
# External Packages
import torch
# Application Packages
from src.search_filter.file_filter import FileFilter
def test_no_file_filter():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
q_with_no_filter = 'head tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == False
assert ret_query == 'head tail'
assert entry_indices == {0, 1, 2, 3}
def test_file_filter_with_non_existent_file():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
q_with_no_filter = 'head file:"nonexistent.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {}
def test_single_file_filter():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
q_with_no_filter = 'head file:"file 1.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {0, 2}
def test_file_filter_with_partial_match():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
q_with_no_filter = 'head file:"1.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {0, 2}
def test_file_filter_with_regex_match():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
q_with_no_filter = 'head file:"*.org" tail'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {0, 1, 2, 3}
def test_multiple_file_filter():
# Arrange
file_filter = FileFilter()
embeddings, entries = arrange_content()
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
# Act
can_filter = file_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {0, 1, 2, 3}
def arrange_content():
embeddings = torch.randn(4, 10)
entries = [
{'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'},
{'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'},
{'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'},
{'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}]
return embeddings, entries

View file

@ -21,10 +21,10 @@ def test_entry_with_empty_body_line_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Org files
entries = extract_org_entries(org_files=[orgfile])
entries, entry_to_file_map = extract_org_entries(org_files=[orgfile])
# Process Each Entry from All Notes Files
jsonl_data = convert_org_entries_to_jsonl(entries)
jsonl_data = convert_org_entries_to_jsonl(entries, entry_to_file_map)
# Assert
assert is_none_or_empty(jsonl_data)
@ -43,10 +43,10 @@ def test_entry_with_body_to_jsonl(tmp_path):
# Act
# Extract Entries from specified Org files
entries = extract_org_entries(org_files=[orgfile])
entries, entry_to_file_map = extract_org_entries(org_files=[orgfile])
# Process Each Entry from All Notes Files
jsonl_string = convert_org_entries_to_jsonl(entries)
jsonl_string = convert_org_entries_to_jsonl(entries, entry_to_file_map)
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
# Assert

View file

@ -1,6 +1,5 @@
# System Packages
from pathlib import Path
from src.utils.config import SearchType
# Internal Packages
from src.utils.state import model
@ -14,7 +13,7 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
# Act
# Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True)
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
# Assert
assert len(notes_model.entries) == 10
@ -24,7 +23,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
query = "How to git install application?"
# Act
@ -47,7 +46,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
@ -60,11 +59,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True)
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
# Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
# Assert
assert len(regenerated_notes_model.entries) == 11

81
tests/test_word_filter.py Normal file
View file

@ -0,0 +1,81 @@
# External Packages
import torch
# Application Packages
from src.search_filter.word_filter import WordFilter
from src.utils.config import SearchType
def test_no_word_filter(tmp_path):
# Arrange
word_filter = WordFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
q_with_no_filter = 'head tail'
# Act
can_filter = word_filter.can_filter(q_with_no_filter)
ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries)
# Assert
assert can_filter == False
assert ret_query == 'head tail'
assert entry_indices == {0, 1, 2, 3}
def test_word_exclude_filter(tmp_path):
# Arrange
word_filter = WordFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
q_with_exclude_filter = 'head -"exclude_word" tail'
# Act
can_filter = word_filter.can_filter(q_with_exclude_filter)
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {0, 2}
def test_word_include_filter(tmp_path):
# Arrange
word_filter = WordFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
query_with_include_filter = 'head +"include_word" tail'
# Act
can_filter = word_filter.can_filter(query_with_include_filter)
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {2, 3}
def test_word_include_and_exclude_filter(tmp_path):
# Arrange
word_filter = WordFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert entry_indices == {2}
def arrange_content():
embeddings = torch.randn(4, 10)
entries = [
{'compiled': '', 'raw': 'Minimal Entry'},
{'compiled': '', 'raw': 'Entry with exclude_word'},
{'compiled': '', 'raw': 'Entry with include_word'},
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}]
return embeddings, entries