Improve Latency of Explicit Filter

### Goal
  - Improve explicit filter latency to work better with incremental search

### Reasons for High Explicit Filter Latency
  - Deleting entries to be excluded from existing list of entries, embeddings
  - Explicit filtering on partial words during incremental search
  - Creating word set for all entries on the fly during query
  - Deep copying of entries, embeddings before applying filter

### Improvement Details
  - **Major**
    - 191a656 Use word to entry map, list comprehension to speed up explicit filter
      - Use list comprehension and `torch.index_select` methods
        - to speed selection of entries, embedding tensors satisfying filter
        - avoid deep copy and direct manipulation of entries, embeddings
      - Use word to entry map and set operations to mark entries 
        satisfying inclusion, exclusion filters
    - c7de57b Pre-compute entry word sets to improve explicit filter query performance
    - 3308e68 Cache explicitly filtered entries, embeddings by required, blocked words
    - cdcee89 Wrap explicit filter words in quotes to trigger filter
      - E.g `+"word_to_include"` instead of `+word_to_include`
      - Signals explicit filter term completed
      - Prevents latency due to incremental search with explicit filtering on partial terms
  - **Minor**
    - 28d3dc1 Deep copy entries, embeddings in filters. Defer till actual filtering
    - 8d9f507 Load entries_by_word_set from file only once on first load of explicit filter
    - 546fad5 Use regex to check for and extract include, exclude filter words from query
    - b7d259b Test Explicit Include, Exclude Filters
    
### Results
  - Improve exclude word filter latency from **20s+ to 0.02s** on 120K line notes corpus
This commit is contained in:
Debanjum 2022-09-04 13:55:17 +00:00 committed by GitHub
commit d153d420fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 272 additions and 78 deletions

View file

@ -40,22 +40,22 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
# Initialize Org Notes Search # Initialize Org Notes Search
if (t == SearchType.Org or t == None) and config.content_type.org: if (t == SearchType.Org or t == None) and config.content_type.org:
# Extract Entries, Generate Notes Embeddings # Extract Entries, Generate Notes Embeddings
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate) model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, search_type=SearchType.Org, regenerate=regenerate)
# Initialize Org Music Search # Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music: if (t == SearchType.Music or t == None) and config.content_type.music:
# Extract Entries, Generate Music Embeddings # Extract Entries, Generate Music Embeddings
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate) model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, search_type=SearchType.Music, regenerate=regenerate)
# Initialize Markdown Search # Initialize Markdown Search
if (t == SearchType.Markdown or t == None) and config.content_type.markdown: if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
# Extract Entries, Generate Markdown Embeddings # Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate) model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, search_type=SearchType.Markdown, regenerate=regenerate)
# Initialize Ledger Search # Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger: if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings # Extract Entries, Generate Ledger Embeddings
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate) model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, search_type=SearchType.Ledger, regenerate=regenerate)
# Initialize Image Search # Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image: if (t == SearchType.Image or t == None) and config.content_type.image:

View file

@ -65,7 +65,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Org or t == None) and state.model.orgmode_search: if (t == SearchType.Org or t == None) and state.model.orgmode_search:
# query org-mode notes # query org-mode notes
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@ -76,7 +76,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Music or t == None) and state.model.music_search: if (t == SearchType.Music or t == None) and state.model.music_search:
# query music library # query music library
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose) hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@ -87,7 +87,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Markdown or t == None) and state.model.markdown_search: if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
# query markdown files # query markdown files
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results
@ -98,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Ledger or t == None) and state.model.ledger_search: if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions # query transactions
query_start = time.time() query_start = time.time()
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose) hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
query_end = time.time() query_end = time.time()
# collate and return results # collate and return results

View file

@ -3,6 +3,7 @@ import re
from datetime import timedelta, datetime from datetime import timedelta, datetime
from dateutil.relativedelta import relativedelta, MO from dateutil.relativedelta import relativedelta, MO
from math import inf from math import inf
from copy import deepcopy
# External Packages # External Packages
import torch import torch
@ -17,29 +18,42 @@ class DateFilter:
# - dt:"2 years ago" # - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})\"(.*?)\"" date_regex = r"dt([:><=]{1,2})\"(.*?)\""
def __init__(self, entry_key='raw'):
self.entry_key = entry_key
def load(*args, **kwargs):
pass
def can_filter(self, raw_query): def can_filter(self, raw_query):
"Check if query contains date filters" "Check if query contains date filters"
return self.extract_date_range(raw_query) is not None return self.extract_date_range(raw_query) is not None
def filter(self, query, entries, embeddings, entry_key='raw'): def apply(self, query, raw_entries, raw_embeddings):
"Find entries containing any dates that fall within date range specified in query" "Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query # extract date range specified in date filter of query
query_daterange = self.extract_date_range(query) query_daterange = self.extract_date_range(query)
# if no date in query, return all entries # if no date in query, return all entries
if query_daterange is None: if query_daterange is None:
return query, entries, embeddings return query, raw_entries, raw_embeddings
# remove date range filter from query # remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query) query = re.sub(rf'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# deep copy original embeddings, entries before filtering
embeddings= deepcopy(raw_embeddings)
entries = deepcopy(raw_entries)
# find entries containing any dates that fall with date range specified in query # find entries containing any dates that fall with date range specified in query
entries_to_include = set() entries_to_include = set()
for id, entry in enumerate(entries): for id, entry in enumerate(entries):
# Extract dates from entry # Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]): for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
# Convert date string in entry to unix timestamp # Convert date string in entry to unix timestamp
try: try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp() date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()

