Improve Latency of Explicit Filter

### Goal
  - Improve explicit filter latency to work better with incremental search

### Reasons for High Explicit Filter Latency
  - Deleting entries to be excluded from existing list of entries, embeddings
  - Explicit filtering on partial words during incremental search
  - Creating word set for all entries on the fly during query
  - Deep copying of entries, embeddings before applying filter

### Improvement Details
  - **Major**
    - 191a656 Use word to entry map, list comprehension to speed up explicit filter
      - Use list comprehension and `torch.index_select` methods
        - to speed selection of entries, embedding tensors satisfying filter
        - avoid deep copy and direct manipulation of entries, embeddings
      - Use word to entry map and set operations to mark entries 
        satisfying inclusion, exclusion filters
    - c7de57b Pre-compute entry word sets to improve explicit filter query performance
    - 3308e68 Cache explicitly filtered entries, embeddings by required, blocked words
    - cdcee89 Wrap explicit filter words in quotes to trigger filter
      - E.g `+"word_to_include"` instead of `+word_to_include`
      - Signals explicit filter term completed
      - Prevents latency due to incremental search with explicit filtering on partial terms
  - **Minor**
    - 28d3dc1 Deep copy entries, embeddings in filters. Defer till actual filtering
    - 8d9f507 Load entries_by_word_set from file only once on first load of explicit filter
    - 546fad5 Use regex to check for and extract include, exclude filter words from query
    - b7d259b Test Explicit Include, Exclude Filters
    
### Results
  - Improve exclude word filter latency from **20s+ to 0.02s** on 120K line notes corpus
This commit is contained in:
Debanjum 2022-09-04 13:55:17 +00:00 committed by GitHub
commit d153d420fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 272 additions and 78 deletions

View file

@ -40,22 +40,22 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
# Initialize Org Notes Search
if (t == SearchType.Org or t == None) and config.content_type.org:
# Extract Entries, Generate Notes Embeddings
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate)
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, search_type=SearchType.Org, regenerate=regenerate)
# Initialize Org Music Search
if (t == SearchType.Music or t == None) and config.content_type.music:
# Extract Entries, Generate Music Embeddings
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate)
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, search_type=SearchType.Music, regenerate=regenerate)
# Initialize Markdown Search
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
# Extract Entries, Generate Markdown Embeddings
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate)
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, search_type=SearchType.Markdown, regenerate=regenerate)
# Initialize Ledger Search
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
# Extract Entries, Generate Ledger Embeddings
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate)
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, search_type=SearchType.Ledger, regenerate=regenerate)
# Initialize Image Search
if (t == SearchType.Image or t == None) and config.content_type.image:

View file

@ -65,7 +65,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
# query org-mode notes
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
query_end = time.time()
# collate and return results
@ -76,7 +76,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Music or t == None) and state.model.music_search:
# query music library
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
query_end = time.time()
# collate and return results
@ -87,7 +87,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
# query markdown files
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
query_end = time.time()
# collate and return results
@ -98,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
# query transactions
query_start = time.time()
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
query_end = time.time()
# collate and return results

View file

@ -3,6 +3,7 @@ import re
from datetime import timedelta, datetime
from dateutil.relativedelta import relativedelta, MO
from math import inf
from copy import deepcopy
# External Packages
import torch
@ -17,29 +18,42 @@ class DateFilter:
# - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
def __init__(self, entry_key='raw'):
self.entry_key = entry_key
def load(*args, **kwargs):
pass
def can_filter(self, raw_query):
"Check if query contains date filters"
return self.extract_date_range(raw_query) is not None
def filter(self, query, entries, embeddings, entry_key='raw'):
def apply(self, query, raw_entries, raw_embeddings):
"Find entries containing any dates that fall within date range specified in query"
# extract date range specified in date filter of query
query_daterange = self.extract_date_range(query)
# if no date in query, return all entries
if query_daterange is None:
return query, entries, embeddings
return query, raw_entries, raw_embeddings
# remove date range filter from query
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
# deep copy original embeddings, entries before filtering
embeddings= deepcopy(raw_embeddings)
entries = deepcopy(raw_entries)
# find entries containing any dates that fall with date range specified in query
entries_to_include = set()
for id, entry in enumerate(entries):
# Extract dates from entry
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
# Convert date string in entry to unix timestamp
try:
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()

