2023-06-05 18:58:29 -07:00
|
|
|
import tiktoken
|
2023-12-28 18:04:02 +05:30
|
|
|
from langchain.schema import ChatMessage
|
|
|
|
|
|
|
|
from khoj.processor.conversation import utils
|
2023-06-05 18:58:29 -07:00
|
|
|
|
2023-06-05 19:29:23 -07:00
|
|
|
|
2023-06-05 18:58:29 -07:00
|
|
|
class TestTruncateMessage:
|
2024-06-07 19:40:53 +05:30
|
|
|
max_prompt_size = 10
|
2024-08-22 19:04:49 -07:00
|
|
|
model_name = "gpt-4o-mini"
|
2023-06-05 18:58:29 -07:00
|
|
|
encoder = tiktoken.encoding_for_model(model_name)
|
|
|
|
|
|
|
|
def test_truncate_message_all_small(self):
|
2024-03-15 14:52:29 +05:30
|
|
|
# Arrange
|
2024-06-07 19:40:53 +05:30
|
|
|
chat_history = generate_chat_history(50)
|
2023-06-05 18:58:29 -07:00
|
|
|
|
2024-03-15 14:52:29 +05:30
|
|
|
# Act
|
|
|
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
|
|
|
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
2023-06-05 18:58:29 -07:00
|
|
|
|
2024-03-15 14:52:29 +05:30
|
|
|
# Assert
|
2023-06-05 18:58:29 -07:00
|
|
|
# The original object has been modified. Verify certain properties
|
2024-06-07 19:40:53 +05:30
|
|
|
assert len(chat_history) < 50
|
2024-03-15 14:52:29 +05:30
|
|
|
assert len(chat_history) > 1
|
2023-06-05 18:58:29 -07:00
|
|
|
assert tokens <= self.max_prompt_size
|
|
|
|
|
|
|
|
def test_truncate_message_first_large(self):
|
2024-03-15 14:52:29 +05:30
|
|
|
# Arrange
|
2024-06-07 19:40:53 +05:30
|
|
|
chat_history = generate_chat_history(5)
|
|
|
|
big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?")
|
2023-06-05 18:58:29 -07:00
|
|
|
copy_big_chat_message = big_chat_message.copy()
|
2024-03-15 14:52:29 +05:30
|
|
|
chat_history.insert(0, big_chat_message)
|
|
|
|
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
2023-06-05 18:58:29 -07:00
|
|
|
|
2024-03-15 14:52:29 +05:30
|
|
|
# Act
|
|
|
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
|
|
|
tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
2023-06-05 18:58:29 -07:00
|
|
|
|
2024-03-15 14:52:29 +05:30
|
|
|
# Assert
|
2023-06-05 18:58:29 -07:00
|
|
|
# The original object has been modified. Verify certain properties
|
2024-03-15 14:52:29 +05:30
|
|
|
assert len(chat_history) == 1
|
|
|
|
assert truncated_chat_history[0] != copy_big_chat_message
|
2023-06-05 18:58:29 -07:00
|
|
|
assert tokens <= self.max_prompt_size
|
|
|
|
|
|
|
|
def test_truncate_message_last_large(self):
|
2024-03-15 14:52:29 +05:30
|
|
|
# Arrange
|
2024-06-07 19:40:53 +05:30
|
|
|
chat_history = generate_chat_history(5)
|
2024-03-15 14:52:29 +05:30
|
|
|
chat_history[0].role = "system" # Mark the first message as system message
|
2024-06-07 19:40:53 +05:30
|
|
|
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
|
2023-06-05 18:58:29 -07:00
|
|
|
copy_big_chat_message = big_chat_message.copy()
|
2023-06-05 19:29:23 -07:00
|
|
|
|
2024-03-15 14:52:29 +05:30
|
|
|
chat_history.insert(0, big_chat_message)
|
|
|
|
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
2023-06-05 18:58:29 -07:00
|
|
|
|
2024-03-15 14:52:29 +05:30
|
|
|
# Act
|
|
|
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
|
|
|
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
2023-06-05 18:58:29 -07:00
|
|
|
|
2024-03-15 14:52:29 +05:30
|
|
|
# Assert
|
2023-08-01 09:25:52 -07:00
|
|
|
# The original object has been modified. Verify certain properties.
|
2024-06-07 19:40:53 +05:30
|
|
|
assert (
|
|
|
|
len(truncated_chat_history) == len(chat_history) + 1
|
|
|
|
) # Because the system_prompt is popped off from the chat_messages list
|
|
|
|
assert len(truncated_chat_history) < 10
|
2024-03-15 14:52:29 +05:30
|
|
|
assert len(truncated_chat_history) > 1
|
|
|
|
assert truncated_chat_history[0] != copy_big_chat_message
|
|
|
|
assert initial_tokens > self.max_prompt_size
|
|
|
|
assert final_tokens <= self.max_prompt_size
|
|
|
|
|
|
|
|
def test_truncate_single_large_non_system_message(self):
|
|
|
|
# Arrange
|
2024-06-07 19:40:53 +05:30
|
|
|
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
|
2024-03-15 14:52:29 +05:30
|
|
|
copy_big_chat_message = big_chat_message.copy()
|
|
|
|
chat_messages = [big_chat_message]
|
|
|
|
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
|
|
|
|
|
|
|
# Act
|
|
|
|
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
|
|
|
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
|
|
|
|
|
|
|
# Assert
|
|
|
|
# The original object has been modified. Verify certain properties
|
|
|
|
assert initial_tokens > self.max_prompt_size
|
|
|
|
assert final_tokens <= self.max_prompt_size
|
|
|
|
assert len(chat_messages) == 1
|
|
|
|
assert truncated_chat_history[0] != copy_big_chat_message
|
2024-03-31 15:37:29 +05:30
|
|
|
|
|
|
|
def test_truncate_single_large_question(self):
|
|
|
|
# Arrange
|
|
|
|
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
|
2024-06-07 19:40:53 +05:30
|
|
|
big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
|
2024-03-31 15:37:29 +05:30
|
|
|
copy_big_chat_message = big_chat_message.copy()
|
|
|
|
chat_messages = [big_chat_message]
|
|
|
|
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
|
|
|
|
|
|
|
# Act
|
|
|
|
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
|
|
|
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
|
|
|
|
|
|
|
# Assert
|
|
|
|
# The original object has been modified. Verify certain properties
|
|
|
|
assert initial_tokens > self.max_prompt_size
|
|
|
|
assert final_tokens <= self.max_prompt_size
|
|
|
|
assert len(chat_messages) == 1
|
|
|
|
assert truncated_chat_history[0] != copy_big_chat_message
|
2024-06-07 19:40:53 +05:30
|
|
|
|
|
|
|
|
2024-11-26 15:35:23 -08:00
|
|
|
def test_load_complex_raw_json_string():
|
|
|
|
# Arrange
|
|
|
|
raw_json = r"""{"key": "value with unescaped " and unescaped \' and escaped \" and escaped \\'"}"""
|
|
|
|
expeced_json = {"key": "value with unescaped \" and unescaped \\' and escaped \" and escaped \\'"}
|
|
|
|
|
|
|
|
# Act
|
|
|
|
parsed_json = utils.load_complex_json(raw_json)
|
|
|
|
|
|
|
|
# Assert
|
|
|
|
assert parsed_json == expeced_json
|
|
|
|
|
|
|
|
|
2024-06-07 19:40:53 +05:30
|
|
|
def generate_content(count):
|
|
|
|
return " ".join([f"{index}" for index, _ in enumerate(range(count))])
|
|
|
|
|
|
|
|
|
|
|
|
def generate_chat_history(count):
|
|
|
|
return [
|
|
|
|
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}")
|
|
|
|
for index, _ in enumerate(range(count))
|
|
|
|
]
|