diff --git a/tests/test_openai_chat_actors.py b/tests/test_openai_chat_actors.py index df9d8f07..848a6139 100644 --- a/tests/test_openai_chat_actors.py +++ b/tests/test_openai_chat_actors.py @@ -12,8 +12,10 @@ from khoj.routers.helpers import ( aget_relevant_output_modes, generate_online_subqueries, infer_webpage_urls, + schedule_query, ) from khoj.utils.helpers import ConversationCommand +from khoj.utils.rawconfig import LocationData # Initialize variables for tests api_key = os.getenv("OPENAI_API_KEY") @@ -490,71 +492,42 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client): # ---------------------------------------------------------------------------------------------------- @pytest.mark.anyio @pytest.mark.django_db(transaction=True) -async def test_use_default_response_mode(chat_client): - # Arrange - user_query = "What's the latest in the Israel/Palestine conflict?" - +@pytest.mark.parametrize( + "user_query, expected_mode", + [ + ("What's the latest in the Israel/Palestine conflict?", "default"), + ("Summarize the latest tech news every Monday evening", "reminder"), + ("Paint a scenery in Timbuktu in the winter", "image"), + ("Remind me, when did I last visit the Serengeti?", "default"), + ], +) +async def test_use_default_response_mode(chat_client, user_query, expected_mode): # Act mode = await aget_relevant_output_modes(user_query, {}) # Assert - assert mode.value == "default" + assert mode.value == expected_mode # ---------------------------------------------------------------------------------------------------- @pytest.mark.anyio @pytest.mark.django_db(transaction=True) -async def test_use_image_response_mode(chat_client): - # Arrange - user_query = "Paint a scenery in Timbuktu in the winter" - - # Act - mode = await aget_relevant_output_modes(user_query, {}) - - # Assert - assert mode.value == "image" - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_select_data_sources_actor_chooses_to_search_notes(chat_client): - # Arrange - user_query = "Where did I learn to swim?" - +@pytest.mark.parametrize( + "user_query, expected_conversation_commands", + [ + ("Where did I learn to swim?", [ConversationCommand.Notes]), + ("Where is the nearest hospital?", [ConversationCommand.Online]), + ("Summarize the wikipedia page on the history of the internet", [ConversationCommand.Webpage]), + ], +) +async def test_select_data_sources_actor_chooses_to_search_notes( + chat_client, user_query, expected_conversation_commands +): # Act conversation_commands = await aget_relevant_information_sources(user_query, {}) # Assert - assert ConversationCommand.Notes in conversation_commands - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_select_data_sources_actor_chooses_to_search_online(chat_client): - # Arrange - user_query = "Where is the nearest hospital?" - - # Act - conversation_commands = await aget_relevant_information_sources(user_query, {}) - - # Assert - assert ConversationCommand.Online in conversation_commands - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_select_data_sources_actor_chooses_to_read_webpage(chat_client): - # Arrange - user_query = "Summarize the wikipedia page on the history of the internet" - - # Act - conversation_commands = await aget_relevant_information_sources(user_query, {}) - - # Assert - assert ConversationCommand.Webpage in conversation_commands + assert expected_conversation_commands in conversation_commands # ---------------------------------------------------------------------------------------------------- @@ -571,6 +544,33 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client): assert "https://en.wikipedia.org/wiki/History_of_the_Internet" in urls +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +@pytest.mark.parametrize( + "user_query, location, expected_crontime, expected_queries", + [ + ( + "Share the weather forecast for the next day at 19:30", + ("Boston", "MA", "USA"), + "30 23 * * *", + ["weather forecast", "boston"], + ), + ], +) +async def test_infer_task_scheduling_request(chat_client, user_query, location, expected_crontime, expected_queries): + # Arrange + location_data = LocationData(city=location[0], region=location[1], country=location[2]) + + # Act + crontime, inferred_query = await schedule_query(user_query, location_data, {}) + + # Assert + assert expected_crontime in crontime + for query in expected_queries: + assert query in inferred_query.lower() + + # Helpers # ---------------------------------------------------------------------------------------------------- def populate_chat_history(message_list):