View file

@ -1,57 +1,121 @@
# Standard Packages # Standard Packages
import re import re
import time
import pickle
import logging
# External Packages # External Packages
import torch import torch
# Internal Packages
from src.utils.helpers import LRU, resolve_absolute_path
from src.utils.config import SearchType
logger = logging.getLogger(__name__)
class ExplicitFilter: class ExplicitFilter:
# Filter Regex
required_regex = r'\+"(\w+)" ?'
blocked_regex = r'\-"(\w+)" ?'
def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'):
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
self.entry_key = entry_key
self.search_type = search_type
self.word_to_entry_index = dict()
self.cache = LRU()
def load(self, entries, regenerate=False):
if self.filter_file.exists() and not regenerate:
start = time.time()
with self.filter_file.open('rb') as f:
self.word_to_entry_index = pickle.load(f)
end = time.time()
logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
else:
start = time.time()
self.cache = {} # Clear cache on (re-)generating entries_by_word_set
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
# Create map of words to entries they exist in
for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, entry[self.entry_key].lower()):
if word == '':
continue
if word not in self.word_to_entry_index:
self.word_to_entry_index[word] = set()
self.word_to_entry_index[word].add(entry_index)
with self.filter_file.open('wb') as f:
pickle.dump(self.word_to_entry_index, f)
end = time.time()
logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
return self.word_to_entry_index
def can_filter(self, raw_query): def can_filter(self, raw_query):
"Check if query contains explicit filters" "Check if query contains explicit filters"
# Extract explicit query portion with required, blocked words to filter from natural query # Extract explicit query portion with required, blocked words to filter from natural query
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")]) required_words = re.findall(self.required_regex, raw_query)
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) blocked_words = re.findall(self.blocked_regex, raw_query)
return len(required_words) != 0 or len(blocked_words) != 0 return len(required_words) != 0 or len(blocked_words) != 0
def filter(self, raw_query, entries, embeddings, entry_key='raw'): def apply(self, raw_query, raw_entries, raw_embeddings):
"Find entries containing required and not blocked words specified in query" "Find entries containing required and not blocked words specified in query"
# Separate natural query from explicit required, blocked words filters # Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")]) start = time.time()
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")]) required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip()
end = time.time()
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
if len(required_words) == 0 and len(blocked_words) == 0: if len(required_words) == 0 and len(blocked_words) == 0:
return query, raw_entries, raw_embeddings
# Return item from cache if exists
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
if cache_key in self.cache:
logger.info(f"Explicit filter results from cache")
entries, embeddings = self.cache[cache_key]
return query, entries, embeddings return query, entries, embeddings
# convert each entry to a set of words if not self.word_to_entry_index:
# split on fullstop, comma, colon, tab, newline or any brackets self.load(raw_entries, regenerate=False)
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
entries_by_word_set = [set(word.lower()
for word
in re.split(entry_splitter, entry[entry_key])
if word != "")
for entry in entries]
# track id of entries to exclude start = time.time()
entries_to_exclude = set()
# mark entries that do not contain all required_words for exclusion # mark entries that contain all required_words for inclusion
entries_with_all_required_words = set(range(len(raw_entries)))
if len(required_words) > 0: if len(required_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set): entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id)
# mark entries that contain any blocked_words for exclusion # mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set()
if len(blocked_words) > 0: if len(blocked_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set): entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id)
# delete entries (and their embeddings) marked for exclusion end = time.time()
for id in sorted(list(entries_to_exclude), reverse=True): logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:])) # get entries (and their embeddings) satisfying inclusion and exclusion filters
start = time.time()
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices]
embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
logger.debug(f"Keep entries satisfying filter: {end - start} seconds")
# Cache results
self.cache[cache_key] = entries, embeddings
return query, entries, embeddings return query, entries, embeddings

View file

