Fix chat tests since streaming. Pass args correctly to chat methods

- Fix testing gpt converse method after it started streaming responses
- Pass stop in model_kwargs dictionary and api key in openai_api_key
  parameter to chat completion methods. This should resolve the arg
  warning thrown by OpenAI module
This commit is contained in:
Debanjum Singh Solanky 2023-07-07 15:23:44 -07:00
parent 48870d9170
commit 11f0a9f196
3 changed files with 42 additions and 33 deletions

View file

@ -31,8 +31,8 @@ def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=50
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
stop='"""', model_kwargs={"stop": ['"""']},
api_key=api_key, openai_api_key=api_key,
) )
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
@ -59,8 +59,8 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
frequency_penalty=0.2, frequency_penalty=0.2,
stop='"""', model_kwargs={"stop": ['"""']},
api_key=api_key, openai_api_key=api_key,
) )
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
@ -104,8 +104,8 @@ def extract_questions(
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
stop=["A: ", "\n"], model_kwargs={"stop": ["A: ", "\n"]},
api_key=api_key, openai_api_key=api_key,
) )
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
@ -143,8 +143,8 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
frequency_penalty=0.2, frequency_penalty=0.2,
stop=["\n"], model_kwargs={"stop": ["\n"]},
api_key=api_key, openai_api_key=api_key,
) )
# Extract, Clean Message from GPT's Response # Extract, Clean Message from GPT's Response
@ -155,9 +155,9 @@ def converse(
references, references,
user_query, user_query,
conversation_log={}, conversation_log={},
model: Optional[str] = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
api_key=None, api_key: Optional[str] = None,
temperature=0.2, temperature: float = 0.2,
completion_func=None, completion_func=None,
): ):
""" """

View file

@ -186,11 +186,12 @@ def test_generate_search_query_with_date_and_context_from_chat_history():
@pytest.mark.chatquality @pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content(): def test_chat_with_no_chat_history_or_retrieved_content():
# Act # Act
response = converse( response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query references=[], # Assume no context retrieved from notes for the user_query
user_query="Hello, my name is Testatron. Who are you?", user_query="Hello, my name is Testatron. Who are you?",
api_key=api_key, api_key=api_key,
) )
response = "".join([response_chunk for response_chunk in response_gen])
# Assert # Assert
expected_responses = ["Khoj", "khoj"] expected_responses = ["Khoj", "khoj"]
@ -210,12 +211,13 @@ def test_answer_from_chat_history_and_no_content():
] ]
# Act # Act
response = converse( response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query references=[], # Assume no context retrieved from notes for the user_query
user_query="What is my name?", user_query="What is my name?",
conversation_log=populate_chat_history(message_list), conversation_log=populate_chat_history(message_list),
api_key=api_key, api_key=api_key,
) )
response = "".join([response_chunk for response_chunk in response_gen])
# Assert # Assert
expected_responses = ["Testatron", "testatron"] expected_responses = ["Testatron", "testatron"]
@ -240,12 +242,13 @@ def test_answer_from_chat_history_and_previously_retrieved_content():
] ]
# Act # Act
response = converse( response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?", user_query="Where was I born?",
conversation_log=populate_chat_history(message_list), conversation_log=populate_chat_history(message_list),
api_key=api_key, api_key=api_key,
) )
response = "".join([response_chunk for response_chunk in response_gen])
# Assert # Assert
assert len(response) > 0 assert len(response) > 0
@ -264,7 +267,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content():
] ]
# Act # Act
response = converse( response_gen = converse(
references=[ references=[
"Testatron was born on 1st April 1984 in Testville." "Testatron was born on 1st April 1984 in Testville."
], # Assume context retrieved from notes for the user_query ], # Assume context retrieved from notes for the user_query
@ -272,6 +275,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content():
conversation_log=populate_chat_history(message_list), conversation_log=populate_chat_history(message_list),
api_key=api_key, api_key=api_key,
) )
response = "".join([response_chunk for response_chunk in response_gen])
# Assert # Assert
assert len(response) > 0 assert len(response) > 0
@ -289,12 +293,13 @@ def test_refuse_answering_unanswerable_question():
] ]
# Act # Act
response = converse( response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?", user_query="Where was I born?",
conversation_log=populate_chat_history(message_list), conversation_log=populate_chat_history(message_list),
api_key=api_key, api_key=api_key,
) )
response = "".join([response_chunk for response_chunk in response_gen])
# Assert # Assert
expected_responses = [ expected_responses = [
@ -329,11 +334,12 @@ Expenses:Food:Dining 10.00 USD""",
] ]
# Act # Act
response = converse( response_gen = converse(
references=context, # Assume context retrieved from notes for the user_query references=context, # Assume context retrieved from notes for the user_query
user_query="What did I have for Dinner today?", user_query="What did I have for Dinner today?",
api_key=api_key, api_key=api_key,
) )
response = "".join([response_chunk for response_chunk in response_gen])
# Assert # Assert
expected_responses = ["tacos", "Tacos"] expected_responses = ["tacos", "Tacos"]
@ -360,11 +366,12 @@ Expenses:Food:Dining 10.00 USD""",
] ]
# Act # Act
response = converse( response_gen = converse(
references=context, # Assume context retrieved from notes for the user_query references=context, # Assume context retrieved from notes for the user_query
user_query="How much did I spend on dining this year?", user_query="How much did I spend on dining this year?",
api_key=api_key, api_key=api_key,
) )
response = "".join([response_chunk for response_chunk in response_gen])
# Assert # Assert
assert len(response) > 0 assert len(response) > 0
@ -383,12 +390,13 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
] ]
# Act # Act
response = converse( response_gen = converse(
references=[], # Assume no context retrieved from notes for the user_query references=[], # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines", user_query="Write a haiku about unit testing in 3 lines",
conversation_log=populate_chat_history(message_list), conversation_log=populate_chat_history(message_list),
api_key=api_key, api_key=api_key,
) )
response = "".join([response_chunk for response_chunk in response_gen])
# Assert # Assert
expected_responses = ["test", "Test"] expected_responses = ["test", "Test"]
@ -414,11 +422,12 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
] ]
# Act # Act
response = converse( response_gen = converse(
references=context, # Assume context retrieved from notes for the user_query references=context, # Assume context retrieved from notes for the user_query
user_query="How many kids does my older sister have?", user_query="How many kids does my older sister have?",
api_key=api_key, api_key=api_key,
) )
response = "".join([response_chunk for response_chunk in response_gen])
# Assert # Assert
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"] expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"]