View file

@ -1,57 +1,121 @@
# Standard Packages
import re
import time
import pickle
import logging
# External Packages
import torch
# Internal Packages
from src.utils.helpers import LRU, resolve_absolute_path
from src.utils.config import SearchType
logger = logging.getLogger(__name__)
class ExplicitFilter:
# Filter Regex
required_regex = r'\+"(\w+)" ?'
blocked_regex = r'\-"(\w+)" ?'
def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'):
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
self.entry_key = entry_key
self.search_type = search_type
self.word_to_entry_index = dict()
self.cache = LRU()
def load(self, entries, regenerate=False):
if self.filter_file.exists() and not regenerate:
start = time.time()
with self.filter_file.open('rb') as f:
self.word_to_entry_index = pickle.load(f)
end = time.time()
logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
else:
start = time.time()
self.cache = {} # Clear cache on (re-)generating entries_by_word_set
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
# Create map of words to entries they exist in
for entry_index, entry in enumerate(entries):
for word in re.split(entry_splitter, entry[self.entry_key].lower()):
if word == '':
continue
if word not in self.word_to_entry_index:
self.word_to_entry_index[word] = set()
self.word_to_entry_index[word].add(entry_index)
with self.filter_file.open('wb') as f:
pickle.dump(self.word_to_entry_index, f)
end = time.time()
logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
return self.word_to_entry_index
def can_filter(self, raw_query):
"Check if query contains explicit filters"
# Extract explicit query portion with required, blocked words to filter from natural query
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
required_words = re.findall(self.required_regex, raw_query)
blocked_words = re.findall(self.blocked_regex, raw_query)
return len(required_words) != 0 or len(blocked_words) != 0
def filter(self, raw_query, entries, embeddings, entry_key='raw'):
def apply(self, raw_query, raw_entries, raw_embeddings):
"Find entries containing required and not blocked words specified in query"
# Separate natural query from explicit required, blocked words filters
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
start = time.time()
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)])
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip()
end = time.time()
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
if len(required_words) == 0 and len(blocked_words) == 0:
return query, raw_entries, raw_embeddings
# Return item from cache if exists
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
if cache_key in self.cache:
logger.info(f"Explicit filter results from cache")
entries, embeddings = self.cache[cache_key]
return query, entries, embeddings
# convert each entry to a set of words
# split on fullstop, comma, colon, tab, newline or any brackets
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
entries_by_word_set = [set(word.lower()
for word
in re.split(entry_splitter, entry[entry_key])
if word != "")
for entry in entries]
if not self.word_to_entry_index:
self.load(raw_entries, regenerate=False)
# track id of entries to exclude
entries_to_exclude = set()
start = time.time()
# mark entries that do not contain all required_words for exclusion
# mark entries that contain all required_words for inclusion
entries_with_all_required_words = set(range(len(raw_entries)))
if len(required_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if not required_words.issubset(words_in_entry):
entries_to_exclude.add(id)
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
# mark entries that contain any blocked_words for exclusion
entries_with_any_blocked_words = set()
if len(blocked_words) > 0:
for id, words_in_entry in enumerate(entries_by_word_set):
if words_in_entry.intersection(blocked_words):
entries_to_exclude.add(id)
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
# delete entries (and their embeddings) marked for exclusion
for id in sorted(list(entries_to_exclude), reverse=True):
del entries[id]
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
end = time.time()
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
# get entries (and their embeddings) satisfying inclusion and exclusion filters
start = time.time()
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices]
embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices)))
end = time.time()
logger.debug(f"Keep entries satisfying filter: {end - start} seconds")
# Cache results
self.cache[cache_key] = entries, embeddings
return query, entries, embeddings

