diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 4f7c6f42..ab547be5 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -702,7 +702,7 @@ async def chat( ) -> Response: perform_chat_checks() conversation_command = get_conversation_command(query=q, any_references=True) - compiled_references, inferred_queries = await extract_references_and_questions( + compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( request, q, (n or 5), conversation_command ) conversation_command = get_conversation_command(query=q, any_references=not is_none_or_empty(compiled_references)) @@ -713,7 +713,7 @@ async def chat( # Get the (streamed) chat response from the LLM of choice. llm_response = generate_chat_response( - q, + defiltered_query, meta_log=state.processor_config.conversation.meta_log, compiled_references=compiled_references, inferred_queries=inferred_queries, @@ -770,28 +770,47 @@ async def extract_references_and_questions( ) return compiled_references, inferred_queries - if conversation_type != ConversationCommand.General: - # Infer search queries from user message - with timer("Extracting search queries took", logger): - # If we've reached here, either the user has enabled offline chat or the openai model is enabled. - if state.processor_config.conversation.enable_offline_chat: - loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model - inferred_queries = extract_questions_offline( - q, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False - ) - elif state.processor_config.conversation.openai_model: - api_key = state.processor_config.conversation.openai_model.api_key - chat_model = state.processor_config.conversation.openai_model.chat_model - inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log) + if conversation_type == ConversationCommand.General: + return compiled_references, inferred_queries - # Collate search results as context for GPT - with timer("Searching knowledge base took", logger): - result_list = [] - for query in inferred_queries: - n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n - result_list.extend( - await search(query, request=request, n=n_items, r=True, score_threshold=-5.0, dedupe=False) - ) - compiled_references = [item.additional["compiled"] for item in result_list] + # Extract filter terms from user message + defiltered_query = q + filter_terms = [] + for filter in [DateFilter(), WordFilter(), FileFilter()]: + filter_terms += filter.get_filter_terms(q) + defiltered_query = filter.defilter(q) + filters_in_query = " ".join(filter_terms) - return compiled_references, inferred_queries + # Infer search queries from user message + with timer("Extracting search queries took", logger): + # If we've reached here, either the user has enabled offline chat or the openai model is enabled. + if state.processor_config.conversation.enable_offline_chat: + loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model + inferred_queries = extract_questions_offline( + defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False + ) + elif state.processor_config.conversation.openai_model: + api_key = state.processor_config.conversation.openai_model.api_key + chat_model = state.processor_config.conversation.openai_model.chat_model + inferred_queries = extract_questions( + defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log + ) + + # Collate search results as context for GPT + with timer("Searching knowledge base took", logger): + result_list = [] + for query in inferred_queries: + n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n + result_list.extend( + await search( + f"{query} {filters_in_query}", + request=request, + n=n_items, + r=True, + score_threshold=-5.0, + dedupe=False, + ) + ) + compiled_references = [item.additional["compiled"] for item in result_list] + + return compiled_references, inferred_queries, defiltered_query diff --git a/src/khoj/search_filter/base_filter.py b/src/khoj/search_filter/base_filter.py index aa4fa2e4..470f7341 100644 --- a/src/khoj/search_filter/base_filter.py +++ b/src/khoj/search_filter/base_filter.py @@ -12,9 +12,12 @@ class BaseFilter(ABC): ... @abstractmethod - def can_filter(self, raw_query: str) -> bool: + def get_filter_terms(self, query: str) -> List[str]: ... + def can_filter(self, raw_query: str) -> bool: + return len(self.get_filter_terms(raw_query)) > 0 + @abstractmethod def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]: ... diff --git a/src/khoj/search_filter/date_filter.py b/src/khoj/search_filter/date_filter.py index b612f16b..16a418be 100644 --- a/src/khoj/search_filter/date_filter.py +++ b/src/khoj/search_filter/date_filter.py @@ -3,6 +3,7 @@ import re import logging from collections import defaultdict from datetime import timedelta, datetime +from typing import List from dateutil.relativedelta import relativedelta from math import inf @@ -45,9 +46,9 @@ class DateFilter(BaseFilter): continue self.date_to_entry_ids[date_in_entry].add(id) - def can_filter(self, raw_query): - "Check if query contains date filters" - return self.extract_date_range(raw_query) is not None + def get_filter_terms(self, query: str) -> List[str]: + "Get all filter terms in query" + return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)] def defilter(self, query): # remove date range filter from query @@ -62,7 +63,7 @@ class DateFilter(BaseFilter): query_daterange = self.extract_date_range(query) # if no date in query, return all entries - if query_daterange is None: + if query_daterange == []: return query, set(range(len(entries))) query = self.defilter(query) @@ -95,7 +96,7 @@ class DateFilter(BaseFilter): date_range_matches = re.findall(self.date_regex, query) if len(date_range_matches) == 0: - return None + return [] # extract, parse natural dates ranges from date range filter passed in query # e.g today maps to (start_of_day, start_of_tomorrow) @@ -110,7 +111,7 @@ class DateFilter(BaseFilter): # >=yesterday maps to [start_of_yesterday, inf) # ": @@ -135,7 +136,7 @@ class DateFilter(BaseFilter): ] if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]: - return None + return [] else: return effective_date_range diff --git a/src/khoj/search_filter/file_filter.py b/src/khoj/search_filter/file_filter.py index 26f416fe..420bf9e7 100644 --- a/src/khoj/search_filter/file_filter.py +++ b/src/khoj/search_filter/file_filter.py @@ -3,6 +3,7 @@ import re import fnmatch import logging from collections import defaultdict +from typing import List # Internal Packages from khoj.search_filter.base_filter import BaseFilter @@ -25,8 +26,9 @@ class FileFilter(BaseFilter): for id, entry in enumerate(entries): self.file_to_entry_map[getattr(entry, self.entry_key)].add(id) - def can_filter(self, raw_query): - return re.search(self.file_filter_regex, raw_query) is not None + def get_filter_terms(self, query: str) -> List[str]: + "Get all filter terms in query" + return [f'file:"{term}"' for term in re.findall(self.file_filter_regex, query)] def defilter(self, query: str) -> str: return re.sub(self.file_filter_regex, "", query).strip() diff --git a/src/khoj/search_filter/word_filter.py b/src/khoj/search_filter/word_filter.py index 9c98e848..ebf64b34 100644 --- a/src/khoj/search_filter/word_filter.py +++ b/src/khoj/search_filter/word_filter.py @@ -2,6 +2,7 @@ import re import logging from collections import defaultdict +from typing import List # Internal Packages from khoj.search_filter.base_filter import BaseFilter @@ -36,12 +37,11 @@ class WordFilter(BaseFilter): return self.word_to_entry_index - def can_filter(self, raw_query): - "Check if query contains word filters" - required_words = re.findall(self.required_regex, raw_query) - blocked_words = re.findall(self.blocked_regex, raw_query) - - return len(required_words) != 0 or len(blocked_words) != 0 + def get_filter_terms(self, query: str) -> List[str]: + "Get all filter terms in query" + required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)] + blocked_terms = [f"-{blocked_term}" for blocked_term in re.findall(self.blocked_regex, query)] + return required_terms + blocked_terms def defilter(self, query: str) -> str: return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip() diff --git a/tests/test_date_filter.py b/tests/test_date_filter.py index 05f19ae3..00e6bfa3 100644 --- a/tests/test_date_filter.py +++ b/tests/test_date_filter.py @@ -68,13 +68,13 @@ def test_extract_date_range(): ] # Unparseable date filter specified in query - assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None + assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == [] # No date filter specified in query - assert DateFilter().extract_date_range("head tail") == None + assert DateFilter().extract_date_range("head tail") == [] # Non intersecting date ranges - assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None + assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == [] @pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.") @@ -158,3 +158,23 @@ def test_date_filter_regex(): dtrange_match = re.findall(DateFilter().date_regex, "head tail") assert dtrange_match == [] + + +def test_get_file_filter_terms(): + dtrange_match = DateFilter().get_filter_terms('multi word head dt>"today" dt:"1984-01-01"') + assert dtrange_match == ["dt>'today'", "dt:'1984-01-01'"] + + dtrange_match = DateFilter().get_filter_terms('head dt>"today" dt:"1984-01-01" multi word tail') + assert dtrange_match == ["dt>'today'", "dt:'1984-01-01'"] + + dtrange_match = DateFilter().get_filter_terms('multi word head dt>="today" dt="1984-01-01"') + assert dtrange_match == ["dt>='today'", "dt='1984-01-01'"] + + dtrange_match = DateFilter().get_filter_terms('dt<"multi word date" multi word tail') + assert dtrange_match == ["dt<'multi word date'"] + + dtrange_match = DateFilter().get_filter_terms('head dt<="multi word date"') + assert dtrange_match == ["dt<='multi word date'"] + + dtrange_match = DateFilter().get_filter_terms("head tail") + assert dtrange_match == [] diff --git a/tests/test_file_filter.py b/tests/test_file_filter.py index 2ae82f66..ed632d32 100644 --- a/tests/test_file_filter.py +++ b/tests/test_file_filter.py @@ -99,6 +99,18 @@ def test_multiple_file_filter(): assert entry_indices == {0, 1, 2, 3} +def test_get_file_filter_terms(): + # Arrange + file_filter = FileFilter() + q_with_filter_terms = 'head tail file:"file 1.org" file:"/path/to/dir/*.org"' + + # Act + filter_terms = file_filter.get_filter_terms(q_with_filter_terms) + + # Assert + assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"'] + + def arrange_content(): entries = [ Entry(compiled="", raw="First Entry", file="file 1.org"), diff --git a/tests/test_gpt4all_chat_director.py b/tests/test_gpt4all_chat_director.py index 6da7f759..3e72a7e2 100644 --- a/tests/test_gpt4all_chat_director.py +++ b/tests/test_gpt4all_chat_director.py @@ -209,6 +209,24 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat): assert "Fujiang" in response_message +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_answer_using_file_filter(client_offline_chat): + # Arrange + no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"') + answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"') + message_list = [] + populate_chat_history(message_list) + + # Act + no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8") + answer_response = client_offline_chat.get(f"/api/chat?q={answer_query}&stream=true").content.decode("utf-8") + + # Assert + assert "Fujiang" not in no_answer_response + assert "Fujiang" in answer_response + + # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality def test_answer_not_known_using_notes_command(client_offline_chat): diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index 4f05fc52..abbd1831 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -337,3 +337,19 @@ def test_answer_requires_multiple_independent_searches(chat_client): assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( "Expected Xi is older than Namita, but got: " + response_message ) + + +# ---------------------------------------------------------------------------------------------------- +def test_answer_using_file_filter(chat_client): + "Chat should be able to use search filters in the query" + # Act + query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"') + response = chat_client.get(f"/api/chat?q={query}&stream=true") + response_message = response.content.decode("utf-8") + + # Assert + expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"] + assert response.status_code == 200 + assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( + "Expected Xi is older than Namita, but got: " + response_message + ) diff --git a/tests/test_word_filter.py b/tests/test_word_filter.py index 82d0dce8..04f45506 100644 --- a/tests/test_word_filter.py +++ b/tests/test_word_filter.py @@ -67,6 +67,18 @@ def test_word_include_and_exclude_filter(): assert entry_indices == {2} +def test_get_word_filter_terms(): + # Arrange + word_filter = WordFilter() + query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail' + + # Act + filter_terms = word_filter.get_filter_terms(query_with_include_and_exclude_filter) + + # Assert + assert filter_terms == ["+include_word", "-exclude_word"] + + def arrange_content(): entries = [ Entry(compiled="", raw="Minimal Entry"),