Dedupe query in notes prompt. Improve OAI chat actor, director tests

- Remove stale tests
- Improve tests to pass across gpt-3.5 and gpt-4-turbo
- The haiku creation director was failing because of duplicate query in
  instantiated prompt
This commit is contained in:
Debanjum Singh Solanky 2024-03-13 18:46:26 +05:30
parent 70b04d16c0
commit dd883dc53a
4 changed files with 38 additions and 61 deletions

View file

@ -149,7 +149,7 @@ def converse(
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
)
if not is_none_or_empty(compiled_references):
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n{conversation_primer}"
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(

View file

@ -104,8 +104,6 @@ Ask crisp follow-up questions to get additional context, when a helpful response
Notes:
{references}
Query: {query}
""".strip()
)

View file

@ -159,33 +159,6 @@ def test_generate_search_query_using_question_and_answer_from_chat_history():
assert "Leia" in response[0] and "Luke" in response[0]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_with_date_and_context_from_chat_history():
# Arrange
message_list = [
("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []),
]
# Act
response = extract_questions(
"What was the Pizza place we ate at over there?", conversation_log=populate_chat_history(message_list)
)
# Assert
expected_responses = [
("dt>='2000-04-01'", "dt<'2000-05-01'"),
("dt>='2000-04-01'", "dt<='2000-04-30'"),
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
]
assert len(response) == 1
assert "Masai Mara" in response[0]
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to April 2000 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content():
@ -396,7 +369,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
# Act
response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines",
user_query="Write a haiku about unit testing in 3 lines. Do not say anything else",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
@ -500,7 +473,7 @@ async def test_use_default_response_mode(chat_client):
@pytest.mark.django_db(transaction=True)
async def test_use_image_response_mode(chat_client):
# Arrange
user_query = "Paint a picture of the scenery in Timbuktu in the winter"
user_query = "Paint a scenery in Timbuktu in the winter"
# Act
mode = await aget_relevant_output_modes(user_query, {})
@ -509,20 +482,6 @@ async def test_use_image_response_mode(chat_client):
assert mode.value == "image"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_select_data_sources_actor_chooses_default(chat_client):
# Arrange
user_query = "How can I improve my swimming compared to my last lesson?"
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {})
# Assert
assert ConversationCommand.Default in conversation_commands
# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)

View file

@ -222,9 +222,17 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"]
expected_responses = [
"don't know",
"do not know",
"no information",
"do not have",
"don't have",
"where were you born?",
]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected chat director to say they don't know in response, but got: " + response_message
)
@ -330,10 +338,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
populate_chat_history(message_list, default_user2)
# Act
response = chat_client.get(
f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."&stream=true'
)
response_message = response.content.decode("utf-8")
response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else."&stream=true')
response_message = response.content.decode("utf-8").split("### compiled references")[0]
# Assert
expected_responses = ["test", "Test"]
@ -350,8 +356,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
# Act
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son"&stream=true')
response_message = response.content.decode("utf-8")
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"&stream=true')
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
# Assert
expected_responses = [
@ -361,9 +367,11 @@ def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_
"the birth order",
"provide more context",
"provide me with more context",
"don't have that",
"haven't provided me",
]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected chat director to ask for clarification in response, but got: " + response_message
)
@ -399,13 +407,18 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user
def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information"
# Act
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"&stream=true')
response_message = response.content.decode("utf-8")
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"&stream=true')
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
# Assert
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
only_full_name_check = "xi li" in response_message and "namita" not in response_message
comparative_statement_check = any(
[expected_response in response_message for expected_response in expected_responses]
)
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
assert only_full_name_check or comparative_statement_check, (
"Expected Xi is older than Namita, but got: " + response_message
)
@ -415,15 +428,22 @@ def test_answer_requires_multiple_independent_searches(chat_client):
def test_answer_using_file_filter(chat_client):
"Chat should be able to use search filters in the query"
# Act
query = urllib.parse.quote('Is Xi older than Namita? file:"Namita.markdown" file:"Xi Li.markdown"')
query = urllib.parse.quote(
'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
)
response = chat_client.get(f"/api/chat?q={query}&stream=true")
response_message = response.content.decode("utf-8")
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
# Assert
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
only_full_name_check = "xi li" in response_message and "namita" not in response_message
comparative_statement_check = any(
[expected_response in response_message for expected_response in expected_responses]
)
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
assert only_full_name_check or comparative_statement_check, (
"Expected Xi is older than Namita, but got: " + response_message
)