khoj/tests/test_conversation_utils.py
Debanjum Singh Solanky 5f2442450c Update truncation test to reduce flakyness in cloud tests
Removed dependency on faker, factory for the truncation tests as that
seems to be the point of flakiness
2024-06-07 19:42:48 +05:30

115 lines
5 KiB
Python

import tiktoken
from langchain.schema import ChatMessage
from khoj.processor.conversation import utils
class TestTruncateMessage:
max_prompt_size = 10
model_name = "gpt-3.5-turbo"
encoder = tiktoken.encoding_for_model(model_name)
def test_truncate_message_all_small(self):
# Arrange
chat_history = generate_chat_history(50)
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_history) < 50
assert len(chat_history) > 1
assert tokens <= self.max_prompt_size
def test_truncate_message_first_large(self):
# Arrange
chat_history = generate_chat_history(5)
big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?")
copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message)
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert len(chat_history) == 1
assert truncated_chat_history[0] != copy_big_chat_message
assert tokens <= self.max_prompt_size
def test_truncate_message_last_large(self):
# Arrange
chat_history = generate_chat_history(5)
chat_history[0].role = "system" # Mark the first message as system message
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
copy_big_chat_message = big_chat_message.copy()
chat_history.insert(0, big_chat_message)
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
# Act
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties.
assert (
len(truncated_chat_history) == len(chat_history) + 1
) # Because the system_prompt is popped off from the chat_messages list
assert len(truncated_chat_history) < 10
assert len(truncated_chat_history) > 1
assert truncated_chat_history[0] != copy_big_chat_message
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
def test_truncate_single_large_non_system_message(self):
# Arrange
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
# Act
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
assert len(chat_messages) == 1
assert truncated_chat_history[0] != copy_big_chat_message
def test_truncate_single_large_question(self):
# Arrange
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
copy_big_chat_message = big_chat_message.copy()
chat_messages = [big_chat_message]
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
# Act
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
# Assert
# The original object has been modified. Verify certain properties
assert initial_tokens > self.max_prompt_size
assert final_tokens <= self.max_prompt_size
assert len(chat_messages) == 1
assert truncated_chat_history[0] != copy_big_chat_message
def generate_content(count):
return " ".join([f"{index}" for index, _ in enumerate(range(count))])
def generate_chat_history(count):
return [
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}")
for index, _ in enumerate(range(count))
]