Move message truncation logic into a separate function. Add unit tests with factory boy.

This commit is contained in:
Saba 2023-06-05 18:58:29 -07:00
parent 5f4223efb4
commit f65ff9815d
3 changed files with 98 additions and 8 deletions

View file

@ -56,6 +56,8 @@ dependencies = [
"aiohttp == 3.8.4",
"langchain >= 0.0.187",
"pypdf >= 3.9.0",
"factory-boy==3.2.1",
"Faker==18.10.1"
]
dynamic = ["version"]

View file

@ -97,23 +97,33 @@ def generate_chatml_messages_with_context(
messages = user_chatml_message + rest_backnforths + system_chatml_message
# Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_message(messages, max_prompt_size[model_name], model_name)
# Return message in chronological order
return messages[::-1]
def truncate_message(messages, max_prompt_size, model_name):
"""Truncate messages to fit within max prompt size supported by model"""
encoder = tiktoken.encoding_for_model(model_name)
tokens = sum([len(encoder.encode(message.content)) for message in messages])
while tokens > max_prompt_size[model_name] and len(messages) > 1:
logger.info(f"num tokens: {tokens}")
while tokens > max_prompt_size and len(messages) > 1:
messages.pop()
tokens = sum([len(encoder.encode(message.content)) for message in messages])
# Truncate last message if still over max supported prompt size by model
if tokens > max_prompt_size[model_name]:
last_message = messages[-1]
truncated_message = encoder.decode(encoder.encode(last_message.content))
if tokens > max_prompt_size:
last_message = '\n'.join(messages[-1].content.split("\n")[:-1])
original_question = '\n'.join(messages[-1].content.split("\n")[-1:])
original_question_tokens = len(encoder.encode(original_question))
remaining_tokens = max_prompt_size - original_question_tokens
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
logger.debug(
f"Truncate last message to fit within max prompt size of {max_prompt_size[model_name]} supported by {model_name} model:\n {truncated_message}"
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
)
messages = [ChatMessage(content=truncated_message, role=last_message.role)]
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)]
# Return message in chronological order
return messages[::-1]
return messages
def reciprocal_conversation_to_chatml(message_pair):

View file

@ -0,0 +1,78 @@
from khoj.processor.conversation import utils
from langchain.schema import ChatMessage
import factory
import logging
import tiktoken
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class ChatMessageFactory(factory.Factory):
class Meta:
model = ChatMessage
content = factory.Faker('paragraph')
role = factory.Faker('name')
class TestTruncateMessage:
max_prompt_size = 4096
model_name = 'gpt-3.5-turbo'
encoder = tiktoken.encoding_for_model(model_name)
def test_truncate_message_all_small(self):
chat_messages = ChatMessageFactory.build_batch(500)
assert len(chat_messages) == 500
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
assert tokens > self.max_prompt_size
prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name)
# The original object has been modified. Verify certain properties
assert len(chat_messages) < 500
assert len(chat_messages) > 1
assert prompt == chat_messages
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
assert tokens <= self.max_prompt_size
def test_truncate_message_first_large(self):
chat_messages = ChatMessageFactory.build_batch(25)
big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
chat_messages.insert(0, big_chat_message)
assert len(chat_messages) == 26
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
assert tokens > self.max_prompt_size
prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name)
# The original object has been modified. Verify certain properties
assert len(chat_messages) < 26
assert len(chat_messages) == 1
assert prompt[0] != copy_big_chat_message
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
assert tokens <= self.max_prompt_size
def test_truncate_message_last_large(self):
chat_messages = ChatMessageFactory.build_batch(25)
big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy()
chat_messages.append(big_chat_message)
assert len(chat_messages) == 26
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
assert tokens > self.max_prompt_size
prompt = utils.truncate_message(chat_messages, self.max_prompt_size, self.model_name)
# The original object has been modified. Verify certain properties
assert len(chat_messages) < 26
assert len(chat_messages) > 1
assert prompt[0] != copy_big_chat_message
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
assert tokens < self.max_prompt_size