Test data source, output mode selector, web search query chat actors

This commit is contained in:
Debanjum Singh Solanky 2024-03-13 16:49:13 +05:30
parent 14682d5354
commit 70b04d16c0

View file

@ -7,7 +7,12 @@ from freezegun import freeze_time
from khoj.processor.conversation.openai.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_output_modes
from khoj.routers.helpers import (
aget_relevant_information_sources,
aget_relevant_output_modes,
generate_online_subqueries,
)
from khoj.utils.helpers import ConversationCommand
# Initialize variables for tests
api_key = os.getenv("OPENAI_API_KEY")
@ -435,6 +440,47 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@freeze_time("2024-04-04", ignore=["transformers"])
async def test_websearch_with_operators(chat_client):
# Arrange
user_query = "Share popular posts on r/worldnews this month"
# Act
responses = await generate_online_subqueries(user_query, {}, None)
# Assert
assert any(
["reddit.com/r/worldnews" in response for response in responses]
), "Expected a search query to include site:reddit.com but got: " + str(responses)
assert any(
["site:reddit.com" in response for response in responses]
), "Expected a search query to include site:reddit.com but got: " + str(responses)
assert any(
["after:2024/04/01" in response for response in responses]
), "Expected a search query to include after:2024/04/01 but got: " + str(responses)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_websearch_khoj_website_for_info_about_khoj(chat_client):
# Arrange
user_query = "Do you support image search?"
# Act
responses = await generate_online_subqueries(user_query, {}, None)
# Assert
assert any(
["site:khoj.dev" in response for response in responses]
), "Expected search query to include site:khoj.dev but got: " + str(responses)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@ -463,6 +509,48 @@ async def test_use_image_response_mode(chat_client):
assert mode.value == "image"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_select_data_sources_actor_chooses_default(chat_client):
# Arrange
user_query = "How can I improve my swimming compared to my last lesson?"
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {})
# Assert
assert ConversationCommand.Default in conversation_commands
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_select_data_sources_actor_chooses_to_search_notes(chat_client):
# Arrange
user_query = "Where did I learn to swim?"
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {})
# Assert
assert ConversationCommand.Notes in conversation_commands
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_select_data_sources_actor_chooses_to_search_online(chat_client):
# Arrange
user_query = "Where is the nearest hospital?"
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {})
# Assert
assert ConversationCommand.Online in conversation_commands
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):