mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-20 06:55:08 +00:00
Add method to extract filter terms from query to all filters
- Test the get_filter_term method in all 3 word, file, date filters - Make the existing can_filter method by default in base filter abstract class
This commit is contained in:
parent
150105505b
commit
d5a2de6222
7 changed files with 62 additions and 12 deletions
|
@ -12,9 +12,12 @@ class BaseFilter(ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def can_filter(self, raw_query: str) -> bool:
|
def get_filter_terms(self, query: str) -> List[str]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def can_filter(self, raw_query: str) -> bool:
|
||||||
|
return len(self.get_filter_terms(raw_query)) > 0
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
|
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
|
||||||
...
|
...
|
||||||
|
|
|
@ -3,6 +3,7 @@ import re
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import timedelta, datetime
|
from datetime import timedelta, datetime
|
||||||
|
from typing import List
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
from math import inf
|
from math import inf
|
||||||
|
|
||||||
|
@ -45,9 +46,9 @@ class DateFilter(BaseFilter):
|
||||||
continue
|
continue
|
||||||
self.date_to_entry_ids[date_in_entry].add(id)
|
self.date_to_entry_ids[date_in_entry].add(id)
|
||||||
|
|
||||||
def can_filter(self, raw_query):
|
def get_filter_terms(self, query: str) -> List[str]:
|
||||||
"Check if query contains date filters"
|
"Get all filter terms in query"
|
||||||
return self.extract_date_range(raw_query) is not None
|
return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)]
|
||||||
|
|
||||||
def defilter(self, query):
|
def defilter(self, query):
|
||||||
# remove date range filter from query
|
# remove date range filter from query
|
||||||
|
|
|
@ -3,6 +3,7 @@ import re
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.search_filter.base_filter import BaseFilter
|
from khoj.search_filter.base_filter import BaseFilter
|
||||||
|
@ -25,8 +26,9 @@ class FileFilter(BaseFilter):
|
||||||
for id, entry in enumerate(entries):
|
for id, entry in enumerate(entries):
|
||||||
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
|
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
|
||||||
|
|
||||||
def can_filter(self, raw_query):
|
def get_filter_terms(self, query: str) -> List[str]:
|
||||||
return re.search(self.file_filter_regex, raw_query) is not None
|
"Get all filter terms in query"
|
||||||
|
return [f'file:"{term}"' for term in re.findall(self.file_filter_regex, query)]
|
||||||
|
|
||||||
def defilter(self, query: str) -> str:
|
def defilter(self, query: str) -> str:
|
||||||
return re.sub(self.file_filter_regex, "", query).strip()
|
return re.sub(self.file_filter_regex, "", query).strip()
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.search_filter.base_filter import BaseFilter
|
from khoj.search_filter.base_filter import BaseFilter
|
||||||
|
@ -36,12 +37,11 @@ class WordFilter(BaseFilter):
|
||||||
|
|
||||||
return self.word_to_entry_index
|
return self.word_to_entry_index
|
||||||
|
|
||||||
def can_filter(self, raw_query):
|
def get_filter_terms(self, query: str) -> List[str]:
|
||||||
"Check if query contains word filters"
|
"Get all filter terms in query"
|
||||||
required_words = re.findall(self.required_regex, raw_query)
|
required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)]
|
||||||
blocked_words = re.findall(self.blocked_regex, raw_query)
|
blocked_terms = [f"-{blocked_term}" for blocked_term in re.findall(self.blocked_regex, query)]
|
||||||
|
return required_terms + blocked_terms
|
||||||
return len(required_words) != 0 or len(blocked_words) != 0
|
|
||||||
|
|
||||||
def defilter(self, query: str) -> str:
|
def defilter(self, query: str) -> str:
|
||||||
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
||||||
|
|
|
@ -158,3 +158,23 @@ def test_date_filter_regex():
|
||||||
|
|
||||||
dtrange_match = re.findall(DateFilter().date_regex, "head tail")
|
dtrange_match = re.findall(DateFilter().date_regex, "head tail")
|
||||||
assert dtrange_match == []
|
assert dtrange_match == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_file_filter_terms():
|
||||||
|
dtrange_match = DateFilter().get_filter_terms('multi word head dt>"today" dt:"1984-01-01"')
|
||||||
|
assert dtrange_match == ["dt>'today'", "dt:'1984-01-01'"]
|
||||||
|
|
||||||
|
dtrange_match = DateFilter().get_filter_terms('head dt>"today" dt:"1984-01-01" multi word tail')
|
||||||
|
assert dtrange_match == ["dt>'today'", "dt:'1984-01-01'"]
|
||||||
|
|
||||||
|
dtrange_match = DateFilter().get_filter_terms('multi word head dt>="today" dt="1984-01-01"')
|
||||||
|
assert dtrange_match == ["dt>='today'", "dt='1984-01-01'"]
|
||||||
|
|
||||||
|
dtrange_match = DateFilter().get_filter_terms('dt<"multi word date" multi word tail')
|
||||||
|
assert dtrange_match == ["dt<'multi word date'"]
|
||||||
|
|
||||||
|
dtrange_match = DateFilter().get_filter_terms('head dt<="multi word date"')
|
||||||
|
assert dtrange_match == ["dt<='multi word date'"]
|
||||||
|
|
||||||
|
dtrange_match = DateFilter().get_filter_terms("head tail")
|
||||||
|
assert dtrange_match == []
|
||||||
|
|
|
@ -99,6 +99,18 @@ def test_multiple_file_filter():
|
||||||
assert entry_indices == {0, 1, 2, 3}
|
assert entry_indices == {0, 1, 2, 3}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_file_filter_terms():
|
||||||
|
# Arrange
|
||||||
|
file_filter = FileFilter()
|
||||||
|
q_with_filter_terms = 'head tail file:"file 1.org" file:"/path/to/dir/*.org"'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"']
|
||||||
|
|
||||||
|
|
||||||
def arrange_content():
|
def arrange_content():
|
||||||
entries = [
|
entries = [
|
||||||
Entry(compiled="", raw="First Entry", file="file 1.org"),
|
Entry(compiled="", raw="First Entry", file="file 1.org"),
|
||||||
|
|
|
@ -67,6 +67,18 @@ def test_word_include_and_exclude_filter():
|
||||||
assert entry_indices == {2}
|
assert entry_indices == {2}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_word_filter_terms():
|
||||||
|
# Arrange
|
||||||
|
word_filter = WordFilter()
|
||||||
|
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
filter_terms = word_filter.get_filter_terms(query_with_include_and_exclude_filter)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert filter_terms == ["+include_word", "-exclude_word"]
|
||||||
|
|
||||||
|
|
||||||
def arrange_content():
|
def arrange_content():
|
||||||
entries = [
|
entries = [
|
||||||
Entry(compiled="", raw="Minimal Entry"),
|
Entry(compiled="", raw="Minimal Entry"),
|
||||||
|
|
Loading…
Add table
Reference in a new issue