Enable passing search query filters via chat and test it

This commit is contained in:
Debanjum Singh Solanky 2023-08-28 00:14:40 -07:00
parent 794bad8bcb
commit 01b310635e
3 changed files with 78 additions and 25 deletions

View file

@ -702,7 +702,7 @@ async def chat(
) -> Response:
perform_chat_checks()
conversation_command = get_conversation_command(query=q, any_references=True)
compiled_references, inferred_queries = await extract_references_and_questions(
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
request, q, (n or 5), conversation_command
)
conversation_command = get_conversation_command(query=q, any_references=not is_none_or_empty(compiled_references))
@ -713,7 +713,7 @@ async def chat(
# Get the (streamed) chat response from the LLM of choice.
llm_response = generate_chat_response(
q,
defiltered_query,
meta_log=state.processor_config.conversation.meta_log,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
@ -770,28 +770,47 @@ async def extract_references_and_questions(
)
return compiled_references, inferred_queries
if conversation_type != ConversationCommand.General:
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
if state.processor_config.conversation.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
inferred_queries = extract_questions_offline(
q, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
)
elif state.processor_config.conversation.openai_model:
api_key = state.processor_config.conversation.openai_model.api_key
chat_model = state.processor_config.conversation.openai_model.chat_model
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
if conversation_type == ConversationCommand.General:
return compiled_references, inferred_queries
# Collate search results as context for GPT
with timer("Searching knowledge base took", logger):
result_list = []
for query in inferred_queries:
n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n
result_list.extend(
await search(query, request=request, n=n_items, r=True, score_threshold=-5.0, dedupe=False)
)
compiled_references = [item.additional["compiled"] for item in result_list]
# Extract filter terms from user message
defiltered_query = q
filter_terms = []
for filter in [DateFilter(), WordFilter(), FileFilter()]:
filter_terms += filter.get_filter_terms(q)
defiltered_query = filter.defilter(q)
filters_in_query = " ".join(filter_terms)
return compiled_references, inferred_queries
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
if state.processor_config.conversation.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
)
elif state.processor_config.conversation.openai_model:
api_key = state.processor_config.conversation.openai_model.api_key
chat_model = state.processor_config.conversation.openai_model.chat_model
inferred_queries = extract_questions(
defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log
)
# Collate search results as context for GPT
with timer("Searching knowledge base took", logger):
result_list = []
for query in inferred_queries:
n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n
result_list.extend(
await search(
f"{query} {filters_in_query}",
request=request,
n=n_items,
r=True,
score_threshold=-5.0,
dedupe=False,
)
)
compiled_references = [item.additional["compiled"] for item in result_list]
return compiled_references, inferred_queries, defiltered_query

View file

@ -209,6 +209,24 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat):
assert "Fujiang" in response_message
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_using_file_filter(client_offline_chat):
# Arrange
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
message_list = []
populate_chat_history(message_list)
# Act
no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8")
answer_response = client_offline_chat.get(f"/api/chat?q={answer_query}&stream=true").content.decode("utf-8")
# Assert
assert "Fujiang" not in no_answer_response
assert "Fujiang" in answer_response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_not_known_using_notes_command(client_offline_chat):

View file

@ -337,3 +337,19 @@ def test_answer_requires_multiple_independent_searches(chat_client):
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected Xi is older than Namita, but got: " + response_message
)
# ----------------------------------------------------------------------------------------------------
def test_answer_using_file_filter(chat_client):
"Chat should be able to use search filters in the query"
# Act
query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"')
response = chat_client.get(f"/api/chat?q={query}&stream=true")
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected Xi is older than Namita, but got: " + response_message
)