mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-12-03 12:23:02 +01:00
Fix chat tests since streaming. Pass args correctly to chat methods
- Fix testing gpt converse method after it started streaming responses - Pass stop in model_kwargs dictionary and api key in openai_api_key parameter to chat completion methods. This should resolve the arg warning thrown by OpenAI module
This commit is contained in:
parent
48870d9170
commit
11f0a9f196
3 changed files with 42 additions and 33 deletions
|
@ -31,8 +31,8 @@ def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=50
|
|||
model_name=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop='"""',
|
||||
api_key=api_key,
|
||||
model_kwargs={"stop": ['"""']},
|
||||
openai_api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -59,8 +59,8 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat
|
|||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
frequency_penalty=0.2,
|
||||
stop='"""',
|
||||
api_key=api_key,
|
||||
model_kwargs={"stop": ['"""']},
|
||||
openai_api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -104,8 +104,8 @@ def extract_questions(
|
|||
model_name=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stop=["A: ", "\n"],
|
||||
api_key=api_key,
|
||||
model_kwargs={"stop": ["A: ", "\n"]},
|
||||
openai_api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -143,8 +143,8 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1
|
|||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
frequency_penalty=0.2,
|
||||
stop=["\n"],
|
||||
api_key=api_key,
|
||||
model_kwargs={"stop": ["\n"]},
|
||||
openai_api_key=api_key,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
|
@ -155,9 +155,9 @@ def converse(
|
|||
references,
|
||||
user_query,
|
||||
conversation_log={},
|
||||
model: Optional[str] = "gpt-3.5-turbo",
|
||||
api_key=None,
|
||||
temperature=0.2,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
api_key: Optional[str] = None,
|
||||
temperature: float = 0.2,
|
||||
completion_func=None,
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -186,11 +186,12 @@ def test_generate_search_query_with_date_and_context_from_chat_history():
|
|||
@pytest.mark.chatquality
|
||||
def test_chat_with_no_chat_history_or_retrieved_content():
|
||||
# Act
|
||||
response = converse(
|
||||
response_gen = converse(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Hello, my name is Testatron. Who are you?",
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Khoj", "khoj"]
|
||||
|
@ -210,12 +211,13 @@ def test_answer_from_chat_history_and_no_content():
|
|||
]
|
||||
|
||||
# Act
|
||||
response = converse(
|
||||
response_gen = converse(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="What is my name?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Testatron", "testatron"]
|
||||
|
@ -240,12 +242,13 @@ def test_answer_from_chat_history_and_previously_retrieved_content():
|
|||
]
|
||||
|
||||
# Act
|
||||
response = converse(
|
||||
response_gen = converse(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
|
@ -264,7 +267,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content():
|
|||
]
|
||||
|
||||
# Act
|
||||
response = converse(
|
||||
response_gen = converse(
|
||||
references=[
|
||||
"Testatron was born on 1st April 1984 in Testville."
|
||||
], # Assume context retrieved from notes for the user_query
|
||||
|
@ -272,6 +275,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content():
|
|||
conversation_log=populate_chat_history(message_list),
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
|
@ -289,12 +293,13 @@ def test_refuse_answering_unanswerable_question():
|
|||
]
|
||||
|
||||
# Act
|
||||
response = converse(
|
||||
response_gen = converse(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
|
@ -329,11 +334,12 @@ Expenses:Food:Dining 10.00 USD""",
|
|||
]
|
||||
|
||||
# Act
|
||||
response = converse(
|
||||
response_gen = converse(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I have for Dinner today?",
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["tacos", "Tacos"]
|
||||
|
@ -360,11 +366,12 @@ Expenses:Food:Dining 10.00 USD""",
|
|||
]
|
||||
|
||||
# Act
|
||||
response = converse(
|
||||
response_gen = converse(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="How much did I spend on dining this year?",
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
|
@ -383,12 +390,13 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
|
|||
]
|
||||
|
||||
# Act
|
||||
response = converse(
|
||||
response_gen = converse(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Write a haiku about unit testing in 3 lines",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["test", "Test"]
|
||||
|
@ -414,11 +422,12 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
|||
]
|
||||
|
||||
# Act
|
||||
response = converse(
|
||||
response_gen = converse(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="How many kids does my older sister have?",
|
||||
api_key=api_key,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"]
|
||||
|
|
|
@ -41,7 +41,7 @@ def populate_chat_history(message_list):
|
|||
def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"')
|
||||
response_message = response.json()["response"]
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Khoj", "khoj"]
|
||||
|
@ -63,7 +63,7 @@ def test_answer_from_chat_history(chat_client):
|
|||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is my name?"')
|
||||
response_message = response.json()["response"]
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Testatron", "testatron"]
|
||||
|
@ -89,7 +89,7 @@ def test_answer_from_currently_retrieved_content(chat_client):
|
|||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
|
||||
response_message = response.json()["response"]
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
@ -112,7 +112,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client):
|
|||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
||||
response_message = response.json()["response"]
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
@ -134,7 +134,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client):
|
|||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
||||
response_message = response.json()["response"]
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
@ -158,13 +158,13 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client):
|
|||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
||||
response_message = response.json()["response"]
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected chat director to say they don't know in response, but got: " + response
|
||||
"Expected chat director to say they don't know in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
|
@ -176,7 +176,7 @@ def test_answer_requires_current_date_awareness(chat_client):
|
|||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"')
|
||||
response_message = response.json()["response"]
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Arak", "Medellin"]
|
||||
|
@ -193,7 +193,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien
|
|||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"')
|
||||
response_message = response.json()["response"]
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
@ -213,7 +213,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
|||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."')
|
||||
response_message = response.json()["response"]
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["test", "Test"]
|
||||
|
@ -230,7 +230,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
|||
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client):
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"')
|
||||
response_message = response.json()["response"]
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
|
@ -259,7 +259,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client):
|
|||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is my name?"')
|
||||
response_message = response.json()["response"]
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Testatron", "testatron"]
|
||||
|
@ -275,7 +275,7 @@ def test_answer_requires_multiple_independent_searches(chat_client):
|
|||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"')
|
||||
response_message = response.json()["response"]
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["he is older than namita", "xi is older than namita"]
|
||||
|
|
Loading…
Reference in a new issue