mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Add basic chat actor test to infer scheduled queries
This commit is contained in:
parent
c11742f443
commit
c28d7d3414
1 changed files with 52 additions and 52 deletions
|
@ -12,8 +12,10 @@ from khoj.routers.helpers import (
|
|||
aget_relevant_output_modes,
|
||||
generate_online_subqueries,
|
||||
infer_webpage_urls,
|
||||
schedule_query,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
# Initialize variables for tests
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
@ -490,71 +492,42 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client):
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_default_response_mode(chat_client):
|
||||
# Arrange
|
||||
user_query = "What's the latest in the Israel/Palestine conflict?"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_query, expected_mode",
|
||||
[
|
||||
("What's the latest in the Israel/Palestine conflict?", "default"),
|
||||
("Summarize the latest tech news every Monday evening", "reminder"),
|
||||
("Paint a scenery in Timbuktu in the winter", "image"),
|
||||
("Remind me, when did I last visit the Serengeti?", "default"),
|
||||
],
|
||||
)
|
||||
async def test_use_default_response_mode(chat_client, user_query, expected_mode):
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == "default"
|
||||
assert mode.value == expected_mode
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_use_image_response_mode(chat_client):
|
||||
# Arrange
|
||||
user_query = "Paint a scenery in Timbuktu in the winter"
|
||||
|
||||
# Act
|
||||
mode = await aget_relevant_output_modes(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert mode.value == "image"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_select_data_sources_actor_chooses_to_search_notes(chat_client):
|
||||
# Arrange
|
||||
user_query = "Where did I learn to swim?"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_query, expected_conversation_commands",
|
||||
[
|
||||
("Where did I learn to swim?", [ConversationCommand.Notes]),
|
||||
("Where is the nearest hospital?", [ConversationCommand.Online]),
|
||||
("Summarize the wikipedia page on the history of the internet", [ConversationCommand.Webpage]),
|
||||
],
|
||||
)
|
||||
async def test_select_data_sources_actor_chooses_to_search_notes(
|
||||
chat_client, user_query, expected_conversation_commands
|
||||
):
|
||||
# Act
|
||||
conversation_commands = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert ConversationCommand.Notes in conversation_commands
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_select_data_sources_actor_chooses_to_search_online(chat_client):
|
||||
# Arrange
|
||||
user_query = "Where is the nearest hospital?"
|
||||
|
||||
# Act
|
||||
conversation_commands = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert ConversationCommand.Online in conversation_commands
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_select_data_sources_actor_chooses_to_read_webpage(chat_client):
|
||||
# Arrange
|
||||
user_query = "Summarize the wikipedia page on the history of the internet"
|
||||
|
||||
# Act
|
||||
conversation_commands = await aget_relevant_information_sources(user_query, {})
|
||||
|
||||
# Assert
|
||||
assert ConversationCommand.Webpage in conversation_commands
|
||||
assert expected_conversation_commands in conversation_commands
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -571,6 +544,33 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client):
|
|||
assert "https://en.wikipedia.org/wiki/History_of_the_Internet" in urls
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.parametrize(
|
||||
"user_query, location, expected_crontime, expected_queries",
|
||||
[
|
||||
(
|
||||
"Share the weather forecast for the next day at 19:30",
|
||||
("Boston", "MA", "USA"),
|
||||
"30 23 * * *",
|
||||
["weather forecast", "boston"],
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_infer_task_scheduling_request(chat_client, user_query, location, expected_crontime, expected_queries):
|
||||
# Arrange
|
||||
location_data = LocationData(city=location[0], region=location[1], country=location[2])
|
||||
|
||||
# Act
|
||||
crontime, inferred_query = await schedule_query(user_query, location_data, {})
|
||||
|
||||
# Assert
|
||||
assert expected_crontime in crontime
|
||||
for query in expected_queries:
|
||||
assert query in inferred_query.lower()
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
|
|
Loading…
Reference in a new issue