@ -3,16 +3,17 @@ import argparse
import pathlib import pathlib
import logging import logging
import time import time
from copy import deepcopy
# External Packages # External Packages
import torch import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util from sentence_transformers import SentenceTransformer, CrossEncoder, util
from src.search_filter.date_filter import DateFilter
from src.search_filter.explicit_filter import ExplicitFilter
# Internal Packages # Internal Packages
from src.utils import state from src.utils import state
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.utils.config import TextSearchModel from src.utils.config import SearchType, TextSearchModel
from src.utils.rawconfig import TextSearchConfig, TextContentConfig from src.utils.rawconfig import TextSearchConfig, TextContentConfig
from src.utils.jsonl import load_jsonl from src.utils.jsonl import load_jsonl
@ -73,26 +74,15 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
return corpus_embeddings return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []): def query(raw_query: str, model: TextSearchModel, rank_results=False):
"Search for entries that answer the query" "Search for entries that answer the query"
query = raw_query query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
# Use deep copy of original embeddings, entries to filter if query contains filters
start = time.time()
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
if filters_in_query:
corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries)
else:
corpus_embeddings = model.corpus_embeddings
entries = model.entries
end = time.time()
logger.debug(f"Copy Time: {end - start:.3f} seconds")
# Filter query, entries and embeddings before semantic search # Filter query, entries and embeddings before semantic search
start = time.time() start = time.time()
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
for filter in filters_in_query: for filter in filters_in_query:
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings) query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
end = time.time() end = time.time()
logger.debug(f"Filter Time: {end - start:.3f} seconds") logger.debug(f"Filter Time: {end - start:.3f} seconds")
@ -163,7 +153,7 @@ def collate_results(hits, entries, count=5):
in hits[0:count]] in hits[0:count]]
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel: def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel:
# Initialize Model # Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config) bi_encoder, cross_encoder, top_k = initialize_model(search_config)
@ -180,7 +170,12 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
config.embeddings_file = resolve_absolute_path(config.embeddings_file) config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate) corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k) filter_directory = resolve_absolute_path(config.compressed_jsonl.parent)
filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)]
for filter in filters:
filter.load(entries, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -20,11 +20,12 @@ class ProcessorType(str, Enum):
class TextSearchModel(): class TextSearchModel():
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k): def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k):
self.entries = entries self.entries = entries
self.corpus_embeddings = corpus_embeddings self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder self.bi_encoder = bi_encoder
self.cross_encoder = cross_encoder self.cross_encoder = cross_encoder
self.filters = filters
self.top_k = top_k self.top_k = top_k

View file

@ -2,6 +2,7 @@
import pathlib import pathlib
import sys import sys
from os.path import join from os.path import join
from collections import OrderedDict
def is_none_or_empty(item): def is_none_or_empty(item):
@ -60,4 +61,21 @@ def load_model(model_name, model_dir, model_type, device:str=None):
def is_pyinstaller_app(): def is_pyinstaller_app():
"Returns true if the app is running from Native GUI created by PyInstaller" "Returns true if the app is running from Native GUI created by PyInstaller"
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS') return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
class LRU(OrderedDict):
def __init__(self, *args, capacity=128, **kwargs):
self.capacity = capacity
super().__init__(*args, **kwargs)
def __getitem__(self, key):
value = super().__getitem__(key)
self.move_to_end(key)
return value
def __setitem__(self, key, value):
super().__setitem__(key, value)
if len(self) > self.capacity:
oldest = next(iter(self))
del self[oldest]

View file

@ -3,9 +3,9 @@ import pytest
# Internal Packages # Internal Packages
from src.search_type import image_search, text_search from src.search_type import image_search, text_search
from src.utils.config import SearchType
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.utils import state
@pytest.fixture(scope='session') @pytest.fixture(scope='session')
@ -46,7 +46,7 @@ def model_dir(search_config):
batch_size = 10, batch_size = 10,
use_xmp_metadata = False) use_xmp_metadata = False)
image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True) image_search.setup(content_config.image, search_config.image, regenerate=False)
# Generate Notes Embeddings from Test Notes # Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig( content_config.org = TextContentConfig(
@ -55,7 +55,7 @@ def model_dir(search_config):
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'), compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
embeddings_file = model_dir.joinpath('note_embeddings.pt')) embeddings_file = model_dir.joinpath('note_embeddings.pt'))
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, verbose=True) text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
return model_dir return model_dir

View file

@ -8,6 +8,7 @@ import pytest
# Internal Packages # Internal Packages
from src.main import app from src.main import app
from src.utils.config import SearchType
from src.utils.state import model, config from src.utils.state import model, config
from src.search_type import text_search, image_search from src.search_type import text_search, image_search
from src.utils.rawconfig import ContentConfig, SearchConfig from src.utils.rawconfig import ContentConfig, SearchConfig
@ -115,7 +116,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig): def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
user_query = "How to git install application?" user_query = "How to git install application?"
# Act # Act
@ -131,8 +132,8 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
user_query = "How to git install application? +Emacs" user_query = 'How to git install application? +"Emacs"'
# Act # Act
response = client.get(f"/search?q={user_query}&n=1&t=org") response = client.get(f"/search?q={user_query}&n=1&t=org")
@ -147,8 +148,8 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig): def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
user_query = "How to git install application? -clone" user_query = 'How to git install application? -"clone"'
# Act # Act
response = client.get(f"/search?q={user_query}&n=1&t=org") response = client.get(f"/search?q={user_query}&n=1&t=org")

