Rename explicit filter to word filter to be more specific

This commit is contained in:
Debanjum Singh Solanky 2022-09-04 17:18:47 +03:00
parent d153d420fc
commit f930324350
6 changed files with 28 additions and 32 deletions

View file

@ -125,7 +125,6 @@ pip install --upgrade khoj-assistant
- Semantic search using the bi-encoder is fairly fast at \<50 ms - Semantic search using the bi-encoder is fairly fast at \<50 ms
- Reranking using the cross-encoder is slower at \<2s on 15 results. Tweak `top_k` to tradeoff speed for accuracy of results - Reranking using the cross-encoder is slower at \<2s on 15 results. Tweak `top_k` to tradeoff speed for accuracy of results
- Applying explicit filters is very slow currently at \~6s. This is because the filters are rudimentary. Considerable speed-ups can be achieved using indexes etc
### Indexing performance ### Indexing performance

View file

@ -16,8 +16,6 @@ from fastapi.templating import Jinja2Templates
from src.configure import configure_search from src.configure import configure_search
from src.search_type import image_search, text_search from src.search_type import image_search, text_search
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
from src.search_filter.explicit_filter import ExplicitFilter
from src.search_filter.date_filter import DateFilter
from src.utils.rawconfig import FullConfig from src.utils.rawconfig import FullConfig
from src.utils.config import SearchType from src.utils.config import SearchType
from src.utils.helpers import get_absolute_path, get_from_dict from src.utils.helpers import get_absolute_path, get_from_dict

View file

@ -15,13 +15,13 @@ from src.utils.config import SearchType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ExplicitFilter: class WordFilter:
# Filter Regex # Filter Regex
required_regex = r'\+"(\w+)" ?' required_regex = r'\+"(\w+)" ?'
blocked_regex = r'\-"(\w+)" ?' blocked_regex = r'\-"(\w+)" ?'
def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'): def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'):
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl") self.filter_file = resolve_absolute_path(filter_directory / f"word_filter_{search_type.name.lower()}_index.pkl")
self.entry_key = entry_key self.entry_key = entry_key
self.search_type = search_type self.search_type = search_type
self.word_to_entry_index = dict() self.word_to_entry_index = dict()
@ -34,7 +34,7 @@ class ExplicitFilter:
with self.filter_file.open('rb') as f: with self.filter_file.open('rb') as f:
self.word_to_entry_index = pickle.load(f) self.word_to_entry_index = pickle.load(f)
end = time.time() end = time.time()
logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds") logger.debug(f"Load word filter index for {self.search_type} from {self.filter_file}: {end - start} seconds")
else: else:
start = time.time() start = time.time()
self.cache = {} # Clear cache on (re-)generating entries_by_word_set self.cache = {} # Clear cache on (re-)generating entries_by_word_set
@ -51,14 +51,13 @@ class ExplicitFilter:
with self.filter_file.open('wb') as f: with self.filter_file.open('wb') as f:
pickle.dump(self.word_to_entry_index, f) pickle.dump(self.word_to_entry_index, f)
end = time.time() end = time.time()
logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds") logger.debug(f"Index {self.search_type} for word filter to {self.filter_file}: {end - start} seconds")
return self.word_to_entry_index return self.word_to_entry_index
def can_filter(self, raw_query): def can_filter(self, raw_query):
"Check if query contains explicit filters" "Check if query contains word filters"
# Extract explicit query portion with required, blocked words to filter from natural query
required_words = re.findall(self.required_regex, raw_query) required_words = re.findall(self.required_regex, raw_query)
blocked_words = re.findall(self.blocked_regex, raw_query) blocked_words = re.findall(self.blocked_regex, raw_query)
@ -67,7 +66,7 @@ class ExplicitFilter:
def apply(self, raw_query, raw_entries, raw_embeddings): def apply(self, raw_query, raw_entries, raw_embeddings):
"Find entries containing required and not blocked words specified in query" "Find entries containing required and not blocked words specified in query"
# Separate natural query from explicit required, blocked words filters # Separate natural query from required, blocked words filters
start = time.time() start = time.time()
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)]) required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
@ -83,7 +82,7 @@ class ExplicitFilter:
# Return item from cache if exists # Return item from cache if exists
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words)) cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
if cache_key in self.cache: if cache_key in self.cache:
logger.info(f"Explicit filter results from cache") logger.info(f"Return word filter results from cache")
entries, embeddings = self.cache[cache_key] entries, embeddings = self.cache[cache_key]
return query, entries, embeddings return query, entries, embeddings

