khoj/tests/test_openai_chat_actors.py

525 lines
19 KiB
Python
Raw Normal View History

import os
from datetime import datetime
import freezegun
import pytest
from freezegun import freeze_time
from khoj.processor.conversation.openai.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import (
aget_relevant_information_sources,
aget_relevant_output_modes,
generate_online_subqueries,
)
from khoj.utils.helpers import ConversationCommand
# Initialize variables for tests
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
pytest.skip(
reason="Set OPENAI_API_KEY environment variable to run tests below. Get OpenAI API key from https://platform.openai.com/account/api-keys",
allow_module_level=True,
)
freezegun.configure(extend_ignore_list=["transformers"])
# Test
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_day():
# Act
response = extract_questions("Where did I go for dinner yesterday?")
# Assert
expected_responses = [
("dt='1984-04-01'", ""),
("dt>='1984-04-01'", "dt<'1984-04-02'"),
("dt>'1984-03-31'", "dt<'1984-04-02'"),
]
assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to 1st April 1984 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_month():
# Act
response = extract_questions("Which countries did I visit last month?")
# Assert
expected_responses = [("dt>='1984-03-01'", "dt<'1984-04-01'"), ("dt>='1984-03-01'", "dt<='1984-03-31'")]
assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to March 1984 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_year():
# Act
response = extract_questions("Which countries have I visited this year?")
# Assert
expected_responses = [
("dt>='1984-01-01'", ""),
("dt>='1984-01-01'", "dt<'1985-01-01'"),
("dt>='1984-01-01'", "dt<='1984-12-31'"),
]
assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to 1984 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_explicit_questions_from_message():
# Act
response = extract_questions("What is the Sun? What is the Moon?")
# Assert
expected_responses = [
("sun", "moon"),
]
assert len(response) == 2
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
"Expected two search queries in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message():
# Act
response = extract_questions("Is Morpheus taller than Neo?")
# Assert
expected_responses = [
("morpheus", "neo"),
]
assert len(response) == 2
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
"Expected two search queries in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_question_from_chat_history():
# Arrange
message_list = [
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
]
# Act
response = extract_questions("Does he have any sons?", conversation_log=populate_chat_history(message_list))
# Assert
assert len(response) == 1
assert "Vader" in response[0]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_answer_from_chat_history():
# Arrange
message_list = [
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
]
# Act
response = extract_questions("Is she a Jedi?", conversation_log=populate_chat_history(message_list))
# Assert
assert len(response) == 1
assert "Leia" in response[0]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_question_and_answer_from_chat_history():
# Arrange
message_list = [
("Does Luke Skywalker have any Siblings?", "Yes, Princess Leia", []),
]
# Act
response = extract_questions("Who is their father?", conversation_log=populate_chat_history(message_list))
# Assert
assert len(response) == 1
assert "Leia" in response[0] and "Luke" in response[0]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content():
# Act
response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Hello, my name is Testatron. Who are you?",
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["Khoj", "khoj"]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_no_content():
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
# Act
response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query
user_query="What is my name?",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["Testatron", "testatron"]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected [T|t]estatron in response but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content():
"Chat actor needs to use context in previous notes and chat history to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
(
"When was I born?",
"You were born on 1st April 1984.",
["Testatron was born on 1st April 1984 in Testville."],
),
]
# Act
response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert len(response) > 0
# Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes
assert "Testville" in response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_currently_retrieved_content():
"Chat actor needs to use context across currently retrieved notes and chat history to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
# Act
response_gen = converse(
references=[
"Testatron was born on 1st April 1984 in Testville."
], # Assume context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert len(response) > 0
assert "Testville" in response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_refuse_answering_unanswerable_question():
"Chat actor should not try make up answers to unanswerable questions."
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
# Act
response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = [
"don't know",
"do not know",
"no information",
"do not have",
"don't have",
"cannot answer",
"I'm sorry",
]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected chat actor to say they don't know in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_requires_current_date_awareness():
"Chat actor should be able to answer questions relative to current date using provided notes"
# Arrange
context = [
f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
Expenses:Food:Dining 10.00 USD""",
f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
Expenses:Food:Dining 10.00 USD""",
f"""2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries 10.00 USD""",
f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining 10.00 USD""",
]
# Act
response_gen = converse(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I have for Dinner today?",
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["tacos", "Tacos"]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected [T|t]acos in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_requires_date_aware_aggregation_across_provided_notes():
"Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
# Arrange
context = [
f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
Expenses:Food:Dining 10.00 USD""",
f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
Expenses:Food:Dining 10.00 USD""",
f"""2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries 10.00 USD""",
f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining 10.00 USD""",
]
# Act
response_gen = converse(
references=context, # Assume context retrieved from notes for the user_query
user_query="How much did I spend on dining this year?",
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert len(response) > 0
assert "20" in response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_general_question_not_in_chat_history_or_retrieved_content():
"Chat actor should be able to answer general questions not requiring looking at chat history or notes"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
# Act
response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines. Do not say anything else",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["test", "Test"]
assert len(response.splitlines()) == 3 # haikus are 3 lines long
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected [T|t]est in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.")
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question():
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
# Arrange
context = [
f"""# Ramya
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""",
f"""# Fang
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""",
f"""# Aiyla
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""",
]
# Act
response_gen = converse(
references=context, # Assume context retrieved from notes for the user_query
user_query="How many kids does my older sister have?",
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"]
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected chat actor to ask for clarification in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@freeze_time("2024-04-04", ignore=["transformers"])
async def test_websearch_with_operators(chat_client):
# Arrange
user_query = "Share popular posts on r/worldnews this month"
# Act
responses = await generate_online_subqueries(user_query, {}, None)
# Assert
assert any(
["reddit.com/r/worldnews" in response for response in responses]
), "Expected a search query to include site:reddit.com but got: " + str(responses)
assert any(
["site:reddit.com" in response for response in responses]
), "Expected a search query to include site:reddit.com but got: " + str(responses)
assert any(
["after:2024/04/01" in response for response in responses]
), "Expected a search query to include after:2024/04/01 but got: " + str(responses)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_websearch_khoj_website_for_info_about_khoj(chat_client):
# Arrange
user_query = "Do you support image search?"
# Act
responses = await generate_online_subqueries(user_query, {}, None)
# Assert
assert any(
["site:khoj.dev" in response for response in responses]
), "Expected search query to include site:khoj.dev but got: " + str(responses)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_default_response_mode(chat_client):
# Arrange
user_query = "What's the latest in the Israel/Palestine conflict?"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "default"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_use_image_response_mode(chat_client):
# Arrange
user_query = "Paint a scenery in Timbuktu in the winter"
# Act
mode = await aget_relevant_output_modes(user_query, {})
# Assert
assert mode.value == "image"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_select_data_sources_actor_chooses_to_search_notes(chat_client):
# Arrange
user_query = "Where did I learn to swim?"
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {})
# Assert
assert ConversationCommand.Notes in conversation_commands
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_select_data_sources_actor_chooses_to_search_online(chat_client):
# Arrange
user_query = "Where is the nearest hospital?"
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {})
# Assert
assert ConversationCommand.Online in conversation_commands
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, gpt_message, context in message_list:
conversation_log["chat"] += message_to_log(
user_message,
gpt_message,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
)
return conversation_log