View file

@ -18,37 +18,37 @@ def test_date_filter():
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}] {'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
q_with_no_date_filter = 'head tail' q_with_no_date_filter = 'head tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_no_date_filter, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 3 assert len(ret_emb) == 3
assert ret_entries == entries assert ret_entries == entries
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail' q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert len(ret_emb) == 0 assert len(ret_emb) == 0
assert ret_entries == [] assert ret_entries == []
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail' query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[2]] assert ret_entries == [entries[2]]
assert len(ret_emb) == 1 assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[1]] assert ret_entries == [entries[1]]
assert len(ret_emb) == 1 assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[2]] assert ret_entries == [entries[2]]
assert len(ret_emb) == 1 assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail' query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail' assert ret_query == 'head tail'
assert ret_entries == [entries[1], entries[2]] assert ret_entries == [entries[1], entries[2]]
assert len(ret_emb) == 2 assert len(ret_emb) == 2

View file

@ -0,0 +1,85 @@
# External Packages
import torch
# Application Packages
from src.search_filter.explicit_filter import ExplicitFilter
from src.utils.config import SearchType
def test_no_explicit_filter(tmp_path):
# Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
q_with_no_filter = 'head tail'
# Act
can_filter = explicit_filter.can_filter(q_with_no_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings)
# Assert
assert can_filter == False
assert ret_query == 'head tail'
assert len(ret_emb) == 4
assert ret_entries == entries
def test_explicit_exclude_filter(tmp_path):
# Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
q_with_exclude_filter = 'head -"exclude_word" tail'
# Act
can_filter = explicit_filter.can_filter(q_with_exclude_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert len(ret_emb) == 2
assert ret_entries == [entries[0], entries[2]]
def test_explicit_include_filter(tmp_path):
# Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
query_with_include_filter = 'head +"include_word" tail'
# Act
can_filter = explicit_filter.can_filter(query_with_include_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert len(ret_emb) == 2
assert ret_entries == [entries[2], entries[3]]
def test_explicit_include_and_exclude_filter(tmp_path):
# Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act
can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert len(ret_emb) == 1
assert ret_entries == [entries[2]]
def arrange_content():
embeddings = torch.randn(4, 10)
entries = [
{'compiled': '', 'raw': 'Minimal Entry'},
{'compiled': '', 'raw': 'Entry with exclude_word'},
{'compiled': '', 'raw': 'Entry with include_word'},
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}]
return embeddings, entries

View file

@ -28,3 +28,18 @@ def test_merge_dicts():
# do not override existing key in priority_dict with default dict # do not override existing key in priority_dict with default dict
assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1} assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1}
def test_lru_cache():
# Test initializing cache
cache = helpers.LRU({'a': 1, 'b': 2}, capacity=2)
assert cache == {'a': 1, 'b': 2}
# Test capacity overflow
cache['c'] = 3
assert cache == {'b': 2, 'c': 3}
# Test delete least recently used item from LRU cache on capacity overflow
cache['b'] # accessing 'b' makes it the most recently used item
cache['d'] = 4 # so 'c' is deleted from the cache instead of 'b'
assert cache == {'b': 2, 'd': 4}

View file

@ -1,5 +1,6 @@
# System Packages # System Packages
from pathlib import Path from pathlib import Path
from src.utils.config import SearchType
# Internal Packages # Internal Packages
from src.utils.state import model from src.utils.state import model
@ -13,7 +14,7 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig): def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
# Act # Act
# Regenerate notes embeddings during asymmetric setup # Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True)
# Assert # Assert
assert len(notes_model.entries) == 10 assert len(notes_model.entries) == 10
@ -23,7 +24,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig): def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
query = "How to git install application?" query = "How to git install application?"
# Act # Act
@ -46,7 +47,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig): def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig):
# Arrange # Arrange
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
assert len(initial_notes_model.entries) == 10 assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10 assert len(initial_notes_model.corpus_embeddings) == 10
@ -59,11 +60,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n") f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
# regenerate notes jsonl, model embeddings and model to include entry from new file # regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True) regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True)
# Act # Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files # reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False) initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
# Assert # Assert
assert len(regenerated_notes_model.entries) == 11 assert len(regenerated_notes_model.entries) == 11