Do not drop system message when truncating context to max prompt size

Previously the system message was getting dropped when the context
size with chat history would be more than the max prompt size
supported by the cat model

Now only the previous chat messages are dropped or the current
message is truncated but the system message is kept to provide
guidance to the chat model
This commit is contained in:
Debanjum Singh Solanky 2023-07-31 16:36:31 -07:00
parent 02e216c135
commit 48e5ac0169

View file

@ -100,30 +100,34 @@ def generate_chatml_messages_with_context(
return messages[::-1]
def truncate_messages(messages, max_prompt_size, model_name):
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
try:
encoder = tiktoken.encoding_for_model(model_name)
except KeyError:
encoder = tiktoken.encoding_for_model("text-davinci-001")
system_message = messages.pop()
system_message_tokens = len(encoder.encode(system_message.content))
tokens = sum([len(encoder.encode(message.content)) for message in messages])
while tokens > max_prompt_size and len(messages) > 1:
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
messages.pop()
tokens = sum([len(encoder.encode(message.content)) for message in messages])
# Truncate last message if still over max supported prompt size by model
if tokens > max_prompt_size:
last_message = "\n".join(messages[-1].content.split("\n")[:-1])
original_question = "\n".join(messages[-1].content.split("\n")[-1:])
# Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size:
current_message = "\n".join(messages[0].content.split("\n")[:-1])
original_question = "\n".join(messages[0].content.split("\n")[-1:])
original_question_tokens = len(encoder.encode(original_question))
remaining_tokens = max_prompt_size - original_question_tokens
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
logger.debug(
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
)
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)]
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
return messages
return messages + [system_message]
def reciprocal_conversation_to_chatml(message_pair):