View file

@ -8,7 +8,7 @@ import time
import torch import torch
from sentence_transformers import SentenceTransformer, CrossEncoder, util from sentence_transformers import SentenceTransformer, CrossEncoder, util
from src.search_filter.date_filter import DateFilter from src.search_filter.date_filter import DateFilter
from src.search_filter.explicit_filter import ExplicitFilter from src.search_filter.word_filter import WordFilter
# Internal Packages # Internal Packages
from src.utils import state from src.utils import state
@ -171,7 +171,7 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate) corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
filter_directory = resolve_absolute_path(config.compressed_jsonl.parent) filter_directory = resolve_absolute_path(config.compressed_jsonl.parent)
filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)] filters = [DateFilter(), WordFilter(filter_directory, search_type=search_type)]
for filter in filters: for filter in filters:
filter.load(entries, regenerate=regenerate) filter.load(entries, regenerate=regenerate)

View file

@ -140,7 +140,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
# assert actual_data contains explicitly included word "Emacs" # assert actual_data contains word "Emacs"
search_result = response.json()[0]["entry"] search_result = response.json()[0]["entry"]
assert "Emacs" in search_result assert "Emacs" in search_result
@ -156,6 +156,6 @@ def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
# assert actual_data does not contains explicitly excluded word "Emacs" # assert actual_data does not contains word "Emacs"
search_result = response.json()[0]["entry"] search_result = response.json()[0]["entry"]
assert "clone" not in search_result assert "clone" not in search_result

View file

@ -2,19 +2,19 @@
import torch import torch
# Application Packages # Application Packages
from src.search_filter.explicit_filter import ExplicitFilter from src.search_filter.word_filter import WordFilter
from src.utils.config import SearchType from src.utils.config import SearchType
def test_no_explicit_filter(tmp_path): def test_no_word_filter(tmp_path):
# Arrange # Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) word_filter = WordFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content() embeddings, entries = arrange_content()
q_with_no_filter = 'head tail' q_with_no_filter = 'head tail'
# Act # Act
can_filter = explicit_filter.can_filter(q_with_no_filter) can_filter = word_filter.can_filter(q_with_no_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = word_filter.apply(q_with_no_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == False assert can_filter == False
@ -23,15 +23,15 @@ def test_no_explicit_filter(tmp_path):
assert ret_entries == entries assert ret_entries == entries
def test_explicit_exclude_filter(tmp_path): def test_word_exclude_filter(tmp_path):
# Arrange # Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) word_filter = WordFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content() embeddings, entries = arrange_content()
q_with_exclude_filter = 'head -"exclude_word" tail' q_with_exclude_filter = 'head -"exclude_word" tail'
# Act # Act
can_filter = explicit_filter.can_filter(q_with_exclude_filter) can_filter = word_filter.can_filter(q_with_exclude_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == True assert can_filter == True
@ -40,15 +40,15 @@ def test_explicit_exclude_filter(tmp_path):
assert ret_entries == [entries[0], entries[2]] assert ret_entries == [entries[0], entries[2]]
def test_explicit_include_filter(tmp_path): def test_word_include_filter(tmp_path):
# Arrange # Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) word_filter = WordFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content() embeddings, entries = arrange_content()
query_with_include_filter = 'head +"include_word" tail' query_with_include_filter = 'head +"include_word" tail'
# Act # Act
can_filter = explicit_filter.can_filter(query_with_include_filter) can_filter = word_filter.can_filter(query_with_include_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == True assert can_filter == True
@ -57,15 +57,15 @@ def test_explicit_include_filter(tmp_path):
assert ret_entries == [entries[2], entries[3]] assert ret_entries == [entries[2], entries[3]]
def test_explicit_include_and_exclude_filter(tmp_path): def test_word_include_and_exclude_filter(tmp_path):
# Arrange # Arrange
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org) word_filter = WordFilter(tmp_path, SearchType.Org)
embeddings, entries = arrange_content() embeddings, entries = arrange_content()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail' query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act # Act
can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter) can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings) ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings)
# Assert # Assert
assert can_filter == True assert can_filter == True