mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Create File Filter. Improve, Consolidate Filter Code
### General Filter Improvements -e441874
Create Abstract Base Class for all filters to inherit from -965bd05
Make search filters return entry ids satisfying filter -092b9e3
Setup Filters when configuring Text Search for each Search Type -31503e7
Do not pass embeddings in argument to `filter.apply` method as unused ### Create File Filter -7606724
Add file associated with each entry to entry dict in `org_to_jsonl` converter -1f9fd28
Create File Filter to filter files specified in query -7dd20d7
Pre-compute file to entry map in speed up file based filter -7e083d3
Cache results for file filters passed in query for faster filtering -2890b4c
Simplify extracting entries satisfying file filter ### Miscellaneous -f930324
Rename `explicit filter` to more appropriate name `word filter` -3707a4c
Improve date filter perf. Precompute date to entry map, Cache results
This commit is contained in:
commit
0a78cd5477
19 changed files with 457 additions and 221 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,3 +13,4 @@ src/.data
|
|||
/dist/
|
||||
/khoj_assistant.egg-info/
|
||||
/config/khoj*.yml
|
||||
.pytest_cache
|
||||
|
|
|
@ -125,7 +125,6 @@ pip install --upgrade khoj-assistant
|
|||
|
||||
- Semantic search using the bi-encoder is fairly fast at \<50 ms
|
||||
- Reranking using the cross-encoder is slower at \<2s on 15 results. Tweak `top_k` to tradeoff speed for accuracy of results
|
||||
- Applying explicit filters is very slow currently at \~6s. This is because the filters are rudimentary. Considerable speed-ups can be achieved using indexes etc
|
||||
|
||||
### Indexing performance
|
||||
|
||||
|
|
|
@ -14,6 +14,9 @@ from src.utils.config import SearchType, SearchModels, ProcessorConfigModel, Con
|
|||
from src.utils import state
|
||||
from src.utils.helpers import resolve_absolute_path
|
||||
from src.utils.rawconfig import FullConfig, ProcessorConfig
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
from src.search_filter.word_filter import WordFilter
|
||||
from src.search_filter.file_filter import FileFilter
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -39,23 +42,25 @@ def configure_server(args, required=False):
|
|||
def configure_search(model: SearchModels, config: FullConfig, regenerate: bool, t: SearchType = None):
|
||||
# Initialize Org Notes Search
|
||||
if (t == SearchType.Org or t == None) and config.content_type.org:
|
||||
filter_directory = resolve_absolute_path(config.content_type.org.compressed_jsonl.parent)
|
||||
filters = [DateFilter(), WordFilter(filter_directory, search_type=SearchType.Org), FileFilter()]
|
||||
# Extract Entries, Generate Notes Embeddings
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, search_type=SearchType.Org, regenerate=regenerate)
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, config.content_type.org, search_config=config.search_type.asymmetric, regenerate=regenerate, filters=filters)
|
||||
|
||||
# Initialize Org Music Search
|
||||
if (t == SearchType.Music or t == None) and config.content_type.music:
|
||||
# Extract Entries, Generate Music Embeddings
|
||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, search_type=SearchType.Music, regenerate=regenerate)
|
||||
model.music_search = text_search.setup(org_to_jsonl, config.content_type.music, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
||||
|
||||
# Initialize Markdown Search
|
||||
if (t == SearchType.Markdown or t == None) and config.content_type.markdown:
|
||||
# Extract Entries, Generate Markdown Embeddings
|
||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, search_type=SearchType.Markdown, regenerate=regenerate)
|
||||
model.markdown_search = text_search.setup(markdown_to_jsonl, config.content_type.markdown, search_config=config.search_type.asymmetric, regenerate=regenerate)
|
||||
|
||||
# Initialize Ledger Search
|
||||
if (t == SearchType.Ledger or t == None) and config.content_type.ledger:
|
||||
# Extract Entries, Generate Ledger Embeddings
|
||||
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, search_type=SearchType.Ledger, regenerate=regenerate)
|
||||
model.ledger_search = text_search.setup(beancount_to_jsonl, config.content_type.ledger, search_config=config.search_type.symmetric, regenerate=regenerate)
|
||||
|
||||
# Initialize Image Search
|
||||
if (t == SearchType.Image or t == None) and config.content_type.image:
|
||||
|
|
|
@ -28,10 +28,10 @@ def org_to_jsonl(org_files, org_file_filter, output_file):
|
|||
org_files = get_org_files(org_files, org_file_filter)
|
||||
|
||||
# Extract Entries from specified Org files
|
||||
entries = extract_org_entries(org_files)
|
||||
entries, file_to_entries = extract_org_entries(org_files)
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries)
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries, file_to_entries)
|
||||
|
||||
# Compress JSONL formatted Data
|
||||
if output_file.suffix == ".gz":
|
||||
|
@ -66,18 +66,19 @@ def get_org_files(org_files=None, org_file_filter=None):
|
|||
def extract_org_entries(org_files):
|
||||
"Extract entries from specified Org files"
|
||||
entries = []
|
||||
entry_to_file_map = []
|
||||
for org_file in org_files:
|
||||
entries.extend(
|
||||
orgnode.makelist(
|
||||
str(org_file)))
|
||||
org_file_entries = orgnode.makelist(str(org_file))
|
||||
entry_to_file_map += [org_file]*len(org_file_entries)
|
||||
entries.extend(org_file_entries)
|
||||
|
||||
return entries
|
||||
return entries, entry_to_file_map
|
||||
|
||||
|
||||
def convert_org_entries_to_jsonl(entries) -> str:
|
||||
def convert_org_entries_to_jsonl(entries, entry_to_file_map) -> str:
|
||||
"Convert each Org-Mode entries to JSON and collate as JSONL"
|
||||
jsonl = ''
|
||||
for entry in entries:
|
||||
for entry_id, entry in enumerate(entries):
|
||||
entry_dict = dict()
|
||||
|
||||
# Ignore title notes i.e notes with just headings and empty body
|
||||
|
@ -106,6 +107,7 @@ def convert_org_entries_to_jsonl(entries) -> str:
|
|||
|
||||
if entry_dict:
|
||||
entry_dict["raw"] = f'{entry}'
|
||||
entry_dict["file"] = f'{entry_to_file_map[entry_id]}'
|
||||
|
||||
# Convert Dictionary to JSON and Append to JSONL string
|
||||
jsonl += f'{json.dumps(entry_dict, ensure_ascii=False)}\n'
|
||||
|
|
|
@ -16,8 +16,6 @@ from fastapi.templating import Jinja2Templates
|
|||
from src.configure import configure_search
|
||||
from src.search_type import image_search, text_search
|
||||
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
|
||||
from src.search_filter.explicit_filter import ExplicitFilter
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
from src.utils.rawconfig import FullConfig
|
||||
from src.utils.config import SearchType
|
||||
from src.utils.helpers import get_absolute_path, get_from_dict
|
||||
|
|
20
src/search_filter/base_filter.py
Normal file
20
src/search_filter/base_filter.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# Standard Packages
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
||||
|
||||
class BaseFilter(ABC):
|
||||
@abstractmethod
|
||||
def load(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_filter(self, raw_query:str) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, query:str, raw_entries:List[str]) -> Tuple[str, Set[int]]:
|
||||
pass
|
|
@ -1,56 +1,40 @@
|
|||
# Standard Packages
|
||||
import re
|
||||
import time
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import timedelta, datetime
|
||||
from dateutil.relativedelta import relativedelta, MO
|
||||
from dateutil.relativedelta import relativedelta
|
||||
from math import inf
|
||||
from copy import deepcopy
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
import dateparser as dtparse
|
||||
|
||||
# Internal Packages
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils.helpers import LRU
|
||||
|
||||
class DateFilter:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DateFilter(BaseFilter):
|
||||
# Date Range Filter Regexes
|
||||
# Example filter queries:
|
||||
# - dt>="yesterday" dt<"tomorrow"
|
||||
# - dt>="last week"
|
||||
# - dt:"2 years ago"
|
||||
# - dt>="yesterday" dt<"tomorrow"
|
||||
# - dt>="last week"
|
||||
# - dt:"2 years ago"
|
||||
date_regex = r"dt([:><=]{1,2})\"(.*?)\""
|
||||
|
||||
|
||||
def __init__(self, entry_key='raw'):
|
||||
self.entry_key = entry_key
|
||||
self.date_to_entry_ids = defaultdict(set)
|
||||
self.cache = LRU()
|
||||
|
||||
|
||||
def load(*args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains date filters"
|
||||
return self.extract_date_range(raw_query) is not None
|
||||
|
||||
|
||||
def apply(self, query, raw_entries, raw_embeddings):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
query_daterange = self.extract_date_range(query)
|
||||
|
||||
# if no date in query, return all entries
|
||||
if query_daterange is None:
|
||||
return query, raw_entries, raw_embeddings
|
||||
|
||||
# remove date range filter from query
|
||||
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
|
||||
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
|
||||
|
||||
# deep copy original embeddings, entries before filtering
|
||||
embeddings= deepcopy(raw_embeddings)
|
||||
entries = deepcopy(raw_entries)
|
||||
|
||||
# find entries containing any dates that fall with date range specified in query
|
||||
entries_to_include = set()
|
||||
def load(self, entries, **_):
|
||||
start = time.time()
|
||||
for id, entry in enumerate(entries):
|
||||
# Extract dates from entry
|
||||
for date_in_entry_string in re.findall(r'\d{4}-\d{2}-\d{2}', entry[self.entry_key]):
|
||||
|
@ -59,18 +43,56 @@ class DateFilter:
|
|||
date_in_entry = datetime.strptime(date_in_entry_string, '%Y-%m-%d').timestamp()
|
||||
except ValueError:
|
||||
continue
|
||||
# Check if date in entry is within date range specified in query
|
||||
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
||||
entries_to_include.add(id)
|
||||
break
|
||||
self.date_to_entry_ids[date_in_entry].add(id)
|
||||
end = time.time()
|
||||
logger.debug(f"Created file filter index: {end - start} seconds")
|
||||
|
||||
# delete entries (and their embeddings) marked for exclusion
|
||||
entries_to_exclude = set(range(len(entries))) - entries_to_include
|
||||
for id in sorted(list(entries_to_exclude), reverse=True):
|
||||
del entries[id]
|
||||
embeddings = torch.cat((embeddings[:id], embeddings[id+1:]))
|
||||
|
||||
return query, entries, embeddings
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains date filters"
|
||||
return self.extract_date_range(raw_query) is not None
|
||||
|
||||
|
||||
def apply(self, query, raw_entries):
|
||||
"Find entries containing any dates that fall within date range specified in query"
|
||||
# extract date range specified in date filter of query
|
||||
start = time.time()
|
||||
query_daterange = self.extract_date_range(query)
|
||||
end = time.time()
|
||||
logger.debug(f"Extract date range to filter from query: {end - start} seconds")
|
||||
|
||||
# if no date in query, return all entries
|
||||
if query_daterange is None:
|
||||
return query, set(range(len(raw_entries)))
|
||||
|
||||
# remove date range filter from query
|
||||
query = re.sub(rf'\s+{self.date_regex}', ' ', query)
|
||||
query = re.sub(r'\s{2,}', ' ', query).strip() # remove multiple spaces
|
||||
|
||||
# return results from cache if exists
|
||||
cache_key = tuple(query_daterange)
|
||||
if cache_key in self.cache:
|
||||
logger.info(f"Return date filter results from cache")
|
||||
entries_to_include = self.cache[cache_key]
|
||||
return query, entries_to_include
|
||||
|
||||
if not self.date_to_entry_ids:
|
||||
self.load(raw_entries)
|
||||
|
||||
# find entries containing any dates that fall with date range specified in query
|
||||
start = time.time()
|
||||
entries_to_include = set()
|
||||
for date_in_entry in self.date_to_entry_ids.keys():
|
||||
# Check if date in entry is within date range specified in query
|
||||
if query_daterange[0] <= date_in_entry < query_daterange[1]:
|
||||
entries_to_include |= self.date_to_entry_ids[date_in_entry]
|
||||
end = time.time()
|
||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
||||
|
||||
# cache results
|
||||
self.cache[cache_key] = entries_to_include
|
||||
|
||||
return query, entries_to_include
|
||||
|
||||
|
||||
def extract_date_range(self, query):
|
||||
|
|
79
src/search_filter/file_filter.py
Normal file
79
src/search_filter/file_filter.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
# Standard Packages
|
||||
import re
|
||||
import fnmatch
|
||||
import time
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
# Internal Packages
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils.helpers import LRU
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileFilter(BaseFilter):
|
||||
file_filter_regex = r'file:"(.+?)" ?'
|
||||
|
||||
def __init__(self, entry_key='file'):
|
||||
self.entry_key = entry_key
|
||||
self.file_to_entry_map = defaultdict(set)
|
||||
self.cache = LRU()
|
||||
|
||||
def load(self, entries, *args, **kwargs):
|
||||
start = time.time()
|
||||
for id, entry in enumerate(entries):
|
||||
self.file_to_entry_map[entry[self.entry_key]].add(id)
|
||||
end = time.time()
|
||||
logger.debug(f"Created file filter index: {end - start} seconds")
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
return re.search(self.file_filter_regex, raw_query) is not None
|
||||
|
||||
def apply(self, raw_query, raw_entries):
|
||||
# Extract file filters from raw query
|
||||
start = time.time()
|
||||
raw_files_to_search = re.findall(self.file_filter_regex, raw_query)
|
||||
if not raw_files_to_search:
|
||||
return raw_query, set(range(len(raw_entries)))
|
||||
|
||||
# Convert simple file filters with no path separator into regex
|
||||
# e.g. "file:notes.org" -> "file:.*notes.org"
|
||||
files_to_search = []
|
||||
for file in sorted(raw_files_to_search):
|
||||
if '/' not in file and '\\' not in file and '*' not in file:
|
||||
files_to_search += [f'*{file}']
|
||||
else:
|
||||
files_to_search += [file]
|
||||
end = time.time()
|
||||
logger.debug(f"Extract files_to_search from query: {end - start} seconds")
|
||||
|
||||
# Return item from cache if exists
|
||||
query = re.sub(self.file_filter_regex, '', raw_query).strip()
|
||||
cache_key = tuple(files_to_search)
|
||||
if cache_key in self.cache:
|
||||
logger.info(f"Return file filter results from cache")
|
||||
included_entry_indices = self.cache[cache_key]
|
||||
return query, included_entry_indices
|
||||
|
||||
if not self.file_to_entry_map:
|
||||
self.load(raw_entries, regenerate=False)
|
||||
|
||||
# Mark entries that contain any blocked_words for exclusion
|
||||
start = time.time()
|
||||
|
||||
included_entry_indices = set.union(*[self.file_to_entry_map[entry_file]
|
||||
for entry_file in self.file_to_entry_map.keys()
|
||||
for search_file in files_to_search
|
||||
if fnmatch.fnmatch(entry_file, search_file)], set())
|
||||
if not included_entry_indices:
|
||||
return query, {}
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
||||
|
||||
# Cache results
|
||||
self.cache[cache_key] = included_entry_indices
|
||||
|
||||
return query, included_entry_indices
|
|
@ -8,6 +8,7 @@ import logging
|
|||
import torch
|
||||
|
||||
# Internal Packages
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
from src.utils.helpers import LRU, resolve_absolute_path
|
||||
from src.utils.config import SearchType
|
||||
|
||||
|
@ -15,13 +16,13 @@ from src.utils.config import SearchType
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExplicitFilter:
|
||||
class WordFilter(BaseFilter):
|
||||
# Filter Regex
|
||||
required_regex = r'\+"(\w+)" ?'
|
||||
blocked_regex = r'\-"(\w+)" ?'
|
||||
|
||||
def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'):
|
||||
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
|
||||
self.filter_file = resolve_absolute_path(filter_directory / f"word_filter_{search_type.name.lower()}_index.pkl")
|
||||
self.entry_key = entry_key
|
||||
self.search_type = search_type
|
||||
self.word_to_entry_index = dict()
|
||||
|
@ -34,7 +35,7 @@ class ExplicitFilter:
|
|||
with self.filter_file.open('rb') as f:
|
||||
self.word_to_entry_index = pickle.load(f)
|
||||
end = time.time()
|
||||
logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
|
||||
logger.debug(f"Load word filter index for {self.search_type} from {self.filter_file}: {end - start} seconds")
|
||||
else:
|
||||
start = time.time()
|
||||
self.cache = {} # Clear cache on (re-)generating entries_by_word_set
|
||||
|
@ -51,23 +52,22 @@ class ExplicitFilter:
|
|||
with self.filter_file.open('wb') as f:
|
||||
pickle.dump(self.word_to_entry_index, f)
|
||||
end = time.time()
|
||||
logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
|
||||
logger.debug(f"Index {self.search_type} for word filter to {self.filter_file}: {end - start} seconds")
|
||||
|
||||
return self.word_to_entry_index
|
||||
|
||||
|
||||
def can_filter(self, raw_query):
|
||||
"Check if query contains explicit filters"
|
||||
# Extract explicit query portion with required, blocked words to filter from natural query
|
||||
"Check if query contains word filters"
|
||||
required_words = re.findall(self.required_regex, raw_query)
|
||||
blocked_words = re.findall(self.blocked_regex, raw_query)
|
||||
|
||||
return len(required_words) != 0 or len(blocked_words) != 0
|
||||
|
||||
|
||||
def apply(self, raw_query, raw_entries, raw_embeddings):
|
||||
def apply(self, raw_query, raw_entries):
|
||||
"Find entries containing required and not blocked words specified in query"
|
||||
# Separate natural query from explicit required, blocked words filters
|
||||
# Separate natural query from required, blocked words filters
|
||||
start = time.time()
|
||||
|
||||
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
|
||||
|
@ -78,14 +78,14 @@ class ExplicitFilter:
|
|||
logger.debug(f"Extract required, blocked filters from query: {end - start} seconds")
|
||||
|
||||
if len(required_words) == 0 and len(blocked_words) == 0:
|
||||
return query, raw_entries, raw_embeddings
|
||||
return query, set(range(len(raw_entries)))
|
||||
|
||||
# Return item from cache if exists
|
||||
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
|
||||
if cache_key in self.cache:
|
||||
logger.info(f"Explicit filter results from cache")
|
||||
entries, embeddings = self.cache[cache_key]
|
||||
return query, entries, embeddings
|
||||
logger.info(f"Return word filter results from cache")
|
||||
included_entry_indices = self.cache[cache_key]
|
||||
return query, included_entry_indices
|
||||
|
||||
if not self.word_to_entry_index:
|
||||
self.load(raw_entries, regenerate=False)
|
||||
|
@ -105,17 +105,10 @@ class ExplicitFilter:
|
|||
end = time.time()
|
||||
logger.debug(f"Mark entries satisfying filter: {end - start} seconds")
|
||||
|
||||
# get entries (and their embeddings) satisfying inclusion and exclusion filters
|
||||
start = time.time()
|
||||
|
||||
# get entries satisfying inclusion and exclusion filters
|
||||
included_entry_indices = entries_with_all_required_words - entries_with_any_blocked_words
|
||||
entries = [entry for id, entry in enumerate(raw_entries) if id in included_entry_indices]
|
||||
embeddings = torch.index_select(raw_embeddings, 0, torch.tensor(list(included_entry_indices)))
|
||||
|
||||
end = time.time()
|
||||
logger.debug(f"Keep entries satisfying filter: {end - start} seconds")
|
||||
|
||||
# Cache results
|
||||
self.cache[cache_key] = entries, embeddings
|
||||
self.cache[cache_key] = included_entry_indices
|
||||
|
||||
return query, entries, embeddings
|
||||
return query, included_entry_indices
|
|
@ -7,13 +7,12 @@ import time
|
|||
# External Packages
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
from src.search_filter.explicit_filter import ExplicitFilter
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
|
||||
# Internal Packages
|
||||
from src.utils import state
|
||||
from src.utils.helpers import get_absolute_path, resolve_absolute_path, load_model
|
||||
from src.utils.config import SearchType, TextSearchModel
|
||||
from src.utils.config import TextSearchModel
|
||||
from src.utils.rawconfig import TextSearchConfig, TextContentConfig
|
||||
from src.utils.jsonl import load_jsonl
|
||||
|
||||
|
@ -53,9 +52,7 @@ def initialize_model(search_config: TextSearchConfig):
|
|||
|
||||
def extract_entries(jsonl_file):
|
||||
"Load entries from compressed jsonl"
|
||||
return [{'compiled': f'{entry["compiled"]}', 'raw': f'{entry["raw"]}'}
|
||||
for entry
|
||||
in load_jsonl(jsonl_file)]
|
||||
return load_jsonl(jsonl_file)
|
||||
|
||||
|
||||
def compute_embeddings(entries, bi_encoder, embeddings_file, regenerate=False):
|
||||
|
@ -79,12 +76,25 @@ def query(raw_query: str, model: TextSearchModel, rank_results=False):
|
|||
query, entries, corpus_embeddings = raw_query, model.entries, model.corpus_embeddings
|
||||
|
||||
# Filter query, entries and embeddings before semantic search
|
||||
start = time.time()
|
||||
start_filter = time.time()
|
||||
included_entry_indices = set(range(len(entries)))
|
||||
filters_in_query = [filter for filter in model.filters if filter.can_filter(query)]
|
||||
for filter in filters_in_query:
|
||||
query, entries, corpus_embeddings = filter.apply(query, entries, corpus_embeddings)
|
||||
end = time.time()
|
||||
logger.debug(f"Filter Time: {end - start:.3f} seconds")
|
||||
query, included_entry_indices_by_filter = filter.apply(query, entries)
|
||||
included_entry_indices.intersection_update(included_entry_indices_by_filter)
|
||||
|
||||
# Get entries (and associated embeddings) satisfying all filters
|
||||
if not included_entry_indices:
|
||||
return [], []
|
||||
else:
|
||||
start = time.time()
|
||||
entries = [entries[id] for id in included_entry_indices]
|
||||
corpus_embeddings = torch.index_select(corpus_embeddings, 0, torch.tensor(list(included_entry_indices)))
|
||||
end = time.time()
|
||||
logger.debug(f"Keep entries satisfying all filters: {end - start} seconds")
|
||||
|
||||
end_filter = time.time()
|
||||
logger.debug(f"Total Filter Time: {end_filter - start_filter:.3f} seconds")
|
||||
|
||||
if entries is None or len(entries) == 0:
|
||||
return [], []
|
||||
|
@ -153,7 +163,7 @@ def collate_results(hits, entries, count=5):
|
|||
in hits[0:count]]
|
||||
|
||||
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, search_type: SearchType, regenerate: bool) -> TextSearchModel:
|
||||
def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchConfig, regenerate: bool, filters: list[BaseFilter] = []) -> TextSearchModel:
|
||||
# Initialize Model
|
||||
bi_encoder, cross_encoder, top_k = initialize_model(search_config)
|
||||
|
||||
|
@ -170,8 +180,6 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
|
|||
config.embeddings_file = resolve_absolute_path(config.embeddings_file)
|
||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
||||
|
||||
filter_directory = resolve_absolute_path(config.compressed_jsonl.parent)
|
||||
filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)]
|
||||
for filter in filters:
|
||||
filter.load(entries, regenerate=regenerate)
|
||||
|
||||
|
|
|
@ -2,9 +2,11 @@
|
|||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.rawconfig import ConversationProcessorConfig
|
||||
from src.search_filter.base_filter import BaseFilter
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
|
@ -20,7 +22,7 @@ class ProcessorType(str, Enum):
|
|||
|
||||
|
||||
class TextSearchModel():
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters, top_k):
|
||||
def __init__(self, entries, corpus_embeddings, bi_encoder, cross_encoder, filters: List[BaseFilter], top_k):
|
||||
self.entries = entries
|
||||
self.corpus_embeddings = corpus_embeddings
|
||||
self.bi_encoder = bi_encoder
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Standard Packages
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
|
@ -6,10 +6,13 @@ from src.search_type import image_search, text_search
|
|||
from src.utils.config import SearchType
|
||||
from src.utils.rawconfig import ContentConfig, TextContentConfig, ImageContentConfig, SearchConfig, TextSearchConfig, ImageSearchConfig
|
||||
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
from src.search_filter.date_filter import DateFilter
|
||||
from src.search_filter.word_filter import WordFilter
|
||||
from src.search_filter.file_filter import FileFilter
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def search_config(tmp_path_factory):
|
||||
def search_config(tmp_path_factory) -> SearchConfig:
|
||||
model_dir = tmp_path_factory.mktemp('data')
|
||||
|
||||
search_config = SearchConfig()
|
||||
|
@ -35,7 +38,7 @@ def search_config(tmp_path_factory):
|
|||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def model_dir(search_config):
|
||||
def model_dir(search_config: SearchConfig):
|
||||
model_dir = search_config.asymmetric.model_directory
|
||||
|
||||
# Generate Image Embeddings from Test Images
|
||||
|
@ -55,13 +58,14 @@ def model_dir(search_config):
|
|||
compressed_jsonl = model_dir.joinpath('notes.jsonl.gz'),
|
||||
embeddings_file = model_dir.joinpath('note_embeddings.pt'))
|
||||
|
||||
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
filters = [DateFilter(), WordFilter(model_dir, search_type=SearchType.Org), FileFilter()]
|
||||
text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
||||
|
||||
return model_dir
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def content_config(model_dir):
|
||||
def content_config(model_dir) -> ContentConfig:
|
||||
content_config = ContentConfig()
|
||||
content_config.org = TextContentConfig(
|
||||
input_files = None,
|
||||
|
|
|
@ -4,7 +4,6 @@ from PIL import Image
|
|||
|
||||
# External Packages
|
||||
from fastapi.testclient import TestClient
|
||||
import pytest
|
||||
|
||||
# Internal Packages
|
||||
from src.main import app
|
||||
|
@ -12,7 +11,8 @@ from src.utils.config import SearchType
|
|||
from src.utils.state import model, config
|
||||
from src.search_type import text_search, image_search
|
||||
from src.utils.rawconfig import ContentConfig, SearchConfig
|
||||
from src.processor.org_mode import org_to_jsonl
|
||||
from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
||||
from src.search_filter.word_filter import WordFilter
|
||||
|
||||
|
||||
# Arrange
|
||||
|
@ -116,7 +116,7 @@ def test_image_search(content_config: ContentConfig, search_config: SearchConfig
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
user_query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
|
@ -132,7 +132,8 @@ def test_notes_search(content_config: ContentConfig, search_config: SearchConfig
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_include_filter(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
filters = [WordFilter(content_config.org.compressed_jsonl.parent, search_type=SearchType.Org)]
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
||||
user_query = 'How to git install application? +"Emacs"'
|
||||
|
||||
# Act
|
||||
|
@ -140,7 +141,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
|
|||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# assert actual_data contains explicitly included word "Emacs"
|
||||
# assert actual_data contains word "Emacs"
|
||||
search_result = response.json()[0]["entry"]
|
||||
assert "Emacs" in search_result
|
||||
|
||||
|
@ -148,7 +149,8 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
filters = [WordFilter(content_config.org.compressed_jsonl.parent, search_type=SearchType.Org)]
|
||||
model.orgmode_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False, filters=filters)
|
||||
user_query = 'How to git install application? -"clone"'
|
||||
|
||||
# Act
|
||||
|
@ -156,6 +158,6 @@ def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_
|
|||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
# assert actual_data does not contains explicitly excluded word "Emacs"
|
||||
# assert actual_data does not contains word "Emacs"
|
||||
search_result = response.json()[0]["entry"]
|
||||
assert "clone" not in search_result
|
||||
|
|
|
@ -18,40 +18,34 @@ def test_date_filter():
|
|||
{'compiled': '', 'raw': 'Entry with date:1984-04-02'}]
|
||||
|
||||
q_with_no_date_filter = 'head tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_no_date_filter, entries.copy(), embeddings)
|
||||
ret_query, entry_indices = DateFilter().apply(q_with_no_date_filter, entries)
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 3
|
||||
assert ret_entries == entries
|
||||
assert entry_indices == {0, 1, 2}
|
||||
|
||||
q_with_dtrange_non_overlapping_at_boundary = 'head dt>"1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries.copy(), embeddings)
|
||||
ret_query, entry_indices = DateFilter().apply(q_with_dtrange_non_overlapping_at_boundary, entries)
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 0
|
||||
assert ret_entries == []
|
||||
assert entry_indices == set()
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<"1984-04-03" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[2]]
|
||||
assert len(ret_emb) == 1
|
||||
assert entry_indices == {2}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<"1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[1]]
|
||||
assert len(ret_emb) == 1
|
||||
assert entry_indices == {1}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>"1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[2]]
|
||||
assert len(ret_emb) == 1
|
||||
assert entry_indices == {2}
|
||||
|
||||
query_with_overlapping_dtrange = 'head dt>="1984-04-01" dt<="1984-04-02" tail'
|
||||
ret_query, ret_entries, ret_emb = DateFilter().apply(query_with_overlapping_dtrange, entries.copy(), embeddings)
|
||||
ret_query, entry_indices = DateFilter().apply(query_with_overlapping_dtrange, entries)
|
||||
assert ret_query == 'head tail'
|
||||
assert ret_entries == [entries[1], entries[2]]
|
||||
assert len(ret_emb) == 2
|
||||
assert entry_indices == {1, 2}
|
||||
|
||||
|
||||
def test_extract_date_range():
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
# External Packages
|
||||
import torch
|
||||
|
||||
# Application Packages
|
||||
from src.search_filter.explicit_filter import ExplicitFilter
|
||||
from src.utils.config import SearchType
|
||||
|
||||
|
||||
def test_no_explicit_filter(tmp_path):
|
||||
# Arrange
|
||||
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head tail'
|
||||
|
||||
# Act
|
||||
can_filter = explicit_filter.can_filter(q_with_no_filter)
|
||||
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
||||
|
||||
# Assert
|
||||
assert can_filter == False
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 4
|
||||
assert ret_entries == entries
|
||||
|
||||
|
||||
def test_explicit_exclude_filter(tmp_path):
|
||||
# Arrange
|
||||
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_exclude_filter = 'head -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = explicit_filter.can_filter(q_with_exclude_filter)
|
||||
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 2
|
||||
assert ret_entries == [entries[0], entries[2]]
|
||||
|
||||
|
||||
def test_explicit_include_filter(tmp_path):
|
||||
# Arrange
|
||||
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
|
||||
embeddings, entries = arrange_content()
|
||||
query_with_include_filter = 'head +"include_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = explicit_filter.can_filter(query_with_include_filter)
|
||||
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 2
|
||||
assert ret_entries == [entries[2], entries[3]]
|
||||
|
||||
|
||||
def test_explicit_include_and_exclude_filter(tmp_path):
|
||||
# Arrange
|
||||
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
|
||||
embeddings, entries = arrange_content()
|
||||
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter)
|
||||
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert len(ret_emb) == 1
|
||||
assert ret_entries == [entries[2]]
|
||||
|
||||
|
||||
def arrange_content():
|
||||
embeddings = torch.randn(4, 10)
|
||||
entries = [
|
||||
{'compiled': '', 'raw': 'Minimal Entry'},
|
||||
{'compiled': '', 'raw': 'Entry with exclude_word'},
|
||||
{'compiled': '', 'raw': 'Entry with include_word'},
|
||||
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}]
|
||||
|
||||
return embeddings, entries
|
112
tests/test_file_filter.py
Normal file
112
tests/test_file_filter.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
# External Packages
|
||||
import torch
|
||||
|
||||
# Application Packages
|
||||
from src.search_filter.file_filter import FileFilter
|
||||
|
||||
|
||||
def test_no_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == False
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_file_filter_with_non_existent_file():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"nonexistent.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {}
|
||||
|
||||
|
||||
def test_single_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"file 1.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_file_filter_with_partial_match():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"1.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_file_filter_with_regex_match():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head file:"*.org" tail'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_multiple_file_filter():
|
||||
# Arrange
|
||||
file_filter = FileFilter()
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head tail file:"file 1.org" file:"file2.org"'
|
||||
|
||||
# Act
|
||||
can_filter = file_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = file_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def arrange_content():
|
||||
embeddings = torch.randn(4, 10)
|
||||
entries = [
|
||||
{'compiled': '', 'raw': 'First Entry', 'file': 'file 1.org'},
|
||||
{'compiled': '', 'raw': 'Second Entry', 'file': 'file2.org'},
|
||||
{'compiled': '', 'raw': 'Third Entry', 'file': 'file 1.org'},
|
||||
{'compiled': '', 'raw': 'Fourth Entry', 'file': 'file2.org'}]
|
||||
|
||||
return embeddings, entries
|
|
@ -21,10 +21,10 @@ def test_entry_with_empty_body_line_to_jsonl(tmp_path):
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries = extract_org_entries(org_files=[orgfile])
|
||||
entries, entry_to_file_map = extract_org_entries(org_files=[orgfile])
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries)
|
||||
jsonl_data = convert_org_entries_to_jsonl(entries, entry_to_file_map)
|
||||
|
||||
# Assert
|
||||
assert is_none_or_empty(jsonl_data)
|
||||
|
@ -43,10 +43,10 @@ def test_entry_with_body_to_jsonl(tmp_path):
|
|||
|
||||
# Act
|
||||
# Extract Entries from specified Org files
|
||||
entries = extract_org_entries(org_files=[orgfile])
|
||||
entries, entry_to_file_map = extract_org_entries(org_files=[orgfile])
|
||||
|
||||
# Process Each Entry from All Notes Files
|
||||
jsonl_string = convert_org_entries_to_jsonl(entries)
|
||||
jsonl_string = convert_org_entries_to_jsonl(entries, entry_to_file_map)
|
||||
jsonl_data = [json.loads(json_string) for json_string in jsonl_string.splitlines()]
|
||||
|
||||
# Assert
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
# System Packages
|
||||
from pathlib import Path
|
||||
from src.utils.config import SearchType
|
||||
|
||||
# Internal Packages
|
||||
from src.utils.state import model
|
||||
|
@ -14,7 +13,7 @@ from src.processor.org_mode.org_to_jsonl import org_to_jsonl
|
|||
def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Act
|
||||
# Regenerate notes embeddings during asymmetric setup
|
||||
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True)
|
||||
notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
|
||||
# Assert
|
||||
assert len(notes_model.entries) == 10
|
||||
|
@ -24,7 +23,7 @@ def test_asymmetric_setup(content_config: ContentConfig, search_config: SearchCo
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_search(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
model.notes_search = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
query = "How to git install application?"
|
||||
|
||||
# Act
|
||||
|
@ -47,7 +46,7 @@ def test_asymmetric_search(content_config: ContentConfig, search_config: SearchC
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchConfig):
|
||||
# Arrange
|
||||
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
initial_notes_model= text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
|
||||
assert len(initial_notes_model.entries) == 10
|
||||
assert len(initial_notes_model.corpus_embeddings) == 10
|
||||
|
@ -60,11 +59,11 @@ def test_asymmetric_reload(content_config: ContentConfig, search_config: SearchC
|
|||
f.write("\n* A Chihuahua doing Tango\n- Saw a super cute video of a chihuahua doing the Tango on Youtube\n")
|
||||
|
||||
# regenerate notes jsonl, model embeddings and model to include entry from new file
|
||||
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=True)
|
||||
regenerated_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=True)
|
||||
|
||||
# Act
|
||||
# reload embeddings, entries, notes model from previously generated notes jsonl and model embeddings files
|
||||
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, SearchType.Org, regenerate=False)
|
||||
initial_notes_model = text_search.setup(org_to_jsonl, content_config.org, search_config.asymmetric, regenerate=False)
|
||||
|
||||
# Assert
|
||||
assert len(regenerated_notes_model.entries) == 11
|
||||
|
|
81
tests/test_word_filter.py
Normal file
81
tests/test_word_filter.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
# External Packages
|
||||
import torch
|
||||
|
||||
# Application Packages
|
||||
from src.search_filter.word_filter import WordFilter
|
||||
from src.utils.config import SearchType
|
||||
|
||||
|
||||
def test_no_word_filter(tmp_path):
|
||||
# Arrange
|
||||
word_filter = WordFilter(tmp_path, SearchType.Org)
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_no_filter = 'head tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(q_with_no_filter)
|
||||
ret_query, entry_indices = word_filter.apply(q_with_no_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == False
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 1, 2, 3}
|
||||
|
||||
|
||||
def test_word_exclude_filter(tmp_path):
|
||||
# Arrange
|
||||
word_filter = WordFilter(tmp_path, SearchType.Org)
|
||||
embeddings, entries = arrange_content()
|
||||
q_with_exclude_filter = 'head -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(q_with_exclude_filter)
|
||||
ret_query, entry_indices = word_filter.apply(q_with_exclude_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {0, 2}
|
||||
|
||||
|
||||
def test_word_include_filter(tmp_path):
|
||||
# Arrange
|
||||
word_filter = WordFilter(tmp_path, SearchType.Org)
|
||||
embeddings, entries = arrange_content()
|
||||
query_with_include_filter = 'head +"include_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(query_with_include_filter)
|
||||
ret_query, entry_indices = word_filter.apply(query_with_include_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {2, 3}
|
||||
|
||||
|
||||
def test_word_include_and_exclude_filter(tmp_path):
|
||||
# Arrange
|
||||
word_filter = WordFilter(tmp_path, SearchType.Org)
|
||||
embeddings, entries = arrange_content()
|
||||
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
||||
|
||||
# Act
|
||||
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
|
||||
ret_query, entry_indices = word_filter.apply(query_with_include_and_exclude_filter, entries)
|
||||
|
||||
# Assert
|
||||
assert can_filter == True
|
||||
assert ret_query == 'head tail'
|
||||
assert entry_indices == {2}
|
||||
|
||||
|
||||
def arrange_content():
|
||||
embeddings = torch.randn(4, 10)
|
||||
entries = [
|
||||
{'compiled': '', 'raw': 'Minimal Entry'},
|
||||
{'compiled': '', 'raw': 'Entry with exclude_word'},
|
||||
{'compiled': '', 'raw': 'Entry with include_word'},
|
||||
{'compiled': '', 'raw': 'Entry with include_word and exclude_word'}]
|
||||
|
||||
return embeddings, entries
|
Loading…
Add table
Reference in a new issue