From 45c623f95c0c9d4e74a2f4159b706caa1800d931 Mon Sep 17 00:00:00 2001 From: Debanjum Date: Mon, 18 Nov 2024 16:05:15 -0800 Subject: [PATCH] Dedupe, organize chat actor, director tests - Move Chat actor tests that were previously in chat director tests file - Dedupe online, offline io selector chat actor tests --- tests/helpers.py | 14 ++++ tests/test_offline_chat_actors.py | 104 ++++++++++++++++++++------ tests/test_offline_chat_director.py | 87 ---------------------- tests/test_online_chat_actors.py | 37 +++++++++- tests/test_online_chat_director.py | 110 ++-------------------------- 5 files changed, 136 insertions(+), 216 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index 6c70536e..3824c615 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -15,6 +15,7 @@ from khoj.database.models import ( Subscription, UserConversationConfig, ) +from khoj.processor.conversation.utils import message_to_log def get_chat_provider(default: ChatModelOptions.ModelType | None = ChatModelOptions.ModelType.OFFLINE): @@ -43,6 +44,19 @@ def get_chat_api_key(provider: ChatModelOptions.ModelType = None): return os.getenv("OPENAI_API_KEY") or os.getenv("GEMINI_API_KEY") or os.getenv("ANTHROPIC_API_KEY") +def generate_chat_history(message_list): + # Generate conversation logs + conversation_log = {"chat": []} + for user_message, chat_response, context in message_list: + message_to_log( + user_message, + chat_response, + {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, + conversation_log=conversation_log.get("chat", []), + ) + return conversation_log + + class UserFactory(factory.django.DjangoModelFactory): class Meta: model = KhojUser diff --git a/tests/test_offline_chat_actors.py b/tests/test_offline_chat_actors.py index de0bbab6..b404f0e8 100644 --- a/tests/test_offline_chat_actors.py +++ b/tests/test_offline_chat_actors.py @@ -3,7 +3,9 @@ from datetime import datetime import pytest from khoj.database.models import ChatModelOptions -from tests.helpers import get_chat_provider +from khoj.routers.helpers import aget_data_sources_and_output_format +from khoj.utils.helpers import ConversationCommand +from tests.helpers import ConversationFactory, generate_chat_history, get_chat_provider SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE pytestmark = pytest.mark.skipif( @@ -134,7 +136,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model): # Act response = extract_questions_offline( query, - conversation_log=populate_chat_history(message_list), + conversation_log=generate_chat_history(message_list), loaded_model=loaded_model, use_history=True, ) @@ -180,7 +182,7 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model): # Act response = extract_questions_offline( "Is she a Doctor?", - conversation_log=populate_chat_history(message_list), + conversation_log=generate_chat_history(message_list), loaded_model=loaded_model, use_history=True, ) @@ -209,7 +211,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo # Act response = extract_questions_offline( "What was the Pizza place we ate at over there?", - conversation_log=populate_chat_history(message_list), + conversation_log=generate_chat_history(message_list), loaded_model=loaded_model, ) @@ -227,6 +229,77 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo ) +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +@pytest.mark.parametrize( + "user_query, expected_conversation_commands", + [ + ( + "Where did I learn to swim?", + {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Text}, + ), + ( + "Where is the nearest hospital?", + {"sources": [ConversationCommand.Online], "output": ConversationCommand.Text}, + ), + ( + "Summarize the wikipedia page on the history of the internet", + {"sources": [ConversationCommand.Webpage], "output": ConversationCommand.Text}, + ), + ( + "How many noble gases are there?", + {"sources": [ConversationCommand.General], "output": ConversationCommand.Text}, + ), + ( + "Make a painting incorporating my past diving experiences", + {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Image}, + ), + ( + "Create a chart of the weather over the next 7 days in Timbuktu", + {"sources": [ConversationCommand.Online, ConversationCommand.Code], "output": ConversationCommand.Text}, + ), + ( + "What's the highest point in this country and have I been there?", + {"sources": [ConversationCommand.Online, ConversationCommand.Notes], "output": ConversationCommand.Text}, + ), + ], +) +async def test_select_data_sources_actor_chooses_to_search_notes( + client_offline_chat, user_query, expected_conversation_commands, default_user2 +): + # Act + selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2) + + # Assert + assert set(expected_conversation_commands["sources"]) == set(selected_conversation_commands["sources"]) + assert expected_conversation_commands["output"] == selected_conversation_commands["output"] + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2): + # Arrange + user_query = "What's the latest in the Israel/Palestine conflict?" + chat_log = [ + ( + "Let's talk about the current events around the world.", + "Sure, let's discuss the current events. What would you like to know?", + [], + ), + ("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []), + ] + chat_history = ConversationFactory(user=default_user2, conversation_log=generate_chat_history(chat_log)) + + # Act + tools = await aget_data_sources_and_output_format(user_query, chat_history, is_task=False) + + # Assert + tools = [tool.value for tool in tools] + assert tools == ["online"] + + # ---------------------------------------------------------------------------------------------------- @pytest.mark.chatquality def test_chat_with_no_chat_history_or_retrieved_content(loaded_model): @@ -264,7 +337,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model) response_gen = converse_offline( references=[], # Assume no context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=populate_chat_history(message_list), + conversation_log=generate_chat_history(message_list), loaded_model=loaded_model, ) response = "".join([response_chunk for response_chunk in response_gen]) @@ -291,7 +364,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model): {"compiled": "Testatron was born on 1st April 1984 in Testville."} ], # Assume context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=populate_chat_history(message_list), + conversation_log=generate_chat_history(message_list), loaded_model=loaded_model, ) response = "".join([response_chunk for response_chunk in response_gen]) @@ -316,7 +389,7 @@ def test_refuse_answering_unanswerable_question(loaded_model): response_gen = converse_offline( references=[], # Assume no context retrieved from notes for the user_query user_query="Where was I born?", - conversation_log=populate_chat_history(message_list), + conversation_log=generate_chat_history(message_list), loaded_model=loaded_model, ) response = "".join([response_chunk for response_chunk in response_gen]) @@ -429,7 +502,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded response_gen = converse_offline( references=[], # Assume no context retrieved from notes for the user_query user_query="Write a haiku about unit testing in 3 lines", - conversation_log=populate_chat_history(message_list), + conversation_log=generate_chat_history(message_list), loaded_model=loaded_model, ) response = "".join([response_chunk for response_chunk in response_gen]) @@ -549,18 +622,3 @@ def test_filter_questions(): filtered_questions = filter_questions(test_questions) assert len(filtered_questions) == 1 assert filtered_questions[0] == "Who is on the basketball team?" - - -# Helpers -# ---------------------------------------------------------------------------------------------------- -def populate_chat_history(message_list): - # Generate conversation logs - conversation_log = {"chat": []} - for user_message, chat_response, context in message_list: - message_to_log( - user_message, - chat_response, - {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, - conversation_log=conversation_log["chat"], - ) - return conversation_log diff --git a/tests/test_offline_chat_director.py b/tests/test_offline_chat_director.py index bb57d1b6..5caa1ca0 100644 --- a/tests/test_offline_chat_director.py +++ b/tests/test_offline_chat_director.py @@ -7,7 +7,6 @@ from freezegun import freeze_time from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import message_to_log -from khoj.routers.helpers import aget_data_sources_and_output_format from tests.helpers import ConversationFactory, get_chat_provider SKIP_TESTS = get_chat_provider(default=None) != ChatModelOptions.ModelType.OFFLINE @@ -725,89 +724,3 @@ def test_answer_requires_multiple_independent_searches(client_offline_chat): assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( "Expected Xi is older than Namita, but got: " + response_message ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_online(client_offline_chat): - # Arrange - user_query = "What's the weather in Patagonia this week?" - - # Act - tools = await aget_data_sources_and_output_format(user_query, {}, is_task=False) - - # Assert - tools = [tool.value for tool in tools] - assert tools == ["online"] - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_notes(client_offline_chat): - # Arrange - user_query = "Where did I go for my first battleship training?" - - # Act - tools = await aget_data_sources_and_output_format(user_query, {}, is_task=False) - - # Assert - tools = [tool.value for tool in tools] - assert tools == ["notes"] - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_online_or_general_and_notes(client_offline_chat): - # Arrange - user_query = "What's the highest point in Patagonia and have I been there?" - - # Act - tools = await aget_data_sources_and_output_format(user_query, {}, is_task=False) - - # Assert - tools = [tool.value for tool in tools] - assert len(tools) == 2 - assert "online" or "general" in tools - assert "notes" in tools - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_general(client_offline_chat): - # Arrange - user_query = "How many noble gases are there?" - - # Act - tools = await aget_data_sources_and_output_format(user_query, {}, is_task=False) - - # Assert - tools = [tool.value for tool in tools] - assert tools == ["general"] - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2): - # Arrange - user_query = "What's the latest in the Israel/Palestine conflict?" - chat_log = [ - ( - "Let's talk about the current events around the world.", - "Sure, let's discuss the current events. What would you like to know?", - [], - ), - ("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []), - ] - chat_history = create_conversation(chat_log, default_user2) - - # Act - tools = await aget_data_sources_and_output_format(user_query, chat_history, is_task=False) - - # Assert - tools = [tool.value for tool in tools] - assert tools == ["online"] diff --git a/tests/test_online_chat_actors.py b/tests/test_online_chat_actors.py index 449ed9b1..3a6db6f8 100644 --- a/tests/test_online_chat_actors.py +++ b/tests/test_online_chat_actors.py @@ -14,7 +14,7 @@ from khoj.routers.helpers import ( should_notify, ) from khoj.utils.helpers import ConversationCommand -from tests.helpers import get_chat_api_key +from tests.helpers import generate_chat_history, get_chat_api_key # Initialize variables for tests api_key = get_chat_api_key() @@ -537,6 +537,10 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_u "Summarize the wikipedia page on the history of the internet", {"sources": [ConversationCommand.Webpage], "output": ConversationCommand.Text}, ), + ( + "How many noble gases are there?", + {"sources": [ConversationCommand.General], "output": ConversationCommand.Text}, + ), ( "Make a painting incorporating my past diving experiences", {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Image}, @@ -545,6 +549,10 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_u "Create a chart of the weather over the next 7 days in Timbuktu", {"sources": [ConversationCommand.Online, ConversationCommand.Code], "output": ConversationCommand.Text}, ), + ( + "What's the highest point in this country and have I been there?", + {"sources": [ConversationCommand.Online, ConversationCommand.Notes], "output": ConversationCommand.Text}, + ), ], ) async def test_select_data_sources_actor_chooses_to_search_notes( @@ -554,7 +562,32 @@ async def test_select_data_sources_actor_chooses_to_search_notes( selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2) # Assert - assert expected_conversation_commands == selected_conversation_commands + assert set(expected_conversation_commands["sources"]) == set(selected_conversation_commands["sources"]) + assert expected_conversation_commands["output"] == selected_conversation_commands["output"] + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.anyio +@pytest.mark.django_db(transaction=True) +async def test_get_correct_tools_with_chat_history(chat_client, default_user2): + # Arrange + user_query = "What's the latest in the Israel/Palestine conflict?" + chat_log = [ + ( + "Let's talk about the current events around the world.", + "Sure, let's discuss the current events. What would you like to know?", + [], + ), + ("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []), + ] + chat_history = generate_chat_history(chat_log) + + # Act + selected = await aget_data_sources_and_output_format(user_query, chat_history, False, default_user2) + sources = selected["sources"] + + # Assert + assert sources == [ConversationCommand.Online] # ---------------------------------------------------------------------------------------------------- diff --git a/tests/test_online_chat_director.py b/tests/test_online_chat_director.py index 317384cb..94545b4c 100644 --- a/tests/test_online_chat_director.py +++ b/tests/test_online_chat_director.py @@ -7,8 +7,7 @@ from freezegun import freeze_time from khoj.database.models import Agent, Entry, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import message_to_log -from khoj.routers.helpers import aget_data_sources_and_output_format -from tests.helpers import ConversationFactory, get_chat_api_key +from tests.helpers import ConversationFactory, generate_chat_history, get_chat_api_key # Initialize variables for tests api_key = get_chat_api_key() @@ -21,22 +20,9 @@ if api_key is None: # Helpers # ---------------------------------------------------------------------------------------------------- -def generate_history(message_list): - # Generate conversation logs - conversation_log = {"chat": []} - for user_message, gpt_message, context in message_list: - message_to_log( - user_message, - gpt_message, - {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, - conversation_log=conversation_log.get("chat", []), - ) - return conversation_log - - def create_conversation(message_list, user, agent=None): # Generate conversation logs - conversation_log = generate_history(message_list) + conversation_log = generate_chat_history(message_list) # Update Conversation Metadata Logs in Database return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent) @@ -619,7 +605,9 @@ def test_answer_in_chat_history_by_conversation_id(chat_client, default_user2: K # Act query = "/general What is my favorite color?" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id, "stream": True}) + response = chat_client.post( + f"/api/chat", json={"q": query, "conversation_id": str(conversation.id), "stream": True} + ) response_message = response.content.decode("utf-8") # Assert @@ -652,7 +640,7 @@ def test_answer_in_chat_history_by_conversation_id_with_agent( # Act query = "/general What did I buy for breakfast?" - response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": conversation.id}) + response = chat_client.post(f"/api/chat", json={"q": query, "conversation_id": str(conversation.id)}) response_message = response.json()["response"] # Assert that agent only responds with the summary of spending @@ -708,89 +696,3 @@ def test_answer_using_file_filter(chat_client): assert only_full_name_check or comparative_statement_check, ( "Expected Xi is older than Namita, but got: " + response_message ) - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_online(chat_client): - # Arrange - user_query = "What's the weather in Patagonia this week?" - - # Act - tools = await aget_data_sources_and_output_format(user_query, {}, False, False) - - # Assert - tools = [tool.value for tool in tools] - assert tools == ["online"] - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_notes(chat_client): - # Arrange - user_query = "Where did I go for my first battleship training?" - - # Act - tools = await aget_data_sources_and_output_format(user_query, {}, False, False) - - # Assert - tools = [tool.value for tool in tools] - assert tools == ["notes"] - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_online_or_general_and_notes(chat_client): - # Arrange - user_query = "What's the highest point in Patagonia and have I been there?" - - # Act - tools = await aget_data_sources_and_output_format(user_query, {}, False, False) - - # Assert - tools = [tool.value for tool in tools] - assert len(tools) == 2 - assert "online" in tools or "general" in tools - assert "notes" in tools - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_general(chat_client): - # Arrange - user_query = "How many noble gases are there?" - - # Act - tools = await aget_data_sources_and_output_format(user_query, {}, False, False) - - # Assert - tools = [tool.value for tool in tools] - assert tools == ["general"] - - -# ---------------------------------------------------------------------------------------------------- -@pytest.mark.anyio -@pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_with_chat_history(chat_client): - # Arrange - user_query = "What's the latest in the Israel/Palestine conflict?" - chat_log = [ - ( - "Let's talk about the current events around the world.", - "Sure, let's discuss the current events. What would you like to know?", - [], - ), - ("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []), - ] - chat_history = generate_history(chat_log) - - # Act - tools = await aget_data_sources_and_output_format(user_query, chat_history, False, False) - - # Assert - tools = [tool.value for tool in tools] - assert tools == ["online"]