from datetime import datetime

import freezegun
import pytest
from freezegun import freeze_time

from khoj.processor.conversation.openai.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import (
    aget_data_sources_and_output_format,
    generate_online_subqueries,
    infer_webpage_urls,
    schedule_query,
    should_notify,
)
from khoj.utils.helpers import ConversationCommand
from tests.helpers import generate_chat_history, get_chat_api_key

# Initialize variables for tests
api_key = get_chat_api_key()
if api_key is None:
    pytest.skip(
        reason="Set OPENAI_API_KEY, GEMINI_API_KEY or ANTHROPIC_API_KEY environment variable to run tests below.",
        allow_module_level=True,
    )

freezegun.configure(extend_ignore_list=["transformers"])


# Test
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_day():
    # Act
    response = extract_questions("Where did I go for dinner yesterday?")

    # Assert
    expected_responses = [
        ("dt='1984-04-01'", ""),
        ("dt>='1984-04-01'", "dt<'1984-04-02'"),
        ("dt>'1984-03-31'", "dt<'1984-04-02'"),
    ]
    assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
        "Expected date filter to limit to 1st April 1984 in response but got: " + response[0]
    )


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_month():
    # Act
    response = extract_questions("Which countries did I visit last month?")

    # Assert
    expected_responses = [("dt>='1984-03-01'", "dt<'1984-04-01'"), ("dt>='1984-03-01'", "dt<='1984-03-31'")]
    assert len(response) == 1
    assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
        "Expected date filter to limit to March 1984 in response but got: " + response[0]
    )


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_year():
    # Act
    response = extract_questions("Which countries have I visited this year?")

    # Assert
    expected_responses = [
        ("dt>='1984-01-01'", ""),
        ("dt>='1984-01-01'", "dt<'1985-01-01'"),
        ("dt>='1984-01-01'", "dt<='1984-12-31'"),
    ]
    assert len(response) == 1
    assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
        "Expected date filter to limit to 1984 in response but got: " + response[0]
    )


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_explicit_questions_from_message():
    # Act
    responses = extract_questions("What is the Sun? What is the Moon?")

    # Assert
    assert len(responses) >= 2
    assert any(["sun" in response.lower() or "moon" in response.lower() for response in responses]), (
        "Expected sun or moon mentioned in generated search queries but got: " + responses
    )


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message():
    # Act
    response = extract_questions("Is Morpheus taller than Neo?")

    # Assert
    expected_responses = [
        ("morpheus", "neo"),
    ]
    assert len(response) > 1
    assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
        "Expected more than one search query in response but got: " + response[0]
    )


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_question_from_chat_history():
    # Arrange
    message_list = [
        ("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
    ]

    # Act
    responses = extract_questions("Does he have any sons?", conversation_log=populate_chat_history(message_list))

    # Assert
    assert all(["Vader" in response for response in responses])


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_answer_from_chat_history():
    # Arrange
    message_list = [
        ("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
    ]

    # Act
    responses = extract_questions("Is she a Jedi?", conversation_log=populate_chat_history(message_list))

    # Assert
    assert all(["Leia" in response for response in responses])


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_question_and_answer_from_chat_history():
    # Arrange
    message_list = [
        ("Does Luke Skywalker have any Siblings?", "Yes, Princess Leia", []),
    ]

    # Act
    response = extract_questions("Who is their father?", conversation_log=populate_chat_history(message_list))

    # Assert
    assert any(["Leia" in response or "Luke" in response for response in response])


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content():
    # Act
    response_gen = converse(
        references=[],  # Assume no context retrieved from notes for the user_query
        user_query="Hello, my name is Testatron. Who are you?",
        api_key=api_key,
    )
    response = "".join([response_chunk for response_chunk in response_gen])

    # Assert
    expected_responses = ["Khoj", "khoj"]
    assert len(response) > 0
    assert any([expected_response in response for expected_response in expected_responses]), (
        "Expected assistants name, [K|k]hoj, in response but got: " + response
    )


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_no_content():
    # Arrange
    message_list = [
        ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
        ("When was I born?", "You were born on 1st April 1984.", []),
    ]

    # Act
    response_gen = converse(
        references=[],  # Assume no context retrieved from notes for the user_query
        user_query="What is my name?",
        conversation_log=populate_chat_history(message_list),
        api_key=api_key,
    )
    response = "".join([response_chunk for response_chunk in response_gen])

    # Assert
    expected_responses = ["Testatron", "testatron"]
    assert len(response) > 0
    assert any([expected_response in response for expected_response in expected_responses]), (
        "Expected [T|t]estatron in response but got: " + response
    )


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content():
    "Chat actor needs to use context in previous notes and chat history to answer question"
    # Arrange
    message_list = [
        ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
        (
            "When was I born?",
            "You were born on 1st April 1984.",
            [{"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "birth.org"}],
        ),
    ]

    # Act
    response_gen = converse(
        references=[],  # Assume no context retrieved from notes for the user_query
        user_query="Where was I born?",
        conversation_log=populate_chat_history(message_list),
        api_key=api_key,
    )
    response = "".join([response_chunk for response_chunk in response_gen])

    # Assert
    assert len(response) > 0
    # Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes
    assert "Testville" in response


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_currently_retrieved_content():
    "Chat actor needs to use context across currently retrieved notes and chat history to answer question"
    # Arrange
    message_list = [
        ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
        ("When was I born?", "You were born on 1st April 1984.", []),
    ]

    # Act
    response_gen = converse(
        references=[
            {"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "background.md"}
        ],  # Assume context retrieved from notes for the user_query
        user_query="Where was I born?",
        conversation_log=populate_chat_history(message_list),
        api_key=api_key,
    )
    response = "".join([response_chunk for response_chunk in response_gen])

    # Assert
    assert len(response) > 0
    assert "Testville" in response


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_refuse_answering_unanswerable_question():
    "Chat actor should not try make up answers to unanswerable questions."
    # Arrange
    message_list = [
        ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
        ("When was I born?", "You were born on 1st April 1984.", []),
    ]

    # Act
    response_gen = converse(
        references=[],  # Assume no context retrieved from notes for the user_query
        user_query="Where was I born?",
        conversation_log=populate_chat_history(message_list),
        api_key=api_key,
    )
    response = "".join([response_chunk for response_chunk in response_gen])

    # Assert
    expected_responses = [
        "don't know",
        "do not know",
        "no information",
        "do not have",
        "don't have",
        "cannot answer",
        "I'm sorry",
    ]
    assert len(response) > 0
    assert any([expected_response in response for expected_response in expected_responses]), (
        "Expected chat actor to say they don't know in response, but got: " + response
    )


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_requires_current_date_awareness():
    "Chat actor should be able to answer questions relative to current date using provided notes"
    # Arrange
    context = [
        {
            "compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
Expenses:Food:Dining  10.00 USD""",
            "file": "Ledger.org",
        },
        {
            "compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
Expenses:Food:Dining  10.00 USD""",
            "file": "Ledger.org",
        },
        {
            "compiled": f"""2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries  10.00 USD""",
            "file": "Ledger.org",
        },
        {
            "compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining  10.00 USD""",
            "file": "Ledger.org",
        },
    ]

    # Act
    response_gen = converse(
        references=context,  # Assume context retrieved from notes for the user_query
        user_query="What did I have for Dinner today?",
        api_key=api_key,
    )
    response = "".join([response_chunk for response_chunk in response_gen])

    # Assert
    expected_responses = ["tacos", "Tacos"]
    assert len(response) > 0
    assert any([expected_response in response for expected_response in expected_responses]), (
        "Expected [T|t]acos in response, but got: " + response
    )


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_requires_date_aware_aggregation_across_provided_notes():
    "Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
    # Arrange
    context = [
        {
            "compiled": f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
Expenses:Food:Dining  10.00 USD""",
            "file": "Ledger.md",
        },
        {
            "compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
Expenses:Food:Dining  10.00 USD""",
            "file": "Ledger.md",
        },
        {
            "compiled": f"""2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries  10.00 USD""",
            "file": "Ledger.md",
        },
        {
            "compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining  10.00 USD""",
            "file": "Ledger.md",
        },
    ]

    # Act
    response_gen = converse(
        references=context,  # Assume context retrieved from notes for the user_query
        user_query="How much did I spend on dining this year?",
        api_key=api_key,
    )
    response = "".join([response_chunk for response_chunk in response_gen])

    # Assert
    assert len(response) > 0
    assert "20" in response


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_general_question_not_in_chat_history_or_retrieved_content():
    "Chat actor should be able to answer general questions not requiring looking at chat history or notes"
    # Arrange
    message_list = [
        ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
        ("When was I born?", "You were born on 1st April 1984.", []),
        ("Where was I born?", "You were born Testville.", []),
    ]

    # Act
    response_gen = converse(
        references=[],  # Assume no context retrieved from notes for the user_query
        user_query="Write a haiku about unit testing in 3 lines. Do not say anything else",
        conversation_log=populate_chat_history(message_list),
        api_key=api_key,
    )
    response = "".join([response_chunk for response_chunk in response_gen])

    # Assert
    expected_responses = ["test", "bug", "code"]
    assert len(response.splitlines()) == 3  # haikus are 3 lines long
    assert any([expected_response in response.lower() for expected_response in expected_responses]), (
        "Expected haiku about unit test, but got: " + response
    )


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question():
    "Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
    # Arrange
    context = [
        {
            "compiled": f"""# Ramya
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""",
            "file": "Family.md",
        },
        {
            "compiled": f"""# Fang
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""",
            "file": "Family.md",
        },
        {
            "compiled": f"""# Aiyla
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""",
            "file": "Family.md",
        },
    ]

    # Act
    response_gen = converse(
        references=context,  # Assume context retrieved from notes for the user_query
        user_query="How many kids does my older sister have?",
        api_key=api_key,
    )
    response = "".join([response_chunk for response_chunk in response_gen])

    # Assert
    expected_responses = [
        "which sister",
        "Which sister",
        "which of your sister",
        "Which of your sister",
        "Could you provide",
    ]
    assert any([expected_response in response for expected_response in expected_responses]), (
        "Expected chat actor to ask for clarification in response, but got: " + response
    )


# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_agent_prompt_should_be_used(openai_agent):
    "Chat actor should ask be tuned to think like an accountant based on the agent definition"
    # Arrange
    context = [
        {"compiled": f"""I went to the store and bought some bananas for 2.20""", "file": "Ledger.md"},
        {"compiled": f"""I went to the store and bought some apples for 1.30""", "file": "Ledger.md"},
        {"compiled": f"""I went to the store and bought some oranges for 6.00""", "file": "Ledger.md"},
    ]
    expected_responses = ["9.50", "9.5"]

    # Act
    response_gen = converse(
        references=context,  # Assume context retrieved from notes for the user_query
        user_query="What did I buy?",
        api_key=api_key,
    )
    no_agent_response = "".join([response_chunk for response_chunk in response_gen])
    response_gen = converse(
        references=context,  # Assume context retrieved from notes for the user_query
        user_query="What did I buy?",
        api_key=api_key,
        agent=openai_agent,
    )
    agent_response = "".join([response_chunk for response_chunk in response_gen])

    # Assert that the model without the agent prompt does not include the summary of purchases
    assert all([expected_response not in no_agent_response for expected_response in expected_responses]), (
        "Expected chat actor to summarize values of purchases" + no_agent_response
    )
    assert any([expected_response in agent_response for expected_response in expected_responses]), (
        "Expected chat actor to summarize values of purchases" + agent_response
    )


# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@freeze_time("2024-04-04", ignore=["transformers"])
async def test_websearch_with_operators(chat_client, default_user2):
    # Arrange
    user_query = "Share popular posts on r/worldnews this month"

    # Act
    responses = await generate_online_subqueries(user_query, {}, None, default_user2)

    # Assert
    assert any(
        ["reddit.com/r/worldnews" in response for response in responses]
    ), "Expected a search query to include site:reddit.com but got: " + str(responses)

    assert any(
        ["site:reddit.com" in response for response in responses]
    ), "Expected a search query to include site:reddit.com but got: " + str(responses)


# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_websearch_khoj_website_for_info_about_khoj(chat_client, default_user2):
    # Arrange
    user_query = "Do you support image search?"

    # Act
    responses = await generate_online_subqueries(user_query, {}, None, default_user2)

    # Assert
    assert any(
        ["site:khoj.dev" in response for response in responses]
    ), "Expected search query to include site:khoj.dev but got: " + str(responses)


# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@pytest.mark.parametrize(
    "user_query, expected_conversation_commands",
    [
        (
            "Where did I learn to swim?",
            {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Text},
        ),
        (
            "Where is the nearest hospital?",
            {"sources": [ConversationCommand.Online], "output": ConversationCommand.Text},
        ),
        (
            "Summarize the wikipedia page on the history of the internet",
            {"sources": [ConversationCommand.Webpage], "output": ConversationCommand.Text},
        ),
        (
            "How many noble gases are there?",
            {"sources": [ConversationCommand.General], "output": ConversationCommand.Text},
        ),
        (
            "Make a painting incorporating my past diving experiences",
            {"sources": [ConversationCommand.Notes], "output": ConversationCommand.Image},
        ),
        (
            "Create a chart of the weather over the next 7 days in Timbuktu",
            {"sources": [ConversationCommand.Online, ConversationCommand.Code], "output": ConversationCommand.Text},
        ),
        (
            "What's the highest point in this country and have I been there?",
            {"sources": [ConversationCommand.Online, ConversationCommand.Notes], "output": ConversationCommand.Text},
        ),
    ],
)
async def test_select_data_sources_actor_chooses_to_search_notes(
    chat_client, user_query, expected_conversation_commands, default_user2
):
    # Act
    selected_conversation_commands = await aget_data_sources_and_output_format(user_query, {}, False, default_user2)

    # Assert
    assert set(expected_conversation_commands["sources"]) == set(selected_conversation_commands["sources"])
    assert expected_conversation_commands["output"] == selected_conversation_commands["output"]


# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_with_chat_history(chat_client, default_user2):
    # Arrange
    user_query = "What's the latest in the Israel/Palestine conflict?"
    chat_log = [
        (
            "Let's talk about the current events around the world.",
            "Sure, let's discuss the current events. What would you like to know?",
            [],
        ),
        ("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []),
    ]
    chat_history = generate_chat_history(chat_log)

    # Act
    selected = await aget_data_sources_and_output_format(user_query, chat_history, False, default_user2)
    sources = selected["sources"]

    # Assert
    assert sources == [ConversationCommand.Online]


# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client, default_user2):
    # Arrange
    user_query = "Summarize the wikipedia page on the history of the internet"

    # Act
    urls = await infer_webpage_urls(user_query, {}, None, default_user2)

    # Assert
    assert "https://en.wikipedia.org/wiki/History_of_the_Internet" in urls


# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@pytest.mark.parametrize(
    "user_query, expected_crontime, expected_qs, unexpected_qs",
    [
        (
            "Share the weather forecast for the next day daily at 7:30pm",
            "30 19 * * *",
            ["weather forecast"],
            ["7:30"],
        ),
        (
            "Notify me when the new President of Brazil is announced",
            "* *",  # crontime is variable
            ["brazil", "president"],
            ["notify"],  # ensure reminder isn't re-triggered on scheduled query run
        ),
        (
            "Let me know whenever Elon leaves Twitter. Check this every afternoon at 12",
            "0 12 * * *",  # ensure correctly converts to utc
            ["elon", "twitter"],
            ["12"],
        ),
        (
            "Draw a wallpaper every morning using the current weather",
            "* * *",  # daily crontime
            ["weather", "wallpaper"],
            ["every"],
        ),
    ],
)
async def test_infer_task_scheduling_request(
    chat_client, user_query, expected_crontime, expected_qs, unexpected_qs, default_user2
):
    # Act
    crontime, inferred_query, _ = await schedule_query(user_query, {}, default_user2)
    inferred_query = inferred_query.lower()

    # Assert
    assert expected_crontime in crontime
    for expected_q in expected_qs:
        assert expected_q in inferred_query, f"Expected fragment {expected_q} in query: {inferred_query}"
    for unexpected_q in unexpected_qs:
        assert (
            unexpected_q not in inferred_query
        ), f"Did not expect fragment '{unexpected_q}' in query: '{inferred_query}'"


# ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@pytest.mark.parametrize(
    "scheduling_query, executing_query, generated_response, expected_should_notify",
    [
        (
            "Notify me only if it is going to rain tomorrow?",
            "What's the weather forecast for tomorrow?",
            "It is sunny and warm tomorrow.",
            False,
        ),
        (
            "Summarize the latest news every morning",
            "Summarize today's news",
            "Today in the news: AI is taking over the world",
            True,
        ),
        (
            "Create a weather wallpaper every morning using the current weather",
            "Paint a weather wallpaper using the current weather",
            "https://khoj-generated-wallpaper.khoj.dev/user110/weathervane.webp",
            True,
        ),
        (
            "Let me know the election results once they are offically declared",
            "What are the results of the elections? Has the winner been declared?",
            "The election results has not been declared yet.",
            False,
        ),
    ],
)
def test_decision_on_when_to_notify_scheduled_task_results(
    chat_client, default_user2, scheduling_query, executing_query, generated_response, expected_should_notify
):
    # Act
    generated_should_notify = should_notify(scheduling_query, executing_query, generated_response, default_user2)

    # Assert
    assert generated_should_notify == expected_should_notify


# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):
    # Generate conversation logs
    conversation_log = {"chat": []}
    for user_message, gpt_message, context in message_list:
        conversation_log["chat"] += message_to_log(
            user_message,
            gpt_message,
            khoj_message_metadata={
                "context": context,
                "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'},
            },
            conversation_log=[],
        )
    return conversation_log