import tiktoken
from langchain.schema import ChatMessage

from khoj.processor.conversation import utils


class TestTruncateMessage:
    max_prompt_size = 10
    model_name = "gpt-3.5-turbo"
    encoder = tiktoken.encoding_for_model(model_name)

    def test_truncate_message_all_small(self):
        # Arrange
        chat_history = generate_chat_history(50)

        # Act
        truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
        tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])

        # Assert
        # The original object has been modified. Verify certain properties
        assert len(chat_history) < 50
        assert len(chat_history) > 1
        assert tokens <= self.max_prompt_size

    def test_truncate_message_first_large(self):
        # Arrange
        chat_history = generate_chat_history(5)
        big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?")
        copy_big_chat_message = big_chat_message.copy()
        chat_history.insert(0, big_chat_message)
        tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])

        # Act
        truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
        tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])

        # Assert
        # The original object has been modified. Verify certain properties
        assert len(chat_history) == 1
        assert truncated_chat_history[0] != copy_big_chat_message
        assert tokens <= self.max_prompt_size

    def test_truncate_message_last_large(self):
        # Arrange
        chat_history = generate_chat_history(5)
        chat_history[0].role = "system"  # Mark the first message as system message
        big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
        copy_big_chat_message = big_chat_message.copy()

        chat_history.insert(0, big_chat_message)
        initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])

        # Act
        truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
        final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])

        # Assert
        # The original object has been modified. Verify certain properties.
        assert (
            len(truncated_chat_history) == len(chat_history) + 1
        )  # Because the system_prompt is popped off from the chat_messages list
        assert len(truncated_chat_history) < 10
        assert len(truncated_chat_history) > 1
        assert truncated_chat_history[0] != copy_big_chat_message
        assert initial_tokens > self.max_prompt_size
        assert final_tokens <= self.max_prompt_size

    def test_truncate_single_large_non_system_message(self):
        # Arrange
        big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
        copy_big_chat_message = big_chat_message.copy()
        chat_messages = [big_chat_message]
        initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])

        # Act
        truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
        final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])

        # Assert
        # The original object has been modified. Verify certain properties
        assert initial_tokens > self.max_prompt_size
        assert final_tokens <= self.max_prompt_size
        assert len(chat_messages) == 1
        assert truncated_chat_history[0] != copy_big_chat_message

    def test_truncate_single_large_question(self):
        # Arrange
        big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
        big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
        copy_big_chat_message = big_chat_message.copy()
        chat_messages = [big_chat_message]
        initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])

        # Act
        truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
        final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])

        # Assert
        # The original object has been modified. Verify certain properties
        assert initial_tokens > self.max_prompt_size
        assert final_tokens <= self.max_prompt_size
        assert len(chat_messages) == 1
        assert truncated_chat_history[0] != copy_big_chat_message


def generate_content(count):
    return " ".join([f"{index}" for index, _ in enumerate(range(count))])


def generate_chat_history(count):
    return [
        ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}")
        for index, _ in enumerate(range(count))
    ]