View file

@ -3,16 +3,17 @@ import argparse
import pathlib
import logging
import time
from copy import deepcopy
# External Packages
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from src.search_filter.date_filter import DateFilter
from src.search_filter.explicit_filter import ExplicitFilter
# Internal Packages
from src.utils import state
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
from src.utils.config import TextSearchModel
from src.utils.config import SearchType, TextSearchModel
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
from src.utils.jsonl import load_jsonl
@ -73,26 +74,15 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
return corpus_embeddings
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []):
def query(raw_query: str, model: TextSearchModel, rank_results=False):
"Search for entries that answer the query"
query = raw_query
# Use deep copy of original embeddings, entries to filter if query contains filters
start = time.time()
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
if filters_in_query:
corpus_embeddings = deepcopy(model.corpus_embeddings)
entries = deepcopy(model.entries)
else:
corpus_embeddings = model.corpus_embeddings
entries = model.entries
end = time.time()
logger.debug(f"Copy Time: {end - start:.3f} seconds")
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
# Filter query, entries and embeddings before semantic search
start = time.time()
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
for filter in filters_in_query:
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
end = time.time()
logger.debug(f"Filter Time: {end - start:.3f} seconds")
@ -163,7 +153,7 @@ def collate_results(hits, entries, count=5):
in hits[0:count]]
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel:
# Initialize Model
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
@ -180,7 +170,12 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
filter_directory = resolve_absolute_path(config.compressed_jsonl.parent)
filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)]
for filter in filters:
filter.load(entries, regenerate=regenerate)
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
if __name__ == '__main__':

View file

@ -20,11 +20,12 @@ class ProcessorType(str, Enum):
class TextSearchModel():
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k):
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k):
self.entries = entries
self.corpus_embeddings = corpus_embeddings
self.bi_encoder = bi_encoder
self.cross_encoder = cross_encoder
self.filters = filters
self.top_k = top_k

View file

@ -2,6 +2,7 @@
import pathlib
import sys
from os.path import join
from collections import OrderedDict
def is_none_or_empty(item):
@ -61,3 +62,20 @@ def load_model(model_name, model_dir, model_type, device:str=None):
def is_pyinstaller_app():
"Returns true if the app is running from Native GUI created by PyInstaller"
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
class LRU(OrderedDict):
def __init__(self, *args, capacity=128, **kwargs):
self.capacity = capacity
super().__init__(*args, **kwargs)
def __getitem__(self, key):
value = super().__getitem__(key)
self.move_to_end(key)
return value
def __setitem__(self, key, value):
super().__setitem__(key, value)
if len(self) > self.capacity:
oldest = next(iter(self))
del self[oldest]

View file

@ -3,9 +3,9 @@ import pytest
# Internal Packages
from src.search_type import image_search, text_search
from src.utils.config import SearchType
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
from src.utils import state
@pytest.fixture(scope='session')
@ -46,7 +46,7 @@ def model_dir(search_config):
batch_size = 10,
use_xmp_metadata = False)
image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True)
image_search.setup(content_config.image, search_config.image, regenerate=False)
# Generate Notes Embeddings from Test Notes
content_config.org = TextContentConfig(
@ -55,7 +55,7 @@ def model_dir(search_config):
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, verbose=True)
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
return model_dir

View file

@ -8,6 +8,7 @@ import pytest
# Internal Packages
from src.main import app
from src.utils.config import SearchType
from src.utils.state import model, config
from src.search_type import text_search, image_search
from src.utils.rawconfig import ContentConfig, SearchConfig
@ -115,7 +116,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
# ----------------------------------------------------------------------------------------------------
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
user_query = "How to git install application?"
# Act
@ -131,8 +132,8 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
user_query = "How to git install application? +Emacs"
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
user_query = 'How to git install application? +"Emacs"'
# Act
response = client.get(f"/search?q={user_query}&n=1&t=org")
@ -147,8 +148,8 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
# ----------------------------------------------------------------------------------------------------
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
user_query = "How to git install application? -clone"
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
user_query = 'How to git install application? -"clone"'
# Act
response = client.get(f"/search?q={user_query}&n=1&t=org")

