Filter knowledge base used by chat to respond (#469)

- Overview
  - Allow applying word, file or date filters on your knowledge base from the chat interface
  - This will limit the portion of the knowledge base Khoj chat can use to respond to your query
This commit is contained in:
Debanjum 2023-08-28 09:32:33 -07:00 committed by GitHub
commit bc5e60defb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 147 additions and 44 deletions

View file

@ -702,7 +702,7 @@ async def chat(
) -> Response: ) -> Response:
perform_chat_checks() perform_chat_checks()
conversation_command = get_conversation_command(query=q, any_references=True) conversation_command = get_conversation_command(query=q, any_references=True)
compiled_references, inferred_queries = await extract_references_and_questions( compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, q, (n or 5), conversation_command request, q, (n or 5), conversation_command
) )
conversation_command = get_conversation_command(query=q, any_references=not is_none_or_empty(compiled_references)) conversation_command = get_conversation_command(query=q, any_references=not is_none_or_empty(compiled_references))
@ -713,7 +713,7 @@ async def chat(
# Get the (streamed) chat response from the LLM of choice. # Get the (streamed) chat response from the LLM of choice.
llm_response = generate_chat_response( llm_response = generate_chat_response(
q, defiltered_query,
meta_log=state.processor_config.conversation.meta_log, meta_log=state.processor_config.conversation.meta_log,
compiled_references=compiled_references, compiled_references=compiled_references,
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
@ -770,19 +770,31 @@ async def extract_references_and_questions(
) )
return compiled_references, inferred_queries return compiled_references, inferred_queries
if conversation_type != ConversationCommand.General: if conversation_type == ConversationCommand.General:
return compiled_references, inferred_queries
# Extract filter terms from user message
defiltered_query = q
filter_terms = []
for filter in [DateFilter(), WordFilter(), FileFilter()]:
filter_terms += filter.get_filter_terms(q)
defiltered_query = filter.defilter(q)
filters_in_query = " ".join(filter_terms)
# Infer search queries from user message # Infer search queries from user message
with timer("Extracting search queries took", logger): with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled. # If we've reached here, either the user has enabled offline chat or the openai model is enabled.
if state.processor_config.conversation.enable_offline_chat: if state.processor_config.conversation.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
inferred_queries = extract_questions_offline( inferred_queries = extract_questions_offline(
q, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
) )
elif state.processor_config.conversation.openai_model: elif state.processor_config.conversation.openai_model:
api_key = state.processor_config.conversation.openai_model.api_key api_key = state.processor_config.conversation.openai_model.api_key
chat_model = state.processor_config.conversation.openai_model.chat_model chat_model = state.processor_config.conversation.openai_model.chat_model
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log) inferred_queries = extract_questions(
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
)
# Collate search results as context for GPT # Collate search results as context for GPT
with timer("Searching knowledge base took", logger): with timer("Searching knowledge base took", logger):
@ -790,8 +802,15 @@ async def extract_references_and_questions(
for query in inferred_queries: for query in inferred_queries:
n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n
result_list.extend( result_list.extend(
await search(query, request=request, n=n_items, r=True, score_threshold=-5.0, dedupe=False) await search(
f"{query} {filters_in_query}",
request=request,
n=n_items,
r=True,
score_threshold=-5.0,
dedupe=False,
)
) )
compiled_references = [item.additional["compiled"] for item in result_list] compiled_references = [item.additional["compiled"] for item in result_list]
return compiled_references, inferred_queries return compiled_references, inferred_queries, defiltered_query

View file

@ -12,9 +12,12 @@ class BaseFilter(ABC):
... ...
@abstractmethod @abstractmethod
def can_filter(self, raw_query: str) -> bool: def get_filter_terms(self, query: str) -> List[str]:
... ...
def can_filter(self, raw_query: str) -> bool:
return len(self.get_filter_terms(raw_query)) > 0
@abstractmethod @abstractmethod
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]: def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
... ...

View file

@ -3,6 +3,7 @@ import re
import logging import logging
from collections import defaultdict from collections import defaultdict
from datetime import timedelta, datetime from datetime import timedelta, datetime
from typing import List
from dateutil.relativedelta import relativedelta from dateutil.relativedelta import relativedelta
from math import inf from math import inf
@ -45,9 +46,9 @@ class DateFilter(BaseFilter):
continue continue
self.date_to_entry_ids[date_in_entry].add(id) self.date_to_entry_ids[date_in_entry].add(id)
def can_filter(self, raw_query): def get_filter_terms(self, query: str) -> List[str]:
"Check if query contains date filters" "Get all filter terms in query"
return self.extract_date_range(raw_query) is not None return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)]
def defilter(self, query): def defilter(self, query):
# remove date range filter from query # remove date range filter from query
@ -62,7 +63,7 @@ class DateFilter(BaseFilter):
query_daterange = self.extract_date_range(query) query_daterange = self.extract_date_range(query)
# if no date in query, return all entries # if no date in query, return all entries
if query_daterange is None: if query_daterange == []:
return query, set(range(len(entries))) return query, set(range(len(entries)))
query = self.defilter(query) query = self.defilter(query)
@ -95,7 +96,7 @@ class DateFilter(BaseFilter):
date_range_matches = re.findall(self.date_regex, query) date_range_matches = re.findall(self.date_regex, query)
if len(date_range_matches) == 0: if len(date_range_matches) == 0:
return None return []
# extract, parse natural dates ranges from date range filter passed in query # extract, parse natural dates ranges from date range filter passed in query
# e.g today maps to (start_of_day, start_of_tomorrow) # e.g today maps to (start_of_day, start_of_tomorrow)
@ -110,7 +111,7 @@ class DateFilter(BaseFilter):
# >=yesterday maps to [start_of_yesterday, inf) # >=yesterday maps to [start_of_yesterday, inf)
# <tomorrow maps to [0, start_of_tomorrow) # <tomorrow maps to [0, start_of_tomorrow)
# --- # ---
effective_date_range = [0, inf] effective_date_range: List = [0, inf]
date_range_considering_comparator = [] date_range_considering_comparator = []
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter: for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
if cmp == ">": if cmp == ">":
@ -135,7 +136,7 @@ class DateFilter(BaseFilter):
] ]
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]: if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
return None return []
else: else:
return effective_date_range return effective_date_range

View file

@ -3,6 +3,7 @@ import re
import fnmatch import fnmatch
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import List
# Internal Packages # Internal Packages
from khoj.search_filter.base_filter import BaseFilter from khoj.search_filter.base_filter import BaseFilter
@ -25,8 +26,9 @@ class FileFilter(BaseFilter):
for id, entry in enumerate(entries): for id, entry in enumerate(entries):
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id) self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
def can_filter(self, raw_query): def get_filter_terms(self, query: str) -> List[str]:
return re.search(self.file_filter_regex, raw_query) is not None "Get all filter terms in query"
return [f'file:"{term}"' for term in re.findall(self.file_filter_regex, query)]
def defilter(self, query: str) -> str: def defilter(self, query: str) -> str:
return re.sub(self.file_filter_regex, "", query).strip() return re.sub(self.file_filter_regex, "", query).strip()

View file

@ -2,6 +2,7 @@
import re import re
import logging import logging
from collections import defaultdict from collections import defaultdict
from typing import List
# Internal Packages # Internal Packages
from khoj.search_filter.base_filter import BaseFilter from khoj.search_filter.base_filter import BaseFilter
@ -36,12 +37,11 @@ class WordFilter(BaseFilter):
return self.word_to_entry_index return self.word_to_entry_index
def can_filter(self, raw_query): def get_filter_terms(self, query: str) -> List[str]:
"Check if query contains word filters" "Get all filter terms in query"
required_words = re.findall(self.required_regex, raw_query) required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)]
blocked_words = re.findall(self.blocked_regex, raw_query) blocked_terms = [f"-{blocked_term}" for blocked_term in re.findall(self.blocked_regex, query)]
return required_terms + blocked_terms
return len(required_words) != 0 or len(blocked_words) != 0
def defilter(self, query: str) -> str: def defilter(self, query: str) -> str:
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip() return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()

View file

@ -68,13 +68,13 @@ def test_extract_date_range():
] ]
# Unparseable date filter specified in query # Unparseable date filter specified in query
assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == []
# No date filter specified in query # No date filter specified in query
assert DateFilter().extract_date_range("head tail") == None assert DateFilter().extract_date_range("head tail") == []
# Non intersecting date ranges # Non intersecting date ranges
assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == []
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.") @pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
@ -158,3 +158,23 @@ def test_date_filter_regex():
dtrange_match = re.findall(DateFilter().date_regex, "head tail") dtrange_match = re.findall(DateFilter().date_regex, "head tail")
assert dtrange_match == [] assert dtrange_match == []
def test_get_file_filter_terms():
dtrange_match = DateFilter().get_filter_terms('multi word head dt>"today" dt:"1984-01-01"')
assert dtrange_match == ["dt>'today'", "dt:'1984-01-01'"]
dtrange_match = DateFilter().get_filter_terms('head dt>"today" dt:"1984-01-01" multi word tail')
assert dtrange_match == ["dt>'today'", "dt:'1984-01-01'"]
dtrange_match = DateFilter().get_filter_terms('multi word head dt>="today" dt="1984-01-01"')
assert dtrange_match == ["dt>='today'", "dt='1984-01-01'"]
dtrange_match = DateFilter().get_filter_terms('dt<"multi word date" multi word tail')
assert dtrange_match == ["dt<'multi word date'"]
dtrange_match = DateFilter().get_filter_terms('head dt<="multi word date"')
assert dtrange_match == ["dt<='multi word date'"]
dtrange_match = DateFilter().get_filter_terms("head tail")
assert dtrange_match == []

View file

@ -99,6 +99,18 @@ def test_multiple_file_filter():
assert entry_indices == {0, 1, 2, 3} assert entry_indices == {0, 1, 2, 3}
def test_get_file_filter_terms():
# Arrange
file_filter = FileFilter()
q_with_filter_terms = 'head tail file:"file 1.org" file:"/path/to/dir/*.org"'
# Act
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
# Assert
assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"']
def arrange_content(): def arrange_content():
entries = [ entries = [
Entry(compiled="", raw="First Entry", file="file 1.org"), Entry(compiled="", raw="First Entry", file="file 1.org"),

View file

@ -209,6 +209,24 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
assert "Fujiang" in response_message assert "Fujiang" in response_message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_using_file_filter(client_offline_chat):
# Arrange
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
message_list = []
populate_chat_history(message_list)
# Act
no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8")
answer_response = client_offline_chat.get(f"/api/chat?q={answer_query}&stream=true").content.decode("utf-8")
# Assert
assert "Fujiang" not in no_answer_response
assert "Fujiang" in answer_response
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality @pytest.mark.chatquality
def test_answer_not_known_using_notes_command(client_offline_chat): def test_answer_not_known_using_notes_command(client_offline_chat):

View file

@ -337,3 +337,19 @@ def test_answer_requires_multiple_independent_searches(chat_client):
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected Xi is older than Namita, but got: " + response_message "Expected Xi is older than Namita, but got: " + response_message
) )
# ----------------------------------------------------------------------------------------------------
def test_answer_using_file_filter(chat_client):
"Chat should be able to use search filters in the query"
# Act
query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"')
response = chat_client.get(f"/api/chat?q={query}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected Xi is older than Namita, but got: " + response_message
)

View file

@ -67,6 +67,18 @@ def test_word_include_and_exclude_filter():
assert entry_indices == {2} assert entry_indices == {2}
def test_get_word_filter_terms():
# Arrange
word_filter = WordFilter()
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
# Act
filter_terms = word_filter.get_filter_terms(query_with_include_and_exclude_filter)
# Assert
assert filter_terms == ["+include_word", "-exclude_word"]
def arrange_content(): def arrange_content():
entries = [ entries = [
Entry(compiled="", raw="Minimal Entry"), Entry(compiled="", raw="Minimal Entry"),