Add method to extract filter terms from query to all filters

- Test the get_filter_term method in all 3 word, file, date filters
- Make the existing can_filter method by default in base filter abstract class
This commit is contained in:
Debanjum Singh Solanky 2023-08-27 19:14:50 -07:00
parent 150105505b
commit d5a2de6222
7 changed files with 62 additions and 12 deletions

View file

@ -12,9 +12,12 @@ class BaseFilter(ABC):
... ...
@abstractmethod @abstractmethod
def can_filter(self, raw_query: str) -> bool: def get_filter_terms(self, query: str) -> List[str]:
... ...
def can_filter(self, raw_query: str) -> bool:
return len(self.get_filter_terms(raw_query)) > 0
@abstractmethod @abstractmethod
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]: def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
... ...

View file

@ -3,6 +3,7 @@ import re
import logging import logging
from collections import defaultdict from collections import defaultdict
from datetime import timedelta, datetime from datetime import timedelta, datetime
from typing import List
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from math import inf from math import inf
@ -45,9 +46,9 @@ class DateFilter(BaseFilter):
continue continue
self.date_to_entry_ids[date_in_entry].add(id) self.date_to_entry_ids[date_in_entry].add(id)
def can_filter(self, raw_query): def get_filter_terms(self, query: str) -> List[str]:
"Check if query contains date filters" "Get all filter terms in query"
return self.extract_date_range(raw_query) is not None return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)]
def defilter(self, query): def defilter(self, query):
# remove date range filter from query # remove date range filter from query

View file

@ -3,6 +3,7 @@ import re
import fnmatch import fnmatch
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import List
# Internal Packages # Internal Packages
from khoj.search_filter.base_filter import BaseFilter from khoj.search_filter.base_filter import BaseFilter
@ -25,8 +26,9 @@ class FileFilter(BaseFilter):
for id, entry in enumerate(entries): for id, entry in enumerate(entries):
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id) self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
def can_filter(self, raw_query): def get_filter_terms(self, query: str) -> List[str]:
return re.search(self.file_filter_regex, raw_query) is not None "Get all filter terms in query"
return [f'file:"{term}"' for term in re.findall(self.file_filter_regex, query)]
def defilter(self, query: str) -> str: def defilter(self, query: str) -> str:
return re.sub(self.file_filter_regex, "", query).strip() return re.sub(self.file_filter_regex, "", query).strip()

View file

@ -2,6 +2,7 @@
import re import re
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import List
# Internal Packages # Internal Packages
from khoj.search_filter.base_filter import BaseFilter from khoj.search_filter.base_filter import BaseFilter
@ -36,12 +37,11 @@ class WordFilter(BaseFilter):
return self.word_to_entry_index return self.word_to_entry_index
def can_filter(self, raw_query): def get_filter_terms(self, query: str) -> List[str]:
"Check if query contains word filters" "Get all filter terms in query"
required_words = re.findall(self.required_regex, raw_query) required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)]
blocked_words = re.findall(self.blocked_regex, raw_query) blocked_terms = [f"-{blocked_term}" for blocked_term in re.findall(self.blocked_regex, query)]
return required_terms + blocked_terms
return len(required_words) != 0 or len(blocked_words) != 0
def defilter(self, query: str) -> str: def defilter(self, query: str) -> str:
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip() return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()

View file

@ -158,3 +158,23 @@ def test_date_filter_regex():
dtrange_match = re.findall(DateFilter().date_regex, "head tail") dtrange_match = re.findall(DateFilter().date_regex, "head tail")
assert dtrange_match == [] assert dtrange_match == []
def test_get_file_filter_terms():
dtrange_match = DateFilter().get_filter_terms('multi word head dt>"today" dt:"1984-01-01"')
assert dtrange_match == ["dt>'today'", "dt:'1984-01-01'"]
dtrange_match = DateFilter().get_filter_terms('head dt>"today" dt:"1984-01-01" multi word tail')
assert dtrange_match == ["dt>'today'", "dt:'1984-01-01'"]
dtrange_match = DateFilter().get_filter_terms('multi word head dt>="today" dt="1984-01-01"')
assert dtrange_match == ["dt>='today'", "dt='1984-01-01'"]
dtrange_match = DateFilter().get_filter_terms('dt<"multi word date" multi word tail')
assert dtrange_match == ["dt<'multi word date'"]
dtrange_match = DateFilter().get_filter_terms('head dt<="multi word date"')
assert dtrange_match == ["dt<='multi word date'"]
dtrange_match = DateFilter().get_filter_terms("head tail")
assert dtrange_match == []

View file

@ -99,6 +99,18 @@ def test_multiple_file_filter():
assert entry_indices == {0, 1, 2, 3} assert entry_indices == {0, 1, 2, 3}
def test_get_file_filter_terms():
# Arrange
file_filter = FileFilter()
q_with_filter_terms = 'head tail file:"file 1.org" file:"/path/to/dir/*.org"'
# Act
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
# Assert
assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"']
def arrange_content(): def arrange_content():
entries = [ entries = [
Entry(compiled="", raw="First Entry", file="file 1.org"), Entry(compiled="", raw="First Entry", file="file 1.org"),

View file

@ -67,6 +67,18 @@ def test_word_include_and_exclude_filter():
assert entry_indices == {2} assert entry_indices == {2}
def test_get_word_filter_terms():
# Arrange
word_filter = WordFilter()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act
filter_terms = word_filter.get_filter_terms(query_with_include_and_exclude_filter)
# Assert
assert filter_terms == ["+include_word", "-exclude_word"]
def arrange_content(): def arrange_content():
entries = [ entries = [
Entry(compiled="", raw="Minimal Entry"), Entry(compiled="", raw="Minimal Entry"),