From 11f0a9f196aed2f784350dde98d2e9678212987b Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Fri, 7 Jul 2023 15:23:44 -0700 Subject: [PATCH] Fix chat tests since streaming. Pass args correctly to chat methods - Fix testing gpt converse method after it started streaming responses - Pass stop in model_kwargs dictionary and api key in openai_api_key parameter to chat completion methods. This should resolve the arg warning thrown by OpenAI module --- src/khoj/processor/conversation/gpt.py | 22 ++++++++++----------- tests/test_chat_actors.py | 27 +++++++++++++++++--------- tests/test_chat_director.py | 26 ++++++++++++------------- 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/gpt.py index ed8ea455..8051afdb 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/gpt.py @@ -31,8 +31,8 @@ def answer(text, user_query, model, api_key=None, temperature=0.5, max_tokens=50 model_name=model, temperature=temperature, max_tokens=max_tokens, - stop='"""', - api_key=api_key, + model_kwargs={"stop": ['"""']}, + openai_api_key=api_key, ) # Extract, Clean Message from GPT's Response @@ -59,8 +59,8 @@ def summarize(text, summary_type, model, user_query=None, api_key=None, temperat temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, - stop='"""', - api_key=api_key, + model_kwargs={"stop": ['"""']}, + openai_api_key=api_key, ) # Extract, Clean Message from GPT's Response @@ -104,8 +104,8 @@ def extract_questions( model_name=model, temperature=temperature, max_tokens=max_tokens, - stop=["A: ", "\n"], - api_key=api_key, + model_kwargs={"stop": ["A: ", "\n"]}, + openai_api_key=api_key, ) # Extract, Clean Message from GPT's Response @@ -143,8 +143,8 @@ def extract_search_type(text, model, api_key=None, temperature=0.5, max_tokens=1 temperature=temperature, max_tokens=max_tokens, frequency_penalty=0.2, - stop=["\n"], - api_key=api_key, + model_kwargs={"stop": ["\n"]}, + openai_api_key=api_key, ) # Extract, Clean Message from GPT's Response @@ -155,9 +155,9 @@ def converse( references, user_query, conversation_log={}, - model: Optional[str] = "gpt-3.5-turbo", - api_key=None, - temperature=0.2, + model: str = "gpt-3.5-turbo", + api_key: Optional[str] = None, + temperature: float = 0.2, completion_func=None, ): """ diff --git a/tests/test_chat_actors.py b/tests/test_chat_actors.py index 524f0813..9f8e821d 100644 --- a/tests/test_chat_actors.py +++ b/tests/test_chat_actors.py @@ -186,11 +186,12 @@ def test_generate_search_query_with_date_and_context_from_chat_history(): @pytest.mark.chatquality def test_chat_with_no_chat_history_or_retrieved_content(): # Act - response = converse( + response_gen = converse( references=[], # Assume no context retrieved from notes for the user_query user_query="Hello, my name is Testatron. Who are you?", api_key=api_key, ) + response = "".join([response_chunk for response_chunk in response_gen]) # Assert expected_responses = ["Khoj", "khoj"] @@ -210,12 +211,13 @@ def test_answer_from_chat_history_and_no_content(): ] # Act - response = converse( + response_gen = converse( references=[], # Assume no context retrieved from notes for the user_query user_query="What is my name?", conversation_log=populate_chat_history(message_list), api_key=api_key, ) + response = "".join([response_chunk for response_chunk in response_gen]) # Assert expected_responses = ["Testatron", "testatron"] @@ -240,12 +242,13 @@ def test_answer_from_chat_history_and_previously_retrieved_content(): ] # Act - response = converse( + response_gen = converse( references=[], # Assume no context retrieved from notes for the user_query user_query="Where was I born?", conversation_log=populate_chat_history(message_list), api_key=api_key, ) + response = "".join([response_chunk for response_chunk in response_gen]) # Assert assert len(response) > 0 @@ -264,7 +267,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(): ] # Act - response = converse( + response_gen = converse( references=[ "Testatron was born on 1st April 1984 in Testville." ], # Assume context retrieved from notes for the user_query @@ -272,6 +275,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(): conversation_log=populate_chat_history(message_list), api_key=api_key, ) + response = "".join([response_chunk for response_chunk in response_gen]) # Assert assert len(response) > 0 @@ -289,12 +293,13 @@ def test_refuse_answering_unanswerable_question(): ] # Act - response = converse( + response_gen = converse( references=[], # Assume no context retrieved from notes for the user_query user_query="Where was I born?", conversation_log=populate_chat_history(message_list), api_key=api_key, ) + response = "".join([response_chunk for response_chunk in response_gen]) # Assert expected_responses = [ @@ -329,11 +334,12 @@ Expenses:Food:Dining 10.00 USD""", ] # Act - response = converse( + response_gen = converse( references=context, # Assume context retrieved from notes for the user_query user_query="What did I have for Dinner today?", api_key=api_key, ) + response = "".join([response_chunk for response_chunk in response_gen]) # Assert expected_responses = ["tacos", "Tacos"] @@ -360,11 +366,12 @@ Expenses:Food:Dining 10.00 USD""", ] # Act - response = converse( + response_gen = converse( references=context, # Assume context retrieved from notes for the user_query user_query="How much did I spend on dining this year?", api_key=api_key, ) + response = "".join([response_chunk for response_chunk in response_gen]) # Assert assert len(response) > 0 @@ -383,12 +390,13 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(): ] # Act - response = converse( + response_gen = converse( references=[], # Assume no context retrieved from notes for the user_query user_query="Write a haiku about unit testing in 3 lines", conversation_log=populate_chat_history(message_list), api_key=api_key, ) + response = "".join([response_chunk for response_chunk in response_gen]) # Assert expected_responses = ["test", "Test"] @@ -414,11 +422,12 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."" ] # Act - response = converse( + response_gen = converse( references=context, # Assume context retrieved from notes for the user_query user_query="How many kids does my older sister have?", api_key=api_key, ) + response = "".join([response_chunk for response_chunk in response_gen]) # Assert expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"] diff --git a/tests/test_chat_director.py b/tests/test_chat_director.py index 0d29ae53..4dbf808d 100644 --- a/tests/test_chat_director.py +++ b/tests/test_chat_director.py @@ -41,7 +41,7 @@ def populate_chat_history(message_list): def test_chat_with_no_chat_history_or_retrieved_content(chat_client): # Act response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"') - response_message = response.json()["response"] + response_message = response.content.decode("utf-8") # Assert expected_responses = ["Khoj", "khoj"] @@ -63,7 +63,7 @@ def test_answer_from_chat_history(chat_client): # Act response = chat_client.get(f'/api/chat?q="What is my name?"') - response_message = response.json()["response"] + response_message = response.content.decode("utf-8") # Assert expected_responses = ["Testatron", "testatron"] @@ -89,7 +89,7 @@ def test_answer_from_currently_retrieved_content(chat_client): # Act response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"') - response_message = response.json()["response"] + response_message = response.content.decode("utf-8") # Assert assert response.status_code == 200 @@ -112,7 +112,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client): # Act response = chat_client.get(f'/api/chat?q="Where was I born?"') - response_message = response.json()["response"] + response_message = response.content.decode("utf-8") # Assert assert response.status_code == 200 @@ -134,7 +134,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client): # Act response = chat_client.get(f'/api/chat?q="Where was I born?"') - response_message = response.json()["response"] + response_message = response.content.decode("utf-8") # Assert assert response.status_code == 200 @@ -158,13 +158,13 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client): # Act response = chat_client.get(f'/api/chat?q="Where was I born?"') - response_message = response.json()["response"] + response_message = response.content.decode("utf-8") # Assert expected_responses = ["don't know", "do not know", "no information", "do not have", "don't have"] assert response.status_code == 200 assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected chat director to say they don't know in response, but got: " + response + "Expected chat director to say they don't know in response, but got: " + response_message ) @@ -176,7 +176,7 @@ def test_answer_requires_current_date_awareness(chat_client): "Chat actor should be able to answer questions relative to current date using provided notes" # Act response = chat_client.get(f'/api/chat?q="Where did I have lunch today?"') - response_message = response.json()["response"] + response_message = response.content.decode("utf-8") # Assert expected_responses = ["Arak", "Medellin"] @@ -193,7 +193,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_clien "Chat director should be able to answer questions that require date aware aggregation across multiple notes" # Act response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"') - response_message = response.json()["response"] + response_message = response.content.decode("utf-8") # Assert assert response.status_code == 200 @@ -213,7 +213,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c # Act response = chat_client.get(f'/api/chat?q=""Write a haiku about unit testing. Do not say anything else."') - response_message = response.json()["response"] + response_message = response.content.decode("utf-8") # Assert expected_responses = ["test", "Test"] @@ -230,7 +230,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c def test_ask_for_clarification_if_not_enough_context_in_question(chat_client): # Act response = chat_client.get(f'/api/chat?q="What is the name of Namitas older son"') - response_message = response.json()["response"] + response_message = response.content.decode("utf-8") # Assert expected_responses = [ @@ -259,7 +259,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client): # Act response = chat_client.get(f'/api/chat?q="What is my name?"') - response_message = response.json()["response"] + response_message = response.content.decode("utf-8") # Assert expected_responses = ["Testatron", "testatron"] @@ -275,7 +275,7 @@ def test_answer_requires_multiple_independent_searches(chat_client): "Chat director should be able to answer by doing multiple independent searches for required information" # Act response = chat_client.get(f'/api/chat?q="Is Xi older than Namita?"') - response_message = response.json()["response"] + response_message = response.content.decode("utf-8") # Assert expected_responses = ["he is older than namita", "xi is older than namita"]