From 01b310635e190cce0d04c92a21e2b8973216a302 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Mon, 28 Aug 2023 00:14:40 -0700 Subject: [PATCH] Enable passing search query filters via chat and test it --- src/khoj/routers/api.py | 69 ++++++++++++++++++----------- tests/test_gpt4all_chat_director.py | 18 ++++++++ tests/test_openai_chat_director.py | 16 +++++++ 3 files changed, 78 insertions(+), 25 deletions(-) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 4f7c6f42..ab547be5 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -702,7 +702,7 @@ async def chat( ) -> Response: perform_chat_checks() conversation_command = get_conversation_command(query=q, any_references=True) - compiled_references, inferred_queries = await extract_references_and_questions( + compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions( request, q, (n or 5), conversation_command ) conversation_command = get_conversation_command(query=q, any_references=not is_none_or_empty(compiled_references)) @@ -713,7 +713,7 @@ async def chat( # Get the (streamed) chat response from the LLM of choice. llm_response = generate_chat_response( - q, + defiltered_query, meta_log=state.processor_config.conversation.meta_log, compiled_references=compiled_references, inferred_queries=inferred_queries, @@ -770,28 +770,47 @@ async def extract_references_and_questions( ) return compiled_references, inferred_queries - if conversation_type != ConversationCommand.General: - # Infer search queries from user message - with timer("Extracting search queries took", logger): - # If we've reached here, either the user has enabled offline chat or the openai model is enabled. - if state.processor_config.conversation.enable_offline_chat: - loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model - inferred_queries = extract_questions_offline( - q, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False - ) - elif state.processor_config.conversation.openai_model: - api_key = state.processor_config.conversation.openai_model.api_key - chat_model = state.processor_config.conversation.openai_model.chat_model - inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log) + if conversation_type == ConversationCommand.General: + return compiled_references, inferred_queries - # Collate search results as context for GPT - with timer("Searching knowledge base took", logger): - result_list = [] - for query in inferred_queries: - n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n - result_list.extend( - await search(query, request=request, n=n_items, r=True, score_threshold=-5.0, dedupe=False) - ) - compiled_references = [item.additional["compiled"] for item in result_list] + # Extract filter terms from user message + defiltered_query = q + filter_terms = [] + for filter in [DateFilter(), WordFilter(), FileFilter()]: + filter_terms += filter.get_filter_terms(q) + defiltered_query = filter.defilter(q) + filters_in_query = " ".join(filter_terms) - return compiled_references, inferred_queries + # Infer search queries from user message + with timer("Extracting search queries took", logger): + # If we've reached here, either the user has enabled offline chat or the openai model is enabled. + if state.processor_config.conversation.enable_offline_chat: + loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model + inferred_queries = extract_questions_offline( + defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False + ) + elif state.processor_config.conversation.openai_model: + api_key = state.processor_config.conversation.openai_model.api_key + chat_model = state.processor_config.conversation.openai_model.chat_model + inferred_queries = extract_questions( + defiltered_query, model=chat_model, api_key=api_key, conversation_log=meta_log + ) + + # Collate search results as context for GPT + with timer("Searching knowledge base took", logger): + result_list = [] + for query in inferred_queries: + n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n + result_list.extend( + await search( + f"{query} {filters_in_query}", + request=request, + n=n_items, + r=True, + score_threshold=-5.0, + dedupe=False, + ) + ) + compiled_references = [item.additional["compiled"] for item in result_list] + + return compiled_references, inferred_queries, defiltered_query diff --git a/tests/test_gpt4all_chat_director.py b/tests/test_gpt4all_chat_director.py index 6da7f759..3e72a7e2 100644 --- a/tests/test_gpt4all_chat_director.py +++ b/tests/test_gpt4all_chat_director.py @@ -209,6 +209,24 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat): assert "Fujiang" in response_message +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_answer_using_file_filter(client_offline_chat): + # Arrange + no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"') + answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"') + message_list = [] + populate_chat_history(message_list) + + # Act + no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8") + answer_response = client_offline_chat.get(f"/api/chat?q={answer_query}&stream=true").content.decode("utf-8") + + # Assert + assert "Fujiang" not in no_answer_response + assert "Fujiang" in answer_response + + # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality def test_answer_not_known_using_notes_command(client_offline_chat): diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index 4f05fc52..abbd1831 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -337,3 +337,19 @@ def test_answer_requires_multiple_independent_searches(chat_client): assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( "Expected Xi is older than Namita, but got: " + response_message ) + + +# ---------------------------------------------------------------------------------------------------- +def test_answer_using_file_filter(chat_client): + "Chat should be able to use search filters in the query" + # Act + query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"') + response = chat_client.get(f"/api/chat?q={query}&stream=true") + response_message = response.content.decode("utf-8") + + # Assert + expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"] + assert response.status_code == 200 + assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( + "Expected Xi is older than Namita, but got: " + response_message + )