mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Do not drop system message when truncating context to max prompt size
Previously the system message was getting dropped when the context size with chat history would be more than the max prompt size supported by the cat model Now only the previous chat messages are dropped or the current message is truncated but the system message is kept to provide guidance to the chat model
This commit is contained in:
parent
02e216c135
commit
48e5ac0169
1 changed files with 14 additions and 10 deletions
|
@ -100,30 +100,34 @@ def generate_chatml_messages_with_context(
|
|||
return messages[::-1]
|
||||
|
||||
|
||||
def truncate_messages(messages, max_prompt_size, model_name):
|
||||
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
encoder = tiktoken.encoding_for_model("text-davinci-001")
|
||||
|
||||
system_message = messages.pop()
|
||||
system_message_tokens = len(encoder.encode(system_message.content))
|
||||
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
while tokens > max_prompt_size and len(messages) > 1:
|
||||
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
||||
messages.pop()
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
|
||||
# Truncate last message if still over max supported prompt size by model
|
||||
if tokens > max_prompt_size:
|
||||
last_message = "\n".join(messages[-1].content.split("\n")[:-1])
|
||||
original_question = "\n".join(messages[-1].content.split("\n")[-1:])
|
||||
# Truncate current message if still over max supported prompt size by model
|
||||
if (tokens + system_message_tokens) > max_prompt_size:
|
||||
current_message = "\n".join(messages[0].content.split("\n")[:-1])
|
||||
original_question = "\n".join(messages[0].content.split("\n")[-1:])
|
||||
original_question_tokens = len(encoder.encode(original_question))
|
||||
remaining_tokens = max_prompt_size - original_question_tokens
|
||||
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
|
||||
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
||||
logger.debug(
|
||||
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||
)
|
||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)]
|
||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
||||
|
||||
return messages
|
||||
return messages + [system_message]
|
||||
|
||||
|
||||
def reciprocal_conversation_to_chatml(message_pair):
|
||||
|
|
Loading…
Add table
Reference in a new issue