mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-28 01:45:07 +01:00
Truncate last message if still over max supported prompt size by model
This commit is contained in:
parent
ed4d0f9076
commit
1cd9ecd449
1 changed files with 10 additions and 1 deletions
|
@ -97,10 +97,19 @@ def generate_chatml_messages_with_context(
|
||||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
tokens = sum([len(encoder.encode(content)) for message in messages for content in message.content])
|
tokens = sum([len(encoder.encode(content)) for message in messages for content in message.content])
|
||||||
while tokens > max_prompt_size[model_name]:
|
while tokens > max_prompt_size[model_name] and len(messages) > 1:
|
||||||
messages.pop()
|
messages.pop()
|
||||||
tokens = sum([len(encoder.encode(content)) for message in messages for content in message.content])
|
tokens = sum([len(encoder.encode(content)) for message in messages for content in message.content])
|
||||||
|
|
||||||
|
# Truncate last message if still over max supported prompt size by model
|
||||||
|
if tokens > max_prompt_size[model_name]:
|
||||||
|
last_message = messages[-1]
|
||||||
|
truncated_message = encoder.decode(encoder.encode(last_message.content))
|
||||||
|
logger.debug(
|
||||||
|
f"Truncate last message to fit within max prompt size of {max_prompt_size[model_name]} supported by {model_name} model:\n {truncated_message}"
|
||||||
|
)
|
||||||
|
messages = [ChatMessage(content=[truncated_message], role=last_message.role)]
|
||||||
|
|
||||||
# Return message in chronological order
|
# Return message in chronological order
|
||||||
return messages[::-1]
|
return messages[::-1]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue