mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-12-17 18:17:10 +00:00
Improve Latency of Explicit Filter
### Goal - Improve explicit filter latency to work better with incremental search ### Reasons for High Explicit Filter Latency - Deleting entries to be excluded from existing list of entries, embeddings - Explicit filtering on partial words during incremental search - Creating word set for all entries on the fly during query - Deep copying of entries, embeddings before applying filter ### Improvement Details - **Major** -191a656
Use word to entry map, list comprehension to speed up explicit filter - Use list comprehension and `torch.index_select` methods - to speed selection of entries, embedding tensors satisfying filter - avoid deep copy and direct manipulation of entries, embeddings - Use word to entry map and set operations to mark entries satisfying inclusion, exclusion filters -c7de57b
Pre-compute entry word sets to improve explicit filter query performance -3308e68
Cache explicitly filtered entries, embeddings by required, blocked words -cdcee89
Wrap explicit filter words in quotes to trigger filter - E.g `+"word_to_include"` instead of `+word_to_include` - Signals explicit filter term completed - Prevents latency due to incremental search with explicit filtering on partial terms - **Minor** -28d3dc1
Deep copy entries, embeddings in filters. Defer till actual filtering -8d9f507
Load entries_by_word_set from file only once on first load of explicit filter -546fad5
Use regex to check for and extract include, exclude filter words from query -b7d259b
Test Explicit Include, Exclude Filters ### Results - Improve exclude word filter latency from **20s+ to 0.02s** on 120K line notes corpus
This commit is contained in:
commit
d153d420fc
13 changed files with 272 additions and 78 deletions
|
@ -40,22 +40,22 @@ def configure_search(model: SearchModels, config: FullConfig, regenerate: bool,
|
|||
# Initialize Org Notes Search
|
||||
if (t == SearchType.Org or t == None) and config.content_type.org:
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, search_type=SearchType.Org, regenerate=regenerate)
|
||||
|
||||
# Initialize Org Music Search
|
||||
if (t == SearchType.Music or t == None) and config.content_type.music:
|
||||
# Extract Entries, Generate Music Embeddings
|
||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, search_type=SearchType.Music, regenerate=regenerate)
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, search_type=SearchType.Markdown, regenerate=regenerate)
|
||||
|
||||
# Initialize Ledger Search
|
||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||
# Extract Entries, Generate Ledger Embeddings
|
||||
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate)
|
||||
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, search_type=SearchType.Ledger, regenerate=regenerate)
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == SearchType.Image or t == None) and config.content_type.image:
|
||||
|
|
|
@ -65,7 +65,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
if (t == SearchType.Org or t == None) and state.model.orgmode_search:
|
||||
# query org-mode notes
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.orgmode_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
@ -76,7 +76,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
if (t == SearchType.Music or t == None) and state.model.music_search:
|
||||
# query music library
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r, filters=[DateFilter(), ExplicitFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.music_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
@ -87,7 +87,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
if (t == SearchType.Markdown or t == None) and state.model.markdown_search:
|
||||
# query markdown files
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.markdown_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
@ -98,7 +98,7 @@ def search(q: str, n: Optional[int] = 5, t: Optional[SearchType] = None, r: Opti
|
|||
if (t == SearchType.Ledger or t == None) and state.model.ledger_search:
|
||||
# query transactions
|
||||
query_start = time.time()
|
||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r, filters=[ExplicitFilter(), DateFilter()], verbose=state.verbose)
|
||||
hits, entries = text_search.query(user_query, state.model.ledger_search, rank_results=r)
|
||||
query_end = time.time()
|
||||
|
||||
# collate and return results
|
||||
|
|
|
@ -3,6 +3,7 @@ import re
|
|||
from datetime import timedelta, datetime
|
||||
from dateutil.relativedelta import relativedelta, MO
|
||||
from math import inf
|
||||
from copy import deepcopy
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
@ -17,29 +18,42 @@ class DateFilter:
|
|||
# - dt:"2 years ago"
|
||||
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
||||
|
||||
|
||||
def __init__(self, entry_key='raw'):
|
||||
self.entry_key = entry_key
|
||||
|
||||
|
||||
def load(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains date filters"
|
||||
return self.extract_date_range(raw_query) is not None
|
||||
|
||||
|
||||
def filter(self, query, entries, embeddings, entry_key='raw'):
|
||||
def apply(self, query, raw_entries, raw_embeddings):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
query_daterange = self.extract_date_range(query)
|
||||
|
||||
# if no date in query, return all entries
|
||||
if query_daterange is None:
|
||||
return query, entries, embeddings
|
||||
return query, raw_entries, raw_embeddings
|
||||
|
||||
# remove date range filter from query
|
||||
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
|
||||
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
|
||||
|
||||
# deep copy original embeddings, entries before filtering
|
||||
embeddings= deepcopy(raw_embeddings)
|
||||
entries = deepcopy(raw_entries)
|
||||
|
||||
# find entries containing any dates that fall with date range specified in query
|
||||
entries_to_include = set()
|
||||
for id, entry in enumerate(entries):
|
||||
# Extract dates from entry
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[entry_key]):
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
|
||||
# Convert date string in entry to unix timestamp
|
||||
try:
|
||||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||
|
|
|
@ -1,57 +1,121 @@
|
|||
# Standard Packages
|
||||
import re
|
||||
import time
|
||||
import pickle
|
||||
import logging
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.helpers import LRU, resolve_absolute_path
|
||||
from src.utils.config import SearchType
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExplicitFilter:
|
||||
# Filter Regex
|
||||
required_regex = r'\+"(\w+)" ?'
|
||||
blocked_regex = r'\-"(\w+)" ?'
|
||||
|
||||
def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'):
|
||||
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
|
||||
self.entry_key = entry_key
|
||||
self.search_type = search_type
|
||||
self.word_to_entry_index = dict()
|
||||
self.cache = LRU()
|
||||
|
||||
|
||||
def load(self, entries, regenerate=False):
|
||||
if self.filter_file.exists() and not regenerate:
|
||||
start = time.time()
|
||||
with self.filter_file.open('rb') as f:
|
||||
self.word_to_entry_index = pickle.load(f)
|
||||
end = time.time()
|
||||
logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
|
||||
else:
|
||||
start = time.time()
|
||||
self.cache = {} # Clear cache on (re-)generating entries_by_word_set
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
||||
# Create map of words to entries they exist in
|
||||
for entry_index, entry in enumerate(entries):
|
||||
for word in re.split(entry_splitter, entry[self.entry_key].lower()):
|
||||
if word == '':
|
||||
continue
|
||||
if word not in self.word_to_entry_index:
|
||||
self.word_to_entry_index[word] = set()
|
||||
self.word_to_entry_index[word].add(entry_index)
|
||||
|
||||
with self.filter_file.open('wb') as f:
|
||||
pickle.dump(self.word_to_entry_index, f)
|
||||
end = time.time()
|
||||
logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
|
||||
|
||||
return self.word_to_entry_index
|
||||
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains explicit filters"
|
||||
# Extract explicit query portion with required, blocked words to filter from natural query
|
||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||
required_words = re.findall(self.required_regex, raw_query)
|
||||
blocked_words = re.findall(self.blocked_regex, raw_query)
|
||||
|
||||
return len(required_words) != 0 or len(blocked_words) != 0
|
||||
|
||||
|
||||
def filter(self, raw_query, entries, embeddings, entry_key='raw'):
|
||||
def apply(self, raw_query, raw_entries, raw_embeddings):
|
||||
"Find entries containing required and not blocked words specified in query"
|
||||
# Separate natural query from explicit required, blocked words filters
|
||||
query = " ".join([word for word in raw_query.split() if not word.startswith("+") and not word.startswith("-")])
|
||||
required_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("+")])
|
||||
blocked_words = set([word[1:].lower() for word in raw_query.split() if word.startswith("-")])
|
||||
start = time.time()
|
||||
|
||||
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
|
||||
blocked_words = set([word.lower() for word in re.findall(self.blocked_regex, raw_query)])
|
||||
query = re.sub(self.blocked_regex, '', re.sub(self.required_regex, '', raw_query)).strip()
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return query, raw_entries, raw_embeddings
|
||||
|
||||
# Return item from cache if exists
|
||||
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
|
||||
if cache_key in self.cache:
|
||||
logger.info(f"Explicit filter results from cache")
|
||||
entries, embeddings = self.cache[cache_key]
|
||||
return query, entries, embeddings
|
||||
|
||||
# convert each entry to a set of words
|
||||
# split on fullstop, comma, colon, tab, newline or any brackets
|
||||
entry_splitter = r',|\.| |\]|\[\(|\)|\{|\}|\t|\n|\:'
|
||||
entries_by_word_set = [set(word.lower()
|
||||
for word
|
||||
in re.split(entry_splitter, entry[entry_key])
|
||||
if word != "")
|
||||
for entry in entries]
|
||||
if not self.word_to_entry_index:
|
||||
self.load(raw_entries, regenerate=False)
|
||||
|
||||
# track id of entries to exclude
|
||||
entries_to_exclude = set()
|
||||
start = time.time()
|
||||
|
||||
# mark entries that do not contain all required_words for exclusion
|
||||
# mark entries that contain all required_words for inclusion
|
||||
entries_with_all_required_words = set(range(len(raw_entries)))
|
||||
if len(required_words) > 0:
|
||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||
if not required_words.issubset(words_in_entry):
|
||||
entries_to_exclude.add(id)
|
||||
entries_with_all_required_words = set.intersection(*[self.word_to_entry_index.get(word, set()) for word in required_words])
|
||||
|
||||
# mark entries that contain any blocked_words for exclusion
|
||||
entries_with_any_blocked_words = set()
|
||||
if len(blocked_words) > 0:
|
||||
for id, words_in_entry in enumerate(entries_by_word_set):
|
||||
if words_in_entry.intersection(blocked_words):
|
||||
entries_to_exclude.add(id)
|
||||
entries_with_any_blocked_words = set.union(*[self.word_to_entry_index.get(word, set()) for word in blocked_words])
|
||||
|
||||
# delete entries (and their embeddings) marked for exclusion
|
||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||
del entries[id]
|
||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||
end = time.time()
|
||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
||||
|
||||
# get entries (and their embeddings) satisfying inclusion and exclusion filters
|
||||
start = time.time()
|
||||
|
||||
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
|
||||
entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices]
|
||||
embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices)))
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Keep entries satisfying filter: {end - start} seconds")
|
||||
|
||||
# Cache results
|
||||
self.cache[cache_key] = entries, embeddings
|
||||
|
||||
return query, entries, embeddings
|
||||
|
|
|
@ -3,16 +3,17 @@ import argparse
|
|||
import pathlib
|
||||
import logging
|
||||
import time
|
||||
from copy import deepcopy
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
from src.search_filter.explicit_filter import ExplicitFilter
|
||||
|
||||
# Internal Packages
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.config import SearchType, TextSearchModel
|
||||
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
||||
from src.utils.jsonl import load_jsonl
|
||||
|
||||
|
@ -73,26 +74,15 @@ def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
|
|||
return corpus_embeddings
|
||||
|
||||
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False, filters: list = []):
|
||||
def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
||||
"Search for entries that answer the query"
|
||||
query = raw_query
|
||||
|
||||
# Use deep copy of original embeddings, entries to filter if query contains filters
|
||||
start = time.time()
|
||||
filters_in_query = [filter for filter in filters if filter.can_filter(query)]
|
||||
if filters_in_query:
|
||||
corpus_embeddings = deepcopy(model.corpus_embeddings)
|
||||
entries = deepcopy(model.entries)
|
||||
else:
|
||||
corpus_embeddings = model.corpus_embeddings
|
||||
entries = model.entries
|
||||
end = time.time()
|
||||
logger.debug(f"Copy Time: {end - start:.3f} seconds")
|
||||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
start = time.time()
|
||||
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
|
||||
for filter in filters_in_query:
|
||||
query, entries, corpus_embeddings = filter.filter(query, entries, corpus_embeddings)
|
||||
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
|
||||
end = time.time()
|
||||
logger.debug(f"Filter Time: {end - start:.3f} seconds")
|
||||
|
||||
|
@ -163,7 +153,7 @@ def collate_results(hits, entries, count=5):
|
|||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool) -> TextSearchModel:
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
|
@ -180,7 +170,12 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
|
|||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, top_k)
|
||||
filter_directory = resolve_absolute_path(config.compressed_jsonl.parent)
|
||||
filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)]
|
||||
for filter in filters:
|
||||
filter.load(entries, regenerate=regenerate)
|
||||
|
||||
return TextSearchModel(entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -20,11 +20,12 @@ class ProcessorType(str, Enum):
|
|||
|
||||
|
||||
class TextSearchModel():
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, top_k):
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
self.cross_encoder = cross_encoder
|
||||
self.filters = filters
|
||||
self.top_k = top_k
|
||||
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import pathlib
|
||||
import sys
|
||||
from os.path import join
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def is_none_or_empty(item):
|
||||
|
@ -61,3 +62,20 @@ def load_model(model_name, model_dir, model_type, device:str=None):
|
|||
def is_pyinstaller_app():
|
||||
"Returns true if the app is running from Native GUI created by PyInstaller"
|
||||
return getattr(sys, 'frozen', False) and hasattr(sys, '_MEIPASS')
|
||||
|
||||
|
||||
class LRU(OrderedDict):
|
||||
def __init__(self, *args, capacity=128, **kwargs):
|
||||
self.capacity = capacity
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __getitem__(self, key):
|
||||
value = super().__getitem__(key)
|
||||
self.move_to_end(key)
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
super().__setitem__(key, value)
|
||||
if len(self) > self.capacity:
|
||||
oldest = next(iter(self))
|
||||
del self[oldest]
|
||||
|
|
|
@ -3,9 +3,9 @@ import pytest
|
|||
|
||||
# Internal Packages
|
||||
from src.search_type import image_search, text_search
|
||||
from src.utils.config import SearchType
|
||||
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
|
||||
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
from src.utils import state
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
|
@ -46,7 +46,7 @@ def model_dir(search_config):
|
|||
batch_size = 10,
|
||||
use_xmp_metadata = False)
|
||||
|
||||
image_search.setup(content_config.image, search_config.image, regenerate=False, verbose=True)
|
||||
image_search.setup(content_config.image, search_config.image, regenerate=False)
|
||||
|
||||
# Generate Notes Embeddings from Test Notes
|
||||
content_config.org = TextContentConfig(
|
||||
|
@ -55,7 +55,7 @@ def model_dir(search_config):
|
|||
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
|
||||
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
|
||||
|
||||
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, verbose=True)
|
||||
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
|
||||
return model_dir
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ import pytest
|
|||
|
||||
# Internal Packages
|
||||
from src.main import app
|
||||
from src.utils.config import SearchType
|
||||
from src.utils.state import model, config
|
||||
from src.search_type import text_search, image_search
|
||||
from src.utils.rawconfig import ContentConfig, SearchConfig
|
||||
|
@ -115,7 +116,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
user_query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
|
@ -131,8 +132,8 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
user_query = "How to git install application? +Emacs"
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
user_query = 'How to git install application? +"Emacs"'
|
||||
|
||||
# Act
|
||||
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
||||
|
@ -147,8 +148,8 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
user_query = "How to git install application? -clone"
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
user_query = 'How to git install application? -"clone"'
|
||||
|
||||
# Act
|
||||
response = client.get(f"/search?q={user_query}&n=1&t=org")
|
||||
|
|
|
@ -18,37 +18,37 @@ def test_date_filter():
|
|||
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
|
||||
|
||||
q_with_no_date_filter = 'head tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_no_date_filter, entries.copy(), embeddings)
|
||||
ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_no_date_filter, entries.copy(), embeddings)
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 3
|
||||
assert ret_entries == entries
|
||||
|
||||
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
|
||||
ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 0
|
||||
assert ret_entries == []
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[2]]
|
||||
assert len(ret_emb) == 1
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[1]]
|
||||
assert len(ret_emb) == 1
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[2]]
|
||||
assert len(ret_emb) == 1
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().filter(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[1], entries[2]]
|
||||
assert len(ret_emb) == 2
|
||||
|
|
85
tests/test_explicit_filter.py
Normal file
85
tests/test_explicit_filter.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
# External Packages
|
||||
import torch
|
||||
|
||||
# Application Packages
|
||||
from src.search_filter.explicit_filter import ExplicitFilter
|
||||
from src.utils.config import SearchType
|
||||
|
||||
|
||||
def test_no_explicit_filter(tmp_path):
|
||||
# Arrange
|
||||
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head tail'
|
||||
|
||||
# Act
|
||||
can_filter = explicit_filter.can_filter(q_with_no_filter)
|
||||
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
||||
|
||||
# Assert
|
||||
assert can_filter == False
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 4
|
||||
assert ret_entries == entries
|
||||
|
||||
|
||||
def test_explicit_exclude_filter(tmp_path):
|
||||
# Arrange
|
||||
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_exclude_filter = 'head -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = explicit_filter.can_filter(q_with_exclude_filter)
|
||||
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 2
|
||||
assert ret_entries == [entries[0], entries[2]]
|
||||
|
||||
|
||||
def test_explicit_include_filter(tmp_path):
|
||||
# Arrange
|
||||
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
|
||||
embeddings, entries = arrange_content()
|
||||
query_with_include_filter = 'head +"include_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = explicit_filter.can_filter(query_with_include_filter)
|
||||
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 2
|
||||
assert ret_entries == [entries[2], entries[3]]
|
||||
|
||||
|
||||
def test_explicit_include_and_exclude_filter(tmp_path):
|
||||
# Arrange
|
||||
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
|
||||
embeddings, entries = arrange_content()
|
||||
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter)
|
||||
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 1
|
||||
assert ret_entries == [entries[2]]
|
||||
|
||||
|
||||
def arrange_content():
|
||||
embeddings = torch.randn(4, 10)
|
||||
entries = [
|
||||
{'compiled': '', 'raw': 'Minimal Entry'},
|
||||
{'compiled': '', 'raw': 'Entry with exclude_word'},
|
||||
{'compiled': '', 'raw': 'Entry with include_word'},
|
||||
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}]
|
||||
|
||||
return embeddings, entries
|
|
@ -28,3 +28,18 @@ def test_merge_dicts():
|
|||
|
||||
# do not override existing key in priority_dict with default dict
|
||||
assert helpers.merge_dicts(priority_dict={'a': 1}, default_dict={'a': 2}) == {'a': 1}
|
||||
|
||||
|
||||
def test_lru_cache():
|
||||
# Test initializing cache
|
||||
cache = helpers.LRU({'a': 1, 'b': 2}, capacity=2)
|
||||
assert cache == {'a': 1, 'b': 2}
|
||||
|
||||
# Test capacity overflow
|
||||
cache['c'] = 3
|
||||
assert cache == {'b': 2, 'c': 3}
|
||||
|
||||
# Test delete least recently used item from LRU cache on capacity overflow
|
||||
cache['b'] # accessing 'b' makes it the most recently used item
|
||||
cache['d'] = 4 # so 'c' is deleted from the cache instead of 'b'
|
||||
assert cache == {'b': 2, 'd': 4}
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# System Packages
|
||||
from pathlib import Path
|
||||
from src.utils.config import SearchType
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.state import model
|
||||
|
@ -13,7 +14,7 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
|||
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Act
|
||||
# Regenerate notes embeddings during asymmetric setup
|
||||
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True)
|
||||
|
||||
# Assert
|
||||
assert len(notes_model.entries) == 10
|
||||
|
@ -23,7 +24,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
|
@ -46,7 +47,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
|
||||
assert len(initial_notes_model.entries) == 10
|
||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||
|
@ -59,11 +60,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
|
|||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True)
|
||||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
|
||||
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
|
||||
# Assert
|
||||
assert len(regenerated_notes_model.entries) == 11
|
||||
|
|
Loading…
Reference in a new issue