View file

@ -41,7 +41,7 @@ def populate_chat_history(message_list):
def test_chat_with_no_chat_history_or_retrieved_content(chat_client): def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# Act # Act
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"') response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"')
response_message = response.json()["response"] response_message = response.content.decode("utf-8")
# Assert # Assert
expected_responses = ["Khoj", "khoj"] expected_responses = ["Khoj", "khoj"]
@ -63,7 +63,7 @@ def test_answer_from_chat_history(chat_client):
# Act # Act
response = chat_client.get(f'/api/chat?q="What is my name?"') response = chat_client.get(f'/api/chat?q="What is my name?"')
response_message = response.json()["response"] response_message = response.content.decode("utf-8")
# Assert # Assert
expected_responses = ["Testatron", "testatron"] expected_responses = ["Testatron", "testatron"]
@ -89,7 +89,7 @@ def test_answer_from_currently_retrieved_content(chat_client):
# Act # Act
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"') response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
response_message = response.json()["response"] response_message = response.content.decode("utf-8")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -112,7 +112,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
# Act # Act
response = chat_client.get(f'/api/chat?q="Where was I born?"') response = chat_client.get(f'/api/chat?q="Where was I born?"')
response_message = response.json()["response"] response_message = response.content.decode("utf-8")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -134,7 +134,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
# Act # Act
response = chat_client.get(f'/api/chat?q="Where was I born?"') response = chat_client.get(f'/api/chat?q="Where was I born?"')
response_message = response.json()["response"] response_message = response.content.decode("utf-8")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -158,13 +158,13 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
# Act # Act
response = chat_client.get(f'/api/chat?q="Where was I born?"') response = chat_client.get(f'/api/chat?q="Where was I born?"')
response_message = response.json()["response"] response_message = response.content.decode("utf-8")
# Assert # Assert
expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"] expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"]
assert response.status_code == 200 assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), ( assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected chat director to say they don't know in response, but got: " + response "Expected chat director to say they don't know in response, but got: " + response_message
) )
@ -176,7 +176,7 @@ def test_answer_requires_current_date_awareness(chat_client):
"Chat actor should be able to answer questions relative to current date using provided notes" "Chat actor should be able to answer questions relative to current date using provided notes"
# Act # Act
response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"') response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"')
response_message = response.json()["response"] response_message = response.content.decode("utf-8")
# Assert # Assert
expected_responses = ["Arak", "Medellin"] expected_responses = ["Arak", "Medellin"]
@ -193,7 +193,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien
"Chat director should be able to answer questions that require date aware aggregation across multiple notes" "Chat director should be able to answer questions that require date aware aggregation across multiple notes"
# Act # Act
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"') response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"')
response_message = response.json()["response"] response_message = response.content.decode("utf-8")
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -213,7 +213,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
# Act # Act
response = chat_client.get(f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."') response = chat_client.get(f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."')
response_message = response.json()["response"] response_message = response.content.decode("utf-8")
# Assert # Assert
expected_responses = ["test", "Test"] expected_responses = ["test", "Test"]
@ -230,7 +230,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client): def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
# Act # Act
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"') response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"')
response_message = response.json()["response"] response_message = response.content.decode("utf-8")
# Assert # Assert
expected_responses = [ expected_responses = [
@ -259,7 +259,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
# Act # Act
response = chat_client.get(f'/api/chat?q="What is my name?"') response = chat_client.get(f'/api/chat?q="What is my name?"')
response_message = response.json()["response"] response_message = response.content.decode("utf-8")
# Assert # Assert
expected_responses = ["Testatron", "testatron"] expected_responses = ["Testatron", "testatron"]
@ -275,7 +275,7 @@ def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information" "Chat director should be able to answer by doing multiple independent searches for required information"
# Act # Act
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"') response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"')
response_message = response.json()["response"] response_message = response.content.decode("utf-8")
# Assert # Assert
expected_responses = ["he is older than namita", "xi is older than namita"] expected_responses = ["he is older than namita", "xi is older than namita"]