From d5a2de622247664289eec11bc4d92f2f1ac5942f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 27 Aug 2023 19:14:50 -0700 Subject: [PATCH] Add method to extract filter terms from query to all filters - Test the get_filter_term method in all 3 word, file, date filters - Make the existing can_filter method by default in base filter abstract class --- src/khoj/search_filter/base_filter.py | 5 ++++- src/khoj/search_filter/date_filter.py | 7 ++++--- src/khoj/search_filter/file_filter.py | 6 ++++-- src/khoj/search_filter/word_filter.py | 12 ++++++------ tests/test_date_filter.py | 20 ++++++++++++++++++++ tests/test_file_filter.py | 12 ++++++++++++ tests/test_word_filter.py | 12 ++++++++++++ 7 files changed, 62 insertions(+), 12 deletions(-) diff --git a/src/khoj/search_filter/base_filter.py b/src/khoj/search_filter/base_filter.py index aa4fa2e4..470f7341 100644 --- a/src/khoj/search_filter/base_filter.py +++ b/src/khoj/search_filter/base_filter.py @@ -12,9 +12,12 @@ class BaseFilter(ABC): ... @abstractmethod - def can_filter(self, raw_query: str) -> bool: + def get_filter_terms(self, query: str) -> List[str]: ... + def can_filter(self, raw_query: str) -> bool: + return len(self.get_filter_terms(raw_query)) > 0 + @abstractmethod def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]: ... diff --git a/src/khoj/search_filter/date_filter.py b/src/khoj/search_filter/date_filter.py index b612f16b..ca35f7c7 100644 --- a/src/khoj/search_filter/date_filter.py +++ b/src/khoj/search_filter/date_filter.py @@ -3,6 +3,7 @@ import re import logging from collections import defaultdict from datetime import timedelta, datetime +from typing import List from dateutil.relativedelta import relativedelta from math import inf @@ -45,9 +46,9 @@ class DateFilter(BaseFilter): continue self.date_to_entry_ids[date_in_entry].add(id) - def can_filter(self, raw_query): - "Check if query contains date filters" - return self.extract_date_range(raw_query) is not None + def get_filter_terms(self, query: str) -> List[str]: + "Get all filter terms in query" + return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)] def defilter(self, query): # remove date range filter from query diff --git a/src/khoj/search_filter/file_filter.py b/src/khoj/search_filter/file_filter.py index 26f416fe..420bf9e7 100644 --- a/src/khoj/search_filter/file_filter.py +++ b/src/khoj/search_filter/file_filter.py @@ -3,6 +3,7 @@ import re import fnmatch import logging from collections import defaultdict +from typing import List # Internal Packages from khoj.search_filter.base_filter import BaseFilter @@ -25,8 +26,9 @@ class FileFilter(BaseFilter): for id, entry in enumerate(entries): self.file_to_entry_map[getattr(entry, self.entry_key)].add(id) - def can_filter(self, raw_query): - return re.search(self.file_filter_regex, raw_query) is not None + def get_filter_terms(self, query: str) -> List[str]: + "Get all filter terms in query" + return [f'file:"{term}"' for term in re.findall(self.file_filter_regex, query)] def defilter(self, query: str) -> str: return re.sub(self.file_filter_regex, "", query).strip() diff --git a/src/khoj/search_filter/word_filter.py b/src/khoj/search_filter/word_filter.py index 9c98e848..ebf64b34 100644 --- a/src/khoj/search_filter/word_filter.py +++ b/src/khoj/search_filter/word_filter.py @@ -2,6 +2,7 @@ import re import logging from collections import defaultdict +from typing import List # Internal Packages from khoj.search_filter.base_filter import BaseFilter @@ -36,12 +37,11 @@ class WordFilter(BaseFilter): return self.word_to_entry_index - def can_filter(self, raw_query): - "Check if query contains word filters" - required_words = re.findall(self.required_regex, raw_query) - blocked_words = re.findall(self.blocked_regex, raw_query) - - return len(required_words) != 0 or len(blocked_words) != 0 + def get_filter_terms(self, query: str) -> List[str]: + "Get all filter terms in query" + required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)] + blocked_terms = [f"-{blocked_term}" for blocked_term in re.findall(self.blocked_regex, query)] + return required_terms + blocked_terms def defilter(self, query: str) -> str: return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip() diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 05f19ae3..90cc5b93 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -158,3 +158,23 @@ def test_date_filter_regex(): dtrange_match = re.findall(DateFilter().date_regex, "head tail") assert dtrange_match == [] + + +def test_get_file_filter_terms(): + dtrange_match = DateFilter().get_filter_terms('multi word head dt>"today" dt:"1984-01-01"') + assert dtrange_match == ["dt>'today'", "dt:'1984-01-01'"] + + dtrange_match = DateFilter().get_filter_terms('head dt>"today" dt:"1984-01-01" multi word tail') + assert dtrange_match == ["dt>'today'", "dt:'1984-01-01'"] + + dtrange_match = DateFilter().get_filter_terms('multi word head dt>="today" dt="1984-01-01"') + assert dtrange_match == ["dt>='today'", "dt='1984-01-01'"] + + dtrange_match = DateFilter().get_filter_terms('dt<"multi word date" multi word tail') + assert dtrange_match == ["dt<'multi word date'"] + + dtrange_match = DateFilter().get_filter_terms('head dt<="multi word date"') + assert dtrange_match == ["dt<='multi word date'"] + + dtrange_match = DateFilter().get_filter_terms("head tail") + assert dtrange_match == [] diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index 2ae82f66..ed632d32 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -99,6 +99,18 @@ def test_multiple_file_filter(): assert entry_indices == {0, 1, 2, 3} +def test_get_file_filter_terms(): + # Arrange + file_filter = FileFilter() + q_with_filter_terms = 'head tail file:"file 1.org" file:"/path/to/dir/*.org"' + + # Act + filter_terms = file_filter.get_filter_terms(q_with_filter_terms) + + # Assert + assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"'] + + def arrange_content(): entries = [ Entry(compiled="", raw="First Entry", file="file 1.org"), diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index 82d0dce8..04f45506 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -67,6 +67,18 @@ def test_word_include_and_exclude_filter(): assert entry_indices == {2} +def test_get_word_filter_terms(): + # Arrange + word_filter = WordFilter() + query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail' + + # Act + filter_terms = word_filter.get_filter_terms(query_with_include_and_exclude_filter) + + # Assert + assert filter_terms == ["+include_word", "-exclude_word"] + + def arrange_content(): entries = [ Entry(compiled="", raw="Minimal Entry"),