Add basic chat actor test to infer scheduled queries

This commit is contained in:
Debanjum Singh Solanky 2024-04-21 13:43:46 +05:30
parent c11742f443
commit c28d7d3414

View file

@ -12,8 +12,10 @@ from khoj.routers.helpers import (
aget_relevant_output_modes,
generate_online_subqueries,
infer_webpage_urls,
schedule_query,
)
from khoj.utils.helpers import ConversationCommand
from khoj.utils.rawconfig import LocationData
# Initialize variables for tests
api_key = os.getenv("OPENAI_API_KEY")
@ -490,71 +492,42 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_default_response_mode(chat_client):
# Arrange
user_query = "What's the latest in the Israel/Palestine conflict?"
@pytest.mark.parametrize(
"user_query, expected_mode",
[
("What's the latest in the Israel/Palestine conflict?", "default"),
("Summarize the latest tech news every Monday evening", "reminder"),
("Paint a scenery in Timbuktu in the winter", "image"),
("Remind me, when did I last visit the Serengeti?", "default"),
],
)
async def test_use_default_response_mode(chat_client, user_query, expected_mode):
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "default"
assert mode.value == expected_mode
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_image_response_mode(chat_client):
# Arrange
user_query = "Paint a scenery in Timbuktu in the winter"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "image"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_select_data_sources_actor_chooses_to_search_notes(chat_client):
# Arrange
user_query = "Where did I learn to swim?"
@pytest.mark.parametrize(
"user_query, expected_conversation_commands",
[
("Where did I learn to swim?", [ConversationCommand.Notes]),
("Where is the nearest hospital?", [ConversationCommand.Online]),
("Summarize the wikipedia page on the history of the internet", [ConversationCommand.Webpage]),
],
)
async def test_select_data_sources_actor_chooses_to_search_notes(
chat_client, user_query, expected_conversation_commands
):
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {})
# Assert
assert ConversationCommand.Notes in conversation_commands
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_select_data_sources_actor_chooses_to_search_online(chat_client):
# Arrange
user_query = "Where is the nearest hospital?"
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {})
# Assert
assert ConversationCommand.Online in conversation_commands
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_select_data_sources_actor_chooses_to_read_webpage(chat_client):
# Arrange
user_query = "Summarize the wikipedia page on the history of the internet"
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {})
# Assert
assert ConversationCommand.Webpage in conversation_commands
assert expected_conversation_commands in conversation_commands
# ----------------------------------------------------------------------------------------------------
@ -571,6 +544,33 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client):
assert "https://en.wikipedia.org/wiki/History_of_the_Internet" in urls
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@pytest.mark.parametrize(
"user_query, location, expected_crontime, expected_queries",
[
(
"Share the weather forecast for the next day at 19:30",
("Boston", "MA", "USA"),
"30 23 * * *",
["weather forecast", "boston"],
),
],
)
async def test_infer_task_scheduling_request(chat_client, user_query, location, expected_crontime, expected_queries):
# Arrange
location_data = LocationData(city=location[0], region=location[1], country=location[2])
# Act
crontime, inferred_query = await schedule_query(user_query, location_data, {})
# Assert
assert expected_crontime in crontime
for query in expected_queries:
assert query in inferred_query.lower()
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):