Update chat API client tests to mix testing of batch and streaming mode

This commit is contained in:
Debanjum Singh Solanky 2024-07-23 15:05:06 +05:30
parent 3f5f418d0e
commit 54b4203683
4 changed files with 30 additions and 38 deletions

View file

@ -22,7 +22,7 @@ magika = Magika()
def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict: def collect_files(search_type: Optional[SearchType] = SearchType.All, user=None) -> dict:
files = {} files: dict[str, dict] = {"docx": {}, "image": {}}
if search_type == SearchType.All or search_type == SearchType.Org: if search_type == SearchType.All or search_type == SearchType.Org:
org_config = LocalOrgConfig.objects.filter(user=user).first() org_config = LocalOrgConfig.objects.filter(user=user).first()

View file

@ -455,13 +455,13 @@ def test_user_no_data_returns_empty(client, sample_org_data, api_user3: KhojApiU
@pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="requires OPENAI_API_KEY") @pytest.mark.skipif(os.getenv("OPENAI_API_KEY") is None, reason="requires OPENAI_API_KEY")
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser): async def test_chat_with_unauthenticated_user(chat_client_with_auth, api_user2: KhojApiUser):
# Arrange # Arrange
headers = {"Authorization": f"Bearer {api_user2.token}"} headers = {"Authorization": f"Bearer {api_user2.token}"}
# Act # Act
auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true', headers=headers) auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"', headers=headers)
no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"&stream=true') no_auth_response = chat_client_with_auth.get(f'/api/chat?q="Hello!"')
# Assert # Assert
assert auth_response.status_code == 200 assert auth_response.status_code == 200

View file

@ -68,10 +68,8 @@ def test_chat_with_online_content(client_offline_chat):
# Act # Act
q = "/online give me the link to paul graham's essay how to do great work" q = "/online give me the link to paul graham's essay how to do great work"
encoded_q = quote(q, safe="") encoded_q = quote(q, safe="")
response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true") response = client_offline_chat.get(f"/api/chat?q={encoded_q}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = [ expected_responses = [
@ -92,10 +90,8 @@ def test_chat_with_online_webpage_content(client_offline_chat):
# Act # Act
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
encoded_q = quote(q, safe="") encoded_q = quote(q, safe="")
response = client_offline_chat.get(f"/api/chat?q={encoded_q}&stream=true") response = client_offline_chat.get(f"/api/chat?q={encoded_q}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = ["185", "1871", "horse"] expected_responses = ["185", "1871", "horse"]

View file

@ -49,8 +49,8 @@ def create_conversation(message_list, user, agent=None):
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_chat_with_no_chat_history_or_retrieved_content(chat_client): def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
# Act # Act
response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true') response = chat_client.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"')
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
expected_responses = ["Khoj", "khoj"] expected_responses = ["Khoj", "khoj"]
@ -67,10 +67,8 @@ def test_chat_with_online_content(chat_client):
# Act # Act
q = "/online give me the link to paul graham's essay how to do great work" q = "/online give me the link to paul graham's essay how to do great work"
encoded_q = quote(q, safe="") encoded_q = quote(q, safe="")
response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true") response = chat_client.get(f"/api/chat?q={encoded_q}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = [ expected_responses = [
@ -91,10 +89,8 @@ def test_chat_with_online_webpage_content(chat_client):
# Act # Act
q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?" q = "/online how many firefighters were involved in the great chicago fire and which year did it take place?"
encoded_q = quote(q, safe="") encoded_q = quote(q, safe="")
response = chat_client.get(f"/api/chat?q={encoded_q}&stream=true") response = chat_client.get(f"/api/chat?q={encoded_q}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
response_message = response_message.split("### compiled references")[0]
# Assert # Assert
expected_responses = ["185", "1871", "horse"] expected_responses = ["185", "1871", "horse"]
@ -144,7 +140,7 @@ def test_answer_from_currently_retrieved_content(chat_client, default_user2: Kho
# Act # Act
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"') response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -168,7 +164,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_n
# Act # Act
response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"') response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -191,7 +187,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, d
# Act # Act
response = chat_client.get(f'/api/chat?q="Where was I born?"') response = chat_client.get(f'/api/chat?q="Where was I born?"')
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -215,8 +211,8 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true') response = chat_client.get(f'/api/chat?q="Where was I born?"')
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
expected_responses = [ expected_responses = [
@ -226,6 +222,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
"do not have", "do not have",
"don't have", "don't have",
"where were you born?", "where were you born?",
"where you were born?",
] ]
assert response.status_code == 200 assert response.status_code == 200
@ -280,8 +277,8 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true") response = chat_client_no_background.get(f"/api/chat?q={query}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -527,8 +524,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else."&stream=true') response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else.')
response_message = response.content.decode("utf-8").split("### compiled references")[0] response_message = response.json()["response"]
# Assert # Assert
expected_responses = ["test", "Test"] expected_responses = ["test", "Test"]
@ -544,9 +541,8 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
@pytest.mark.chatquality @pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background): def test_ask_for_clarification_if_not_enough_context_in_question(chat_client_no_background):
# Act # Act
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"')
response = chat_client_no_background.get(f'/api/chat?q="What is the name of Namitas older son?"&stream=true') response_message = response.json()["response"].lower()
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower()
# Assert # Assert
expected_responses = [ expected_responses = [
@ -658,8 +654,8 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
def test_answer_requires_multiple_independent_searches(chat_client): def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information" "Chat director should be able to answer by doing multiple independent searches for required information"
# Act # Act
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"&stream=true') response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"')
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower() response_message = response.json()["response"].lower()
# Assert # Assert
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"] expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]
@ -683,8 +679,8 @@ def test_answer_using_file_filter(chat_client):
'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"' 'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
) )
response = chat_client.get(f"/api/chat?q={query}&stream=true") response = chat_client.get(f"/api/chat?q={query}")
response_message = response.content.decode("utf-8").split("### compiled references")[0].lower() response_message = response.json()["response"].lower()
# Assert # Assert
expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"] expected_responses = ["he is older than namita", "xi is older than namita", "xi li is older than namita"]