Fix chat tests since streaming. Pass args correctly to chat methods

- Fix testing gpt converse method after it started streaming responses
- Pass stop in model_kwargs dictionary and api key in openai_api_key
  parameter to chat completion methods. This should resolve the arg
  warning thrown by OpenAI module
This commit is contained in:
Debanjum Singh Solanky 2023-07-07 15:23:44 -07:00
parent 48870d9170
commit 11f0a9f196
3 changed files with 42 additions and 33 deletions

View file

@ -31,8 +31,8 @@ def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=50
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
stop='"""',
api_key=api_key,
model_kwargs={"stop": ['"""']},
openai_api_key=api_key,
)
# Extract, Clean Message from GPT's Response
@ -59,8 +59,8 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop='"""',
api_key=api_key,
model_kwargs={"stop": ['"""']},
openai_api_key=api_key,
)
# Extract, Clean Message from GPT's Response
@ -104,8 +104,8 @@ def extract_questions(
model_name=model,
temperature=temperature,
max_tokens=max_tokens,
stop=["A: ", "\n"],
api_key=api_key,
model_kwargs={"stop": ["A: ", "\n"]},
openai_api_key=api_key,
)
# Extract, Clean Message from GPT's Response
@ -143,8 +143,8 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
temperature=temperature,
max_tokens=max_tokens,
frequency_penalty=0.2,
stop=["\n"],
api_key=api_key,
model_kwargs={"stop": ["\n"]},
openai_api_key=api_key,
)
# Extract, Clean Message from GPT's Response
@ -155,9 +155,9 @@ def converse(
references,
user_query,
conversation_log={},
model: Optional[str] = "gpt-3.5-turbo",
api_key=None,
temperature=0.2,
model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,
temperature: float = 0.2,
completion_func=None,
):
"""

View file

@ -186,11 +186,12 @@ def test_generate_search_query_with_date_and_context_from_chat_history():
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content():
# Act
response = converse(
response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Hello, my name is Testatron. Who are you?",
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["Khoj", "khoj"]
@ -210,12 +211,13 @@ def test_answer_from_chat_history_and_no_content():
]
# Act
response = converse(
response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query
user_query="What is my name?",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["Testatron", "testatron"]
@ -240,12 +242,13 @@ def test_answer_from_chat_history_and_previously_retrieved_content():
]
# Act
response = converse(
response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert len(response) > 0
@ -264,7 +267,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content():
]
# Act
response = converse(
response_gen = converse(
references=[
"Testatron was born on 1st April 1984 in Testville."
], # Assume context retrieved from notes for the user_query
@ -272,6 +275,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content():
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert len(response) > 0
@ -289,12 +293,13 @@ def test_refuse_answering_unanswerable_question():
]
# Act
response = converse(
response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = [
@ -329,11 +334,12 @@ Expenses:Food:Dining 10.00 USD""",
]
# Act
response = converse(
response_gen = converse(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I have for Dinner today?",
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["tacos", "Tacos"]
@ -360,11 +366,12 @@ Expenses:Food:Dining 10.00 USD""",
]
# Act
response = converse(
response_gen = converse(
references=context, # Assume context retrieved from notes for the user_query
user_query="How much did I spend on dining this year?",
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert len(response) > 0
@ -383,12 +390,13 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
]
# Act
response = converse(
response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines",
conversation_log=populate_chat_history(message_list),
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["test", "Test"]
@ -414,11 +422,12 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
]
# Act
response = converse(
response_gen = converse(
references=context, # Assume context retrieved from notes for the user_query
user_query="How many kids does my older sister have?",
api_key=api_key,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"]

View file

@ -41,7 +41,7 @@ def populate_chat_history(message_list):
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"')
response_message = response.json()["response"]
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["Khoj", "khoj"]
@ -63,7 +63,7 @@ def test_answer_from_chat_history(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"')
response_message = response.json()["response"]
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["Testatron", "testatron"]
@ -89,7 +89,7 @@ def test_answer_from_currently_retrieved_content(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
response_message = response.json()["response"]
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
@ -112,7 +112,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"')
response_message = response.json()["response"]
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
@ -134,7 +134,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"')
response_message = response.json()["response"]
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
@ -158,13 +158,13 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="Where was I born?"')
response_message = response.json()["response"]
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected chat director to say they don't know in response, but got: " + response
"Expected chat director to say they don't know in response, but got: " + response_message
)
@ -176,7 +176,7 @@ def test_answer_requires_current_date_awareness(chat_client):
"Chat actor should be able to answer questions relative to current date using provided notes"
# Act
response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"')
response_message = response.json()["response"]
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["Arak", "Medellin"]
@ -193,7 +193,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
# Act
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"')
response_message = response.json()["response"]
response_message = response.content.decode("utf-8")
# Assert
assert response.status_code == 200
@ -213,7 +213,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
# Act
response = chat_client.get(f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."')
response_message = response.json()["response"]
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["test", "Test"]
@ -230,7 +230,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"')
response_message = response.json()["response"]
response_message = response.content.decode("utf-8")
# Assert
expected_responses = [
@ -259,7 +259,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"')
response_message = response.json()["response"]
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["Testatron", "testatron"]
@ -275,7 +275,7 @@ def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information"
# Act
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"')
response_message = response.json()["response"]
response_message = response.content.decode("utf-8")
# Assert
expected_responses = ["he is older than namita", "xi is older than namita"]