Run pre-commit script

This commit is contained in:
Saba 2023-06-05 19:29:23 -07:00
parent 948ba6ddca
commit 7119ed0849
2 changed files with 11 additions and 9 deletions

View file

@ -102,6 +102,7 @@ def generate_chatml_messages_with_context(
# Return message in chronological order # Return message in chronological order
return messages[::-1] return messages[::-1]
def truncate_message(messages, max_prompt_size, model_name): def truncate_message(messages, max_prompt_size, model_name):
"""Truncate messages to fit within max prompt size supported by model""" """Truncate messages to fit within max prompt size supported by model"""
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
@ -112,8 +113,8 @@ def truncate_message(messages, max_prompt_size, model_name):
# Truncate last message if still over max supported prompt size by model # Truncate last message if still over max supported prompt size by model
if tokens > max_prompt_size: if tokens > max_prompt_size:
last_message = '\n'.join(messages[-1].content.split("\n")[:-1]) last_message = "\n".join(messages[-1].content.split("\n")[:-1])
original_question = '\n'.join(messages[-1].content.split("\n")[-1:]) original_question = "\n".join(messages[-1].content.split("\n")[-1:])
original_question_tokens = len(encoder.encode(original_question)) original_question_tokens = len(encoder.encode(original_question))
remaining_tokens = max_prompt_size - original_question_tokens remaining_tokens = max_prompt_size - original_question_tokens
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip() truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()

View file

@ -3,16 +3,18 @@ from langchain.schema import ChatMessage
import factory import factory
import tiktoken import tiktoken
class ChatMessageFactory(factory.Factory): class ChatMessageFactory(factory.Factory):
class Meta: class Meta:
model = ChatMessage model = ChatMessage
content = factory.Faker('paragraph') content = factory.Faker("paragraph")
role = factory.Faker('name') role = factory.Faker("name")
class TestTruncateMessage: class TestTruncateMessage:
max_prompt_size = 4096 max_prompt_size = 4096
model_name = 'gpt-3.5-turbo' model_name = "gpt-3.5-turbo"
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
def test_truncate_message_all_small(self): def test_truncate_message_all_small(self):
@ -33,7 +35,7 @@ class TestTruncateMessage:
def test_truncate_message_first_large(self): def test_truncate_message_first_large(self):
chat_messages = ChatMessageFactory.build_batch(25) chat_messages = ChatMessageFactory.build_batch(25)
big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000)) big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?" big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
chat_messages.insert(0, big_chat_message) chat_messages.insert(0, big_chat_message)
@ -53,7 +55,7 @@ class TestTruncateMessage:
def test_truncate_message_last_large(self): def test_truncate_message_last_large(self):
chat_messages = ChatMessageFactory.build_batch(25) chat_messages = ChatMessageFactory.build_batch(25)
big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000)) big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
big_chat_message.content = big_chat_message.content + "\n" + "Question?" big_chat_message.content = big_chat_message.content + "\n" + "Question?"
copy_big_chat_message = big_chat_message.copy() copy_big_chat_message = big_chat_message.copy()
@ -71,4 +73,3 @@ class TestTruncateMessage:
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt]) tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
assert tokens < self.max_prompt_size assert tokens < self.max_prompt_size