khoj/tests/test_chat_actors.py
Debanjum Singh Solanky 08f5fb315f Add answers to context for Search Actor to generate relevant queries
Update Search Actor prompt with answers, more precise primer and
two more examples for context

Mark the 3 chat quality tests using answer as context to generate
queries as expected to pass. Verify that the 3 tests pass now, unlike
before when the Search Actor did not have the answers for context
2023-03-18 16:30:55 -06:00

432 lines
16 KiB
Python

# Standard Packages
import os
from datetime import datetime
# External Packages
import pytest
from freezegun import freeze_time
# Internal Packages
from khoj.processor.conversation.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log
# Initialize variables for tests
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
pytest.skip(
reason="Set OPENAI_API_KEY environment variable to run tests below. Get OpenAI API key from https://platform.openai.com/account/api-keys",
allow_module_level=True,
)
# Test
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_day():
# Act
response = extract_questions("Where did I go for dinner yesterday?")
# Assert
expected_responses = [
('dt="1984-04-01"', ""),
('dt>="1984-04-01"', 'dt<"1984-04-02"'),
('dt>"1984-03-31"', 'dt<"1984-04-02"'),
]
assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to 1st April 1984 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_month():
# Act
response = extract_questions("Which countries did I visit last month?")
# Assert
expected_responses = [('dt>="1984-03-01"', 'dt<"1984-04-01"'), ('dt>="1984-03-01"', 'dt<="1984-03-31"')]
assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to March 1984 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_year():
# Act
response = extract_questions("Which countries have I visited this year?")
# Assert
expected_responses = [
('dt>="1984-01-01"', ""),
('dt>="1984-01-01"', 'dt<"1985-01-01"'),
('dt>="1984-01-01"', 'dt<="1984-12-31"'),
]
assert len(response) == 1
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to 1984 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_explicit_questions_from_message():
# Act
response = extract_questions("What is the Sun? What is the Moon?")
# Assert
expected_responses = [
("sun", "moon"),
]
assert len(response) == 2
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
"Expected two search queries in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message():
# Act
response = extract_questions("Is Morpheus taller than Neo?")
# Assert
expected_responses = [
("morpheus", "neo"),
]
assert len(response) == 2
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
"Expected two search queries in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_question_from_chat_history():
# Arrange
message_list = [
("What is the name of Mr. Vaders daughter?", "Princess Leia", ""),
]
# Act
response = extract_questions("Does he have any sons?", conversation_log=populate_chat_history(message_list))
# Assert
assert len(response) == 1
assert "Vader" in response[0]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_answer_from_chat_history():
# Arrange
message_list = [
("What is the name of Mr. Vaders daughter?", "Princess Leia", ""),
]
# Act
response = extract_questions("Is she a Jedi?", conversation_log=populate_chat_history(message_list))
# Assert
assert len(response) == 1
assert "Leia" in response[0]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_question_and_answer_from_chat_history():
# Arrange
message_list = [
("Does Luke Skywalker have any Siblings?", "Yes, Princess Leia", ""),
]
# Act
response = extract_questions("Who is their father?", conversation_log=populate_chat_history(message_list))
# Assert
assert len(response) == 1
assert "Leia" in response[0] and "Luke" in response[0]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_with_date_and_context_from_chat_history():
# Arrange
message_list = [
("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", ""),
]
# Act
response = extract_questions(
"What was the Pizza place we ate at over there?", conversation_log=populate_chat_history(message_list)
)
# Assert
expected_responses = [
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
('dt>="2000-04-01"', 'dt<="2000-04-31"'),
]
assert len(response) == 1
assert "Masai Mara" in response[0]
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to April 2000 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content():
# Act
response = converse(
text="", # Assume no context retrieved from notes for the user_query
user_query="Hello, my name is Testatron. Who are you?",
api_key=api_key,
)
# Assert
expected_responses = ["Khoj", "khoj"]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_no_content():
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", ""),
("When was I born?", "You were born on 1st April 1984.", ""),
]
# Act
response = converse(
text="", # Assume no context retrieved from notes for the user_query
user_query="What is my name?",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
# Assert
expected_responses = ["Testatron", "testatron"]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected [T|t]estatron in response but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content():
"Chat actor needs to use context in previous notes and chat history to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", ""),
("When was I born?", "You were born on 1st April 1984.", "Testatron was born on 1st April 1984 in Testville."),
]
# Act
response = converse(
text="", # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
# Assert
assert len(response) > 0
# Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes
assert "Testville" in response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_currently_retrieved_content():
"Chat actor needs to use context across currently retrieved notes and chat history to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", ""),
("When was I born?", "You were born on 1st April 1984.", ""),
]
# Act
response = converse(
text="Testatron was born on 1st April 1984 in Testville.", # Assume context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
# Assert
assert len(response) > 0
assert "Testville" in response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_no_answer_in_chat_history_or_retrieved_content():
"Chat actor should say don't know as not enough contexts in chat history or retrieved to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", ""),
("When was I born?", "You were born on 1st April 1984.", ""),
]
# Act
response = converse(
text="", # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
# Assert
expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected chat actor to say they don't know in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_requires_current_date_awareness():
"Chat actor should be able to answer questions relative to current date using provided notes"
# Arrange
context = f"""
# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
Expenses:Food:Dining 10.00 USD
# {datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
Expenses:Food:Dining 10.00 USD
# 2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries 10.00 USD
# 2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining 10.00 USD
"""
# Act
response = converse(
text=context, # Assume context retrieved from notes for the user_query
user_query="What did I have for Dinner today?",
api_key=api_key,
)
# Assert
expected_responses = ["tacos", "Tacos"]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected [T|t]acos in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_requires_date_aware_aggregation_across_provided_notes():
"Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
# Arrange
context = f"""
# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
Expenses:Food:Dining 10.00 USD
# {datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
Expenses:Food:Dining 10.00 USD
# 2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries 10.00 USD
# 2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining 10.00 USD
"""
# Act
response = converse(
text=context, # Assume context retrieved from notes for the user_query
user_query="How much did I spend on dining this year?",
api_key=api_key,
)
# Assert
assert len(response) > 0
assert "20" in response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_general_question_not_in_chat_history_or_retrieved_content():
"Chat actor should be able to answer general questions not requiring looking at chat history or notes"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", ""),
("When was I born?", "You were born on 1st April 1984.", ""),
("Where was I born?", "You were born Testville.", ""),
]
# Act
response = converse(
text="", # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
# Assert
expected_responses = ["test", "Test"]
assert len(response.splitlines()) == 3 # haikus are 3 lines long
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected [T|t]est in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.")
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question():
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
# Arrange
context = f"""
# Ramya
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.
# Fang
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.
# Aiyla
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.
"""
# Act
response = converse(
text=context, # Assume context retrieved from notes for the user_query
user_query="How many kids does my older sister have?",
api_key=api_key,
)
# Assert
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"]
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected chat actor to ask for clarification in response, but got: " + response
)
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, gpt_message, context in message_list:
conversation_log["chat"] += message_to_log(
user_message,
gpt_message,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
)
return conversation_log