View file

@ -18,37 +18,37 @@ def test_date_filter():
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
q_with_no_date_filter = 'head tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings)
ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_no_date_filter, entries.copy(), embeddings)
assert ret_query == 'head tail'
assert len(ret_emb) == 3
assert ret_entries == entries
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
assert ret_query == 'head tail'
assert len(ret_emb) == 0
assert ret_entries == []
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail'
assert ret_entries == [entries[2]]
assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail'
assert ret_entries == [entries[1]]
assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail'
assert ret_entries == [entries[2]]
assert len(ret_emb) == 1
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
assert ret_query == 'head tail'
assert ret_entries == [entries[1], entries[2]]
assert len(ret_emb) == 2

View file

@ -0,0 +1,85 @@
# External Packages
import torch
# Application Packages
from src.search_filter.explicit_filter import ExplicitFilter
from src.utils.config import SearchType
def test_no_explicit_filter(tmp_path):
# Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
q_with_no_filter = 'head tail'
# Act
can_filter = explicit_filter.can_filter(q_with_no_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings)
# Assert
assert can_filter == False
assert ret_query == 'head tail'
assert len(ret_emb) == 4
assert ret_entries == entries
def test_explicit_exclude_filter(tmp_path):
# Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
q_with_exclude_filter = 'head -"exclude_word" tail'
# Act
can_filter = explicit_filter.can_filter(q_with_exclude_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert len(ret_emb) == 2
assert ret_entries == [entries[0], entries[2]]
def test_explicit_include_filter(tmp_path):
# Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
query_with_include_filter = 'head +"include_word" tail'
# Act
can_filter = explicit_filter.can_filter(query_with_include_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert len(ret_emb) == 2
assert ret_entries == [entries[2], entries[3]]
def test_explicit_include_and_exclude_filter(tmp_path):
# Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act
can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings)
# Assert
assert can_filter == True
assert ret_query == 'head tail'
assert len(ret_emb) == 1
assert ret_entries == [entries[2]]
def arrange_content():
embeddings = torch.randn(4, 10)
entries = [
{'compiled': '', 'raw': 'Minimal Entry'},
{'compiled': '', 'raw': 'Entry with exclude_word'},
{'compiled': '', 'raw': 'Entry with include_word'},
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}]
return embeddings, entries

View file

@ -28,3 +28,18 @@ def test_merge_dicts():
# do not override existing key in priority_dict with default dict
assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1}
def test_lru_cache():
# Test initializing cache
cache = helpers.LRU({'a': 1, 'b': 2}, capacity=2)
assert cache == {'a': 1, 'b': 2}
# Test capacity overflow
cache['c'] = 3
assert cache == {'b': 2, 'c': 3}
# Test delete least recently used item from LRU cache on capacity overflow
cache['b'] # accessing 'b' makes it the most recently used item
cache['d'] = 4 # so 'c' is deleted from the cache instead of 'b'
assert cache == {'b': 2, 'd': 4}

View file

@ -1,5 +1,6 @@
# System Packages
from pathlib import Path
from src.utils.config import SearchType
# Internal Packages
from src.utils.state import model
@ -13,7 +14,7 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
# Act
# Regenerate notes embeddings during asymmetric setup
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True)
# Assert
assert len(notes_model.entries) == 10
@ -23,7 +24,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
query = "How to git install application?"
# Act
@ -46,7 +47,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
# ----------------------------------------------------------------------------------------------------
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig):
# Arrange
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
assert len(initial_notes_model.entries) == 10
assert len(initial_notes_model.corpus_embeddings) == 10
@ -59,11 +60,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
# regenerate notes jsonl, model embeddings and model to include entry from new file
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True)
# Act
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
# Assert
assert len(regenerated_notes_model.entries) == 11