mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Filter knowledge base used by chat to respond (#469)
- Overview - Allow applying word, file or date filters on your knowledge base from the chat interface - This will limit the portion of the knowledge base Khoj chat can use to respond to your query
This commit is contained in:
commit
bc5e60defb
10 changed files with 147 additions and 44 deletions
|
@ -702,7 +702,7 @@ async def chat(
|
||||||
) -> Response:
|
) -> Response:
|
||||||
perform_chat_checks()
|
perform_chat_checks()
|
||||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||||
compiled_references, inferred_queries = await extract_references_and_questions(
|
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||||
request, q, (n or 5), conversation_command
|
request, q, (n or 5), conversation_command
|
||||||
)
|
)
|
||||||
conversation_command = get_conversation_command(query=q, any_references=not is_none_or_empty(compiled_references))
|
conversation_command = get_conversation_command(query=q, any_references=not is_none_or_empty(compiled_references))
|
||||||
|
@ -713,7 +713,7 @@ async def chat(
|
||||||
|
|
||||||
# Get the (streamed) chat response from the LLM of choice.
|
# Get the (streamed) chat response from the LLM of choice.
|
||||||
llm_response = generate_chat_response(
|
llm_response = generate_chat_response(
|
||||||
q,
|
defiltered_query,
|
||||||
meta_log=state.processor_config.conversation.meta_log,
|
meta_log=state.processor_config.conversation.meta_log,
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
inferred_queries=inferred_queries,
|
inferred_queries=inferred_queries,
|
||||||
|
@ -770,19 +770,31 @@ async def extract_references_and_questions(
|
||||||
)
|
)
|
||||||
return compiled_references, inferred_queries
|
return compiled_references, inferred_queries
|
||||||
|
|
||||||
if conversation_type != ConversationCommand.General:
|
if conversation_type == ConversationCommand.General:
|
||||||
|
return compiled_references, inferred_queries
|
||||||
|
|
||||||
|
# Extract filter terms from user message
|
||||||
|
defiltered_query = q
|
||||||
|
filter_terms = []
|
||||||
|
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||||
|
filter_terms += filter.get_filter_terms(q)
|
||||||
|
defiltered_query = filter.defilter(q)
|
||||||
|
filters_in_query = " ".join(filter_terms)
|
||||||
|
|
||||||
# Infer search queries from user message
|
# Infer search queries from user message
|
||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||||
if state.processor_config.conversation.enable_offline_chat:
|
if state.processor_config.conversation.enable_offline_chat:
|
||||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||||
inferred_queries = extract_questions_offline(
|
inferred_queries = extract_questions_offline(
|
||||||
q, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||||
)
|
)
|
||||||
elif state.processor_config.conversation.openai_model:
|
elif state.processor_config.conversation.openai_model:
|
||||||
api_key = state.processor_config.conversation.openai_model.api_key
|
api_key = state.processor_config.conversation.openai_model.api_key
|
||||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||||
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
|
inferred_queries = extract_questions(
|
||||||
|
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
|
||||||
|
)
|
||||||
|
|
||||||
# Collate search results as context for GPT
|
# Collate search results as context for GPT
|
||||||
with timer("Searching knowledge base took", logger):
|
with timer("Searching knowledge base took", logger):
|
||||||
|
@ -790,8 +802,15 @@ async def extract_references_and_questions(
|
||||||
for query in inferred_queries:
|
for query in inferred_queries:
|
||||||
n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n
|
n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n
|
||||||
result_list.extend(
|
result_list.extend(
|
||||||
await search(query, request=request, n=n_items, r=True, score_threshold=-5.0, dedupe=False)
|
await search(
|
||||||
|
f"{query} {filters_in_query}",
|
||||||
|
request=request,
|
||||||
|
n=n_items,
|
||||||
|
r=True,
|
||||||
|
score_threshold=-5.0,
|
||||||
|
dedupe=False,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
compiled_references = [item.additional["compiled"] for item in result_list]
|
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||||
|
|
||||||
return compiled_references, inferred_queries
|
return compiled_references, inferred_queries, defiltered_query
|
||||||
|
|
|
@ -12,9 +12,12 @@ class BaseFilter(ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def can_filter(self, raw_query: str) -> bool:
|
def get_filter_terms(self, query: str) -> List[str]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
def can_filter(self, raw_query: str) -> bool:
|
||||||
|
return len(self.get_filter_terms(raw_query)) > 0
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
|
def apply(self, query: str, entries: List[Entry]) -> Tuple[str, Set[int]]:
|
||||||
...
|
...
|
||||||
|
|
|
@ -3,6 +3,7 @@ import re
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import timedelta, datetime
|
from datetime import timedelta, datetime
|
||||||
|
from typing import List
|
||||||
from dateutil.relativedelta import relativedelta
|
from dateutil.relativedelta import relativedelta
|
||||||
from math import inf
|
from math import inf
|
||||||
|
|
||||||
|
@ -45,9 +46,9 @@ class DateFilter(BaseFilter):
|
||||||
continue
|
continue
|
||||||
self.date_to_entry_ids[date_in_entry].add(id)
|
self.date_to_entry_ids[date_in_entry].add(id)
|
||||||
|
|
||||||
def can_filter(self, raw_query):
|
def get_filter_terms(self, query: str) -> List[str]:
|
||||||
"Check if query contains date filters"
|
"Get all filter terms in query"
|
||||||
return self.extract_date_range(raw_query) is not None
|
return [f"dt{item[0]}'{item[1]}'" for item in re.findall(self.date_regex, query)]
|
||||||
|
|
||||||
def defilter(self, query):
|
def defilter(self, query):
|
||||||
# remove date range filter from query
|
# remove date range filter from query
|
||||||
|
@ -62,7 +63,7 @@ class DateFilter(BaseFilter):
|
||||||
query_daterange = self.extract_date_range(query)
|
query_daterange = self.extract_date_range(query)
|
||||||
|
|
||||||
# if no date in query, return all entries
|
# if no date in query, return all entries
|
||||||
if query_daterange is None:
|
if query_daterange == []:
|
||||||
return query, set(range(len(entries)))
|
return query, set(range(len(entries)))
|
||||||
|
|
||||||
query = self.defilter(query)
|
query = self.defilter(query)
|
||||||
|
@ -95,7 +96,7 @@ class DateFilter(BaseFilter):
|
||||||
date_range_matches = re.findall(self.date_regex, query)
|
date_range_matches = re.findall(self.date_regex, query)
|
||||||
|
|
||||||
if len(date_range_matches) == 0:
|
if len(date_range_matches) == 0:
|
||||||
return None
|
return []
|
||||||
|
|
||||||
# extract, parse natural dates ranges from date range filter passed in query
|
# extract, parse natural dates ranges from date range filter passed in query
|
||||||
# e.g today maps to (start_of_day, start_of_tomorrow)
|
# e.g today maps to (start_of_day, start_of_tomorrow)
|
||||||
|
@ -110,7 +111,7 @@ class DateFilter(BaseFilter):
|
||||||
# >=yesterday maps to [start_of_yesterday, inf)
|
# >=yesterday maps to [start_of_yesterday, inf)
|
||||||
# <tomorrow maps to [0, start_of_tomorrow)
|
# <tomorrow maps to [0, start_of_tomorrow)
|
||||||
# ---
|
# ---
|
||||||
effective_date_range = [0, inf]
|
effective_date_range: List = [0, inf]
|
||||||
date_range_considering_comparator = []
|
date_range_considering_comparator = []
|
||||||
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
|
for cmp, (dtrange_start, dtrange_end) in date_ranges_from_filter:
|
||||||
if cmp == ">":
|
if cmp == ">":
|
||||||
|
@ -135,7 +136,7 @@ class DateFilter(BaseFilter):
|
||||||
]
|
]
|
||||||
|
|
||||||
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
|
if effective_date_range == [0, inf] or effective_date_range[0] > effective_date_range[1]:
|
||||||
return None
|
return []
|
||||||
else:
|
else:
|
||||||
return effective_date_range
|
return effective_date_range
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import re
|
||||||
import fnmatch
|
import fnmatch
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.search_filter.base_filter import BaseFilter
|
from khoj.search_filter.base_filter import BaseFilter
|
||||||
|
@ -25,8 +26,9 @@ class FileFilter(BaseFilter):
|
||||||
for id, entry in enumerate(entries):
|
for id, entry in enumerate(entries):
|
||||||
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
|
self.file_to_entry_map[getattr(entry, self.entry_key)].add(id)
|
||||||
|
|
||||||
def can_filter(self, raw_query):
|
def get_filter_terms(self, query: str) -> List[str]:
|
||||||
return re.search(self.file_filter_regex, raw_query) is not None
|
"Get all filter terms in query"
|
||||||
|
return [f'file:"{term}"' for term in re.findall(self.file_filter_regex, query)]
|
||||||
|
|
||||||
def defilter(self, query: str) -> str:
|
def defilter(self, query: str) -> str:
|
||||||
return re.sub(self.file_filter_regex, "", query).strip()
|
return re.sub(self.file_filter_regex, "", query).strip()
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
import re
|
import re
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from typing import List
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.search_filter.base_filter import BaseFilter
|
from khoj.search_filter.base_filter import BaseFilter
|
||||||
|
@ -36,12 +37,11 @@ class WordFilter(BaseFilter):
|
||||||
|
|
||||||
return self.word_to_entry_index
|
return self.word_to_entry_index
|
||||||
|
|
||||||
def can_filter(self, raw_query):
|
def get_filter_terms(self, query: str) -> List[str]:
|
||||||
"Check if query contains word filters"
|
"Get all filter terms in query"
|
||||||
required_words = re.findall(self.required_regex, raw_query)
|
required_terms = [f"+{required_term}" for required_term in re.findall(self.required_regex, query)]
|
||||||
blocked_words = re.findall(self.blocked_regex, raw_query)
|
blocked_terms = [f"-{blocked_term}" for blocked_term in re.findall(self.blocked_regex, query)]
|
||||||
|
return required_terms + blocked_terms
|
||||||
return len(required_words) != 0 or len(blocked_words) != 0
|
|
||||||
|
|
||||||
def defilter(self, query: str) -> str:
|
def defilter(self, query: str) -> str:
|
||||||
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
return re.sub(self.blocked_regex, "", re.sub(self.required_regex, "", query)).strip()
|
||||||
|
|
|
@ -68,13 +68,13 @@ def test_extract_date_range():
|
||||||
]
|
]
|
||||||
|
|
||||||
# Unparseable date filter specified in query
|
# Unparseable date filter specified in query
|
||||||
assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == None
|
assert DateFilter().extract_date_range('head dt:"Summer of 69" tail') == []
|
||||||
|
|
||||||
# No date filter specified in query
|
# No date filter specified in query
|
||||||
assert DateFilter().extract_date_range("head tail") == None
|
assert DateFilter().extract_date_range("head tail") == []
|
||||||
|
|
||||||
# Non intersecting date ranges
|
# Non intersecting date ranges
|
||||||
assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == None
|
assert DateFilter().extract_date_range('head dt>"1984-01-01" dt<"1984-01-01" tail') == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
|
@pytest.mark.filterwarnings("ignore:The localize method is no longer necessary.")
|
||||||
|
@ -158,3 +158,23 @@ def test_date_filter_regex():
|
||||||
|
|
||||||
dtrange_match = re.findall(DateFilter().date_regex, "head tail")
|
dtrange_match = re.findall(DateFilter().date_regex, "head tail")
|
||||||
assert dtrange_match == []
|
assert dtrange_match == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_file_filter_terms():
|
||||||
|
dtrange_match = DateFilter().get_filter_terms('multi word head dt>"today" dt:"1984-01-01"')
|
||||||
|
assert dtrange_match == ["dt>'today'", "dt:'1984-01-01'"]
|
||||||
|
|
||||||
|
dtrange_match = DateFilter().get_filter_terms('head dt>"today" dt:"1984-01-01" multi word tail')
|
||||||
|
assert dtrange_match == ["dt>'today'", "dt:'1984-01-01'"]
|
||||||
|
|
||||||
|
dtrange_match = DateFilter().get_filter_terms('multi word head dt>="today" dt="1984-01-01"')
|
||||||
|
assert dtrange_match == ["dt>='today'", "dt='1984-01-01'"]
|
||||||
|
|
||||||
|
dtrange_match = DateFilter().get_filter_terms('dt<"multi word date" multi word tail')
|
||||||
|
assert dtrange_match == ["dt<'multi word date'"]
|
||||||
|
|
||||||
|
dtrange_match = DateFilter().get_filter_terms('head dt<="multi word date"')
|
||||||
|
assert dtrange_match == ["dt<='multi word date'"]
|
||||||
|
|
||||||
|
dtrange_match = DateFilter().get_filter_terms("head tail")
|
||||||
|
assert dtrange_match == []
|
||||||
|
|
|
@ -99,6 +99,18 @@ def test_multiple_file_filter():
|
||||||
assert entry_indices == {0, 1, 2, 3}
|
assert entry_indices == {0, 1, 2, 3}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_file_filter_terms():
|
||||||
|
# Arrange
|
||||||
|
file_filter = FileFilter()
|
||||||
|
q_with_filter_terms = 'head tail file:"file 1.org" file:"/path/to/dir/*.org"'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
filter_terms = file_filter.get_filter_terms(q_with_filter_terms)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert filter_terms == ['file:"file 1.org"', 'file:"/path/to/dir/*.org"']
|
||||||
|
|
||||||
|
|
||||||
def arrange_content():
|
def arrange_content():
|
||||||
entries = [
|
entries = [
|
||||||
Entry(compiled="", raw="First Entry", file="file 1.org"),
|
Entry(compiled="", raw="First Entry", file="file 1.org"),
|
||||||
|
|
|
@ -209,6 +209,24 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
|
||||||
assert "Fujiang" in response_message
|
assert "Fujiang" in response_message
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_answer_using_file_filter(client_offline_chat):
|
||||||
|
# Arrange
|
||||||
|
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
|
||||||
|
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
|
||||||
|
message_list = []
|
||||||
|
populate_chat_history(message_list)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8")
|
||||||
|
answer_response = client_offline_chat.get(f"/api/chat?q={answer_query}&stream=true").content.decode("utf-8")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "Fujiang" not in no_answer_response
|
||||||
|
assert "Fujiang" in answer_response
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.chatquality
|
@pytest.mark.chatquality
|
||||||
def test_answer_not_known_using_notes_command(client_offline_chat):
|
def test_answer_not_known_using_notes_command(client_offline_chat):
|
||||||
|
|
|
@ -337,3 +337,19 @@ def test_answer_requires_multiple_independent_searches(chat_client):
|
||||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||||
"Expected Xi is older than Namita, but got: " + response_message
|
"Expected Xi is older than Namita, but got: " + response_message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def test_answer_using_file_filter(chat_client):
|
||||||
|
"Chat should be able to use search filters in the query"
|
||||||
|
# Act
|
||||||
|
query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"')
|
||||||
|
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||||
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||||
|
"Expected Xi is older than Namita, but got: " + response_message
|
||||||
|
)
|
||||||
|
|
|
@ -67,6 +67,18 @@ def test_word_include_and_exclude_filter():
|
||||||
assert entry_indices == {2}
|
assert entry_indices == {2}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_word_filter_terms():
|
||||||
|
# Arrange
|
||||||
|
word_filter = WordFilter()
|
||||||
|
query_with_include_and_exclude_filter = 'head +"include_word" -"exclude_word" tail'
|
||||||
|
|
||||||
|
# Act
|
||||||
|
filter_terms = word_filter.get_filter_terms(query_with_include_and_exclude_filter)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert filter_terms == ["+include_word", "-exclude_word"]
|
||||||
|
|
||||||
|
|
||||||
def arrange_content():
|
def arrange_content():
|
||||||
entries = [
|
entries = [
|
||||||
Entry(compiled="", raw="Minimal Entry"),
|
Entry(compiled="", raw="Minimal Entry"),
|
||||||
|
|
Loading…
Reference in a new issue