diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py index 155baf8c..a7191a66 100644 --- a/tests/test_gpt4all_chat_actors.py +++ b/tests/test_gpt4all_chat_actors.py @@ -76,9 +76,29 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting") @pytest.mark.chatquality @freeze_time("1984-04-02") -def test_extract_question_with_date_filter_from_relative_year(loaded_model): +def test_extract_question_with_date_filter_from_relative_year(): + # Act + response = extract_questions_offline("Which countries have I visited this year?") + + # Assert + expected_responses = [ + ("dt>='1984-01-01'", ""), + ("dt>='1984-01-01'", "dt<'1985-01-01'"), + ("dt>='1984-01-01'", "dt<='1984-12-31'"), + ] + assert len(response) == 1 + assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( + "Expected date filter to limit to 1984 in response but got: " + response[0] + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +@freeze_time("1984-04-02") +def test_extract_question_includes_root_question(loaded_model): # Act response = extract_questions_offline("Which countries have I visited this year?", loaded_model=loaded_model) @@ -107,13 +127,13 @@ def test_extract_multiple_implicit_questions_from_message(loaded_model): response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model) # Assert - expected_responses = [ - ("morpheus", "neo", "height", "taller", "shorter"), - ] - assert len(response) == 3 - assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), ( - "Expected two search queries in response but got: " + response[0] - ) + expected_responses = ["height", "taller", "shorter", "heights"] + assert len(response) <= 3 + + for question in response: + assert any([expected_response in question.lower() for expected_response in expected_responses]), ( + "Expected chat actor to ask follow-up questions about Morpheus and Neo, but got: " + question + ) # ----------------------------------------------------------------------------------------------------