mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Resolve merge conflicts: let Khoj fail if the model tokenizer is not found
This commit is contained in:
commit
209975e065
2 changed files with 14 additions and 11 deletions
|
@ -125,7 +125,7 @@ def converse_offline(
|
|||
# Get Conversation Primer appropriate to Conversation Type
|
||||
# TODO If compiled_references_message is too long, we need to truncate it.
|
||||
if compiled_references_message == "":
|
||||
conversation_primer = prompts.conversation_llamav2.format(query=user_query)
|
||||
conversation_primer = user_query
|
||||
else:
|
||||
conversation_primer = prompts.notes_conversation_llamav2.format(
|
||||
query=user_query, references=compiled_references_message
|
||||
|
|
|
@ -102,7 +102,7 @@ def generate_chatml_messages_with_context(
|
|||
return messages[::-1]
|
||||
|
||||
|
||||
def truncate_messages(messages, max_prompt_size, model_name):
|
||||
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
|
||||
if "llama" in model_name:
|
||||
|
@ -110,24 +110,27 @@ def truncate_messages(messages, max_prompt_size, model_name):
|
|||
else:
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
|
||||
system_message = messages.pop()
|
||||
system_message_tokens = len(encoder.encode(system_message.content))
|
||||
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
while tokens > max_prompt_size and len(messages) > 1:
|
||||
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
||||
messages.pop()
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
|
||||
# Truncate last message if still over max supported prompt size by model
|
||||
if tokens > max_prompt_size:
|
||||
last_message = "\n".join(messages[-1].content.split("\n")[:-1])
|
||||
original_question = "\n".join(messages[-1].content.split("\n")[-1:])
|
||||
# Truncate current message if still over max supported prompt size by model
|
||||
if (tokens + system_message_tokens) > max_prompt_size:
|
||||
current_message = "\n".join(messages[0].content.split("\n")[:-1])
|
||||
original_question = "\n".join(messages[0].content.split("\n")[-1:])
|
||||
original_question_tokens = len(encoder.encode(original_question))
|
||||
remaining_tokens = max_prompt_size - original_question_tokens
|
||||
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
|
||||
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
||||
logger.debug(
|
||||
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||
)
|
||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)]
|
||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
||||
|
||||
return messages
|
||||
return messages + [system_message]
|
||||
|
||||
|
||||
def reciprocal_conversation_to_chatml(message_pair):
|
||||
|
|
Loading…
Reference in a new issue