khoj/tests/test_conversation_utils.py
2023-06-05 19:01:03 -07:00

74 lines
3.1 KiB
Python

from khoj.processor.conversation import utils
from langchain.schema import ChatMessage
import factory
import tiktoken
class ChatMessageFactory(factory.Factory):
class Meta:
model = ChatMessage
content = factory.Faker('paragraph')
role = factory.Faker('name')
class TestTruncateMessage:
max_prompt_size = 4096
model_name = 'gpt-3.5-turbo'
encoder = tiktoken.encoding_for_model(model_name)
def test_truncate_message_all_small(self):
chat_messages = ChatMessageFactory.build_batch(500)
assert len(chat_messages) == 500
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
assert tokens > self.max_prompt_size
prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name)
# The original object has been modified. Verify certain properties
assert len(chat_messages) < 500
assert len(chat_messages) > 1
assert prompt == chat_messages
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
assert tokens <= self.max_prompt_size
def test_truncate_message_first_large(self):
chat_messages = ChatMessageFactory.build_batch(25)
big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
chat_messages.insert(0, big_chat_message)
assert len(chat_messages) == 26
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
assert tokens > self.max_prompt_size
prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name)
# The original object has been modified. Verify certain properties
assert len(chat_messages) < 26
assert len(chat_messages) == 1
assert prompt[0] != copy_big_chat_message
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
assert tokens <= self.max_prompt_size
def test_truncate_message_last_large(self):
chat_messages = ChatMessageFactory.build_batch(25)
big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
chat_messages.append(big_chat_message)
assert len(chat_messages) == 26
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
assert tokens > self.max_prompt_size
prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name)
# The original object has been modified. Verify certain properties
assert len(chat_messages) < 26
assert len(chat_messages) > 1
assert prompt[0] != copy_big_chat_message
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
assert tokens < self.max_prompt_size