mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Rename explicit filter to word filter to be more specific
This commit is contained in:
parent
d153d420fc
commit
f930324350
6 changed files with 28 additions and 32 deletions
|
@ -125,7 +125,6 @@ pip install --upgrade khoj-assistant
|
||||||
|
|
||||||
- Semantic search using the bi-encoder is fairly fast at \<50 ms
|
- Semantic search using the bi-encoder is fairly fast at \<50 ms
|
||||||
- Reranking using the cross-encoder is slower at \<2s on 15 results. Tweak `top_k` to tradeoff speed for accuracy of results
|
- Reranking using the cross-encoder is slower at \<2s on 15 results. Tweak `top_k` to tradeoff speed for accuracy of results
|
||||||
- Applying explicit filters is very slow currently at \~6s. This is because the filters are rudimentary. Considerable speed-ups can be achieved using indexes etc
|
|
||||||
|
|
||||||
### Indexing performance
|
### Indexing performance
|
||||||
|
|
||||||
|
|
|
@ -16,8 +16,6 @@ from fastapi.templating import Jinja2Templates
|
||||||
from src.configure import configure_search
|
from src.configure import configure_search
|
||||||
from src.search_type import image_search, text_search
|
from src.search_type import image_search, text_search
|
||||||
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
|
from src.processor.conversation.gpt import converse, extract_search_type, message_to_log, message_to_prompt, understand, summarize
|
||||||
from src.search_filter.explicit_filter import ExplicitFilter
|
|
||||||
from src.search_filter.date_filter import DateFilter
|
|
||||||
from src.utils.rawconfig import FullConfig
|
from src.utils.rawconfig import FullConfig
|
||||||
from src.utils.config import SearchType
|
from src.utils.config import SearchType
|
||||||
from src.utils.helpers import get_absolute_path, get_from_dict
|
from src.utils.helpers import get_absolute_path, get_from_dict
|
||||||
|
|
|
@ -15,13 +15,13 @@ from src.utils.config import SearchType
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ExplicitFilter:
|
class WordFilter:
|
||||||
# Filter Regex
|
# Filter Regex
|
||||||
required_regex = r'\+"(\w+)" ?'
|
required_regex = r'\+"(\w+)" ?'
|
||||||
blocked_regex = r'\-"(\w+)" ?'
|
blocked_regex = r'\-"(\w+)" ?'
|
||||||
|
|
||||||
def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'):
|
def __init__(self, filter_directory, search_type: SearchType, entry_key='raw'):
|
||||||
self.filter_file = resolve_absolute_path(filter_directory / f"{search_type.name.lower()}_explicit_filter_entry_word_sets.pkl")
|
self.filter_file = resolve_absolute_path(filter_directory / f"word_filter_{search_type.name.lower()}_index.pkl")
|
||||||
self.entry_key = entry_key
|
self.entry_key = entry_key
|
||||||
self.search_type = search_type
|
self.search_type = search_type
|
||||||
self.word_to_entry_index = dict()
|
self.word_to_entry_index = dict()
|
||||||
|
@ -34,7 +34,7 @@ class ExplicitFilter:
|
||||||
with self.filter_file.open('rb') as f:
|
with self.filter_file.open('rb') as f:
|
||||||
self.word_to_entry_index = pickle.load(f)
|
self.word_to_entry_index = pickle.load(f)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.debug(f"Load {self.search_type} entries by word set from file: {end - start} seconds")
|
logger.debug(f"Load word filter index for {self.search_type} from {self.filter_file}: {end - start} seconds")
|
||||||
else:
|
else:
|
||||||
start = time.time()
|
start = time.time()
|
||||||
self.cache = {} # Clear cache on (re-)generating entries_by_word_set
|
self.cache = {} # Clear cache on (re-)generating entries_by_word_set
|
||||||
|
@ -51,14 +51,13 @@ class ExplicitFilter:
|
||||||
with self.filter_file.open('wb') as f:
|
with self.filter_file.open('wb') as f:
|
||||||
pickle.dump(self.word_to_entry_index, f)
|
pickle.dump(self.word_to_entry_index, f)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
logger.debug(f"Convert all {self.search_type} entries to word sets: {end - start} seconds")
|
logger.debug(f"Index {self.search_type} for word filter to {self.filter_file}: {end - start} seconds")
|
||||||
|
|
||||||
return self.word_to_entry_index
|
return self.word_to_entry_index
|
||||||
|
|
||||||
|
|
||||||
def can_filter(self, raw_query):
|
def can_filter(self, raw_query):
|
||||||
"Check if query contains explicit filters"
|
"Check if query contains word filters"
|
||||||
# Extract explicit query portion with required, blocked words to filter from natural query
|
|
||||||
required_words = re.findall(self.required_regex, raw_query)
|
required_words = re.findall(self.required_regex, raw_query)
|
||||||
blocked_words = re.findall(self.blocked_regex, raw_query)
|
blocked_words = re.findall(self.blocked_regex, raw_query)
|
||||||
|
|
||||||
|
@ -67,7 +66,7 @@ class ExplicitFilter:
|
||||||
|
|
||||||
def apply(self, raw_query, raw_entries, raw_embeddings):
|
def apply(self, raw_query, raw_entries, raw_embeddings):
|
||||||
"Find entries containing required and not blocked words specified in query"
|
"Find entries containing required and not blocked words specified in query"
|
||||||
# Separate natural query from explicit required, blocked words filters
|
# Separate natural query from required, blocked words filters
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
|
required_words = set([word.lower() for word in re.findall(self.required_regex, raw_query)])
|
||||||
|
@ -83,7 +82,7 @@ class ExplicitFilter:
|
||||||
# Return item from cache if exists
|
# Return item from cache if exists
|
||||||
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
|
cache_key = tuple(sorted(required_words)), tuple(sorted(blocked_words))
|
||||||
if cache_key in self.cache:
|
if cache_key in self.cache:
|
||||||
logger.info(f"Explicit filter results from cache")
|
logger.info(f"Return word filter results from cache")
|
||||||
entries, embeddings = self.cache[cache_key]
|
entries, embeddings = self.cache[cache_key]
|
||||||
return query, entries, embeddings
|
return query, entries, embeddings
|
||||||
|
|
|
@ -8,7 +8,7 @@ import time
|
||||||
import torch
|
import torch
|
||||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||||
from src.search_filter.date_filter import DateFilter
|
from src.search_filter.date_filter import DateFilter
|
||||||
from src.search_filter.explicit_filter import ExplicitFilter
|
from src.search_filter.word_filter import WordFilter
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from src.utils import state
|
from src.utils import state
|
||||||
|
@ -171,7 +171,7 @@ def setup(text_to_jsonl, config: TextContentConfig, search_config: TextSearchCon
|
||||||
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
corpus_embeddings = compute_embeddings(entries, bi_encoder, config.embeddings_file, regenerate=regenerate)
|
||||||
|
|
||||||
filter_directory = resolve_absolute_path(config.compressed_jsonl.parent)
|
filter_directory = resolve_absolute_path(config.compressed_jsonl.parent)
|
||||||
filters = [DateFilter(), ExplicitFilter(filter_directory, search_type=search_type)]
|
filters = [DateFilter(), WordFilter(filter_directory, search_type=search_type)]
|
||||||
for filter in filters:
|
for filter in filters:
|
||||||
filter.load(entries, regenerate=regenerate)
|
filter.load(entries, regenerate=regenerate)
|
||||||
|
|
||||||
|
|
|
@ -140,7 +140,7 @@ def test_notes_search_with_include_filter(content_config: ContentConfig, search_
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
# assert actual_data contains explicitly included word "Emacs"
|
# assert actual_data contains word "Emacs"
|
||||||
search_result = response.json()[0]["entry"]
|
search_result = response.json()[0]["entry"]
|
||||||
assert "Emacs" in search_result
|
assert "Emacs" in search_result
|
||||||
|
|
||||||
|
@ -156,6 +156,6 @@ def test_notes_search_with_exclude_filter(content_config: ContentConfig, search_
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
# assert actual_data does not contains explicitly excluded word "Emacs"
|
# assert actual_data does not contains word "Emacs"
|
||||||
search_result = response.json()[0]["entry"]
|
search_result = response.json()[0]["entry"]
|
||||||
assert "clone" not in search_result
|
assert "clone" not in search_result
|
||||||
|
|
|
@ -2,19 +2,19 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Application Packages
|
# Application Packages
|
||||||
from src.search_filter.explicit_filter import ExplicitFilter
|
from src.search_filter.word_filter import WordFilter
|
||||||
from src.utils.config import SearchType
|
from src.utils.config import SearchType
|
||||||
|
|
||||||
|
|
||||||
def test_no_explicit_filter(tmp_path):
|
def test_no_word_filter(tmp_path):
|
||||||
# Arrange
|
# Arrange
|
||||||
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
|
word_filter = WordFilter(tmp_path, SearchType.Org)
|
||||||
embeddings, entries = arrange_content()
|
embeddings, entries = arrange_content()
|
||||||
q_with_no_filter = 'head tail'
|
q_with_no_filter = 'head tail'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = explicit_filter.can_filter(q_with_no_filter)
|
can_filter = word_filter.can_filter(q_with_no_filter)
|
||||||
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
ret_query, ret_entries, ret_emb = word_filter.apply(q_with_no_filter, entries.copy(), embeddings)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == False
|
assert can_filter == False
|
||||||
|
@ -23,15 +23,15 @@ def test_no_explicit_filter(tmp_path):
|
||||||
assert ret_entries == entries
|
assert ret_entries == entries
|
||||||
|
|
||||||
|
|
||||||
def test_explicit_exclude_filter(tmp_path):
|
def test_word_exclude_filter(tmp_path):
|
||||||
# Arrange
|
# Arrange
|
||||||
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
|
word_filter = WordFilter(tmp_path, SearchType.Org)
|
||||||
embeddings, entries = arrange_content()
|
embeddings, entries = arrange_content()
|
||||||
q_with_exclude_filter = 'head -"exclude_word" tail'
|
q_with_exclude_filter = 'head -"exclude_word" tail'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = explicit_filter.can_filter(q_with_exclude_filter)
|
can_filter = word_filter.can_filter(q_with_exclude_filter)
|
||||||
ret_query, ret_entries, ret_emb = explicit_filter.apply(q_with_exclude_filter, entries.copy(), embeddings)
|
ret_query, ret_entries, ret_emb = word_filter.apply(q_with_exclude_filter, entries.copy(), embeddings)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
|
@ -40,15 +40,15 @@ def test_explicit_exclude_filter(tmp_path):
|
||||||
assert ret_entries == [entries[0], entries[2]]
|
assert ret_entries == [entries[0], entries[2]]
|
||||||
|
|
||||||
|
|
||||||
def test_explicit_include_filter(tmp_path):
|
def test_word_include_filter(tmp_path):
|
||||||
# Arrange
|
# Arrange
|
||||||
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
|
word_filter = WordFilter(tmp_path, SearchType.Org)
|
||||||
embeddings, entries = arrange_content()
|
embeddings, entries = arrange_content()
|
||||||
query_with_include_filter = 'head +"include_word" tail'
|
query_with_include_filter = 'head +"include_word" tail'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = explicit_filter.can_filter(query_with_include_filter)
|
can_filter = word_filter.can_filter(query_with_include_filter)
|
||||||
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_filter, entries.copy(), embeddings)
|
ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_filter, entries.copy(), embeddings)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
||||||
|
@ -57,15 +57,15 @@ def test_explicit_include_filter(tmp_path):
|
||||||
assert ret_entries == [entries[2], entries[3]]
|
assert ret_entries == [entries[2], entries[3]]
|
||||||
|
|
||||||
|
|
||||||
def test_explicit_include_and_exclude_filter(tmp_path):
|
def test_word_include_and_exclude_filter(tmp_path):
|
||||||
# Arrange
|
# Arrange
|
||||||
explicit_filter = ExplicitFilter(tmp_path, SearchType.Org)
|
word_filter = WordFilter(tmp_path, SearchType.Org)
|
||||||
embeddings, entries = arrange_content()
|
embeddings, entries = arrange_content()
|
||||||
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
can_filter = explicit_filter.can_filter(query_with_include_and_exclude_filter)
|
can_filter = word_filter.can_filter(query_with_include_and_exclude_filter)
|
||||||
ret_query, ret_entries, ret_emb = explicit_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings)
|
ret_query, ret_entries, ret_emb = word_filter.apply(query_with_include_and_exclude_filter, entries.copy(), embeddings)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert can_filter == True
|
assert can_filter == True
|
Loading…
Reference in a new issue