Resolve merge conflicts: let Khoj fail if the model tokenizer is not found

This commit is contained in:
sabaimran 2023-07-31 19:12:26 -07:00
commit 209975e065
2 changed files with 14 additions and 11 deletions

View file

@ -125,7 +125,7 @@ def converse_offline(
# Get Conversation Primer appropriate to Conversation Type
# TODO If compiled_references_message is too long, we need to truncate it.
if compiled_references_message == "":
conversation_primer = prompts.conversation_llamav2.format(query=user_query)
conversation_primer = user_query
else:
conversation_primer = prompts.notes_conversation_llamav2.format(
query=user_query, references=compiled_references_message

View file

@ -102,7 +102,7 @@ def generate_chatml_messages_with_context(
return messages[::-1]
def truncate_messages(messages, max_prompt_size, model_name):
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
if "llama" in model_name:
@ -110,24 +110,27 @@ def truncate_messages(messages, max_prompt_size, model_name):
else:
encoder = tiktoken.encoding_for_model(model_name)
system_message = messages.pop()
system_message_tokens = len(encoder.encode(system_message.content))
tokens = sum([len(encoder.encode(message.content)) for message in messages])
while tokens > max_prompt_size and len(messages) > 1:
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
messages.pop()
tokens = sum([len(encoder.encode(message.content)) for message in messages])
# Truncate last message if still over max supported prompt size by model
if tokens > max_prompt_size:
last_message = "\n".join(messages[-1].content.split("\n")[:-1])
original_question = "\n".join(messages[-1].content.split("\n")[-1:])
# Truncate current message if still over max supported prompt size by model
if (tokens + system_message_tokens) > max_prompt_size:
current_message = "\n".join(messages[0].content.split("\n")[:-1])
original_question = "\n".join(messages[0].content.split("\n")[-1:])
original_question_tokens = len(encoder.encode(original_question))
remaining_tokens = max_prompt_size - original_question_tokens
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
logger.debug(
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
)
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)]
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
return messages
return messages + [system_message]
def reciprocal_conversation_to_chatml(message_pair):