Improve offline chat truncation to consider message separator tokens

This commit is contained in:
Debanjum Singh Solanky 2024-07-18 02:39:56 +05:30
parent 6f46e6afc6
commit b0ee78586c

View file

@ -186,7 +186,7 @@ def generate_chatml_messages_with_context(
def truncate_messages( def truncate_messages(
messages: list[ChatMessage], messages: list[ChatMessage],
max_prompt_size, max_prompt_size: int,
model_name: str, model_name: str,
loaded_model: Optional[Llama] = None, loaded_model: Optional[Llama] = None,
tokenizer_name=None, tokenizer_name=None,
@ -232,7 +232,8 @@ def truncate_messages(
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])
# Drop older messages until under max supported prompt size by model # Drop older messages until under max supported prompt size by model
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1: # Reserves 4 tokens to demarcate each message (e.g <|im_start|>user, <|im_end|>, <|endoftext|> etc.)
while (tokens + system_message_tokens + 4 * len(messages)) > max_prompt_size and len(messages) > 1:
messages.pop() messages.pop()
tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str]) tokens = sum([len(encoder.encode(message.content)) for message in messages if type(message.content) == str])