Resolve merge conflicts: let Khoj fail if the model tokenizer is not found

This commit is contained in:
sabaimran 2023-07-31 19:12:26 -07:00
commit 209975e065
2 changed files with 14 additions and 11 deletions

View file

@ -125,7 +125,7 @@ def converse_offline(
# Get Conversation Primer appropriate to Conversation Type # Get Conversation Primer appropriate to Conversation Type
# TODO If compiled_references_message is too long, we need to truncate it. # TODO If compiled_references_message is too long, we need to truncate it.
if compiled_references_message == "": if compiled_references_message == "":
conversation_primer = prompts.conversation_llamav2.format(query=user_query) conversation_primer = user_query
else: else:
conversation_primer = prompts.notes_conversation_llamav2.format( conversation_primer = prompts.notes_conversation_llamav2.format(
query=user_query, references=compiled_references_message query=user_query, references=compiled_references_message

View file

@ -102,7 +102,7 @@ def generate_chatml_messages_with_context(
return messages[::-1] return messages[::-1]
def truncate_messages(messages, max_prompt_size, model_name): def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model""" """Truncate messages to fit within max prompt size supported by model"""
if "llama" in model_name: if "llama" in model_name:
@ -110,24 +110,27 @@ def truncate_messages(messages, max_prompt_size, model_name):
else: else:
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
system_message = messages.pop()
system_message_tokens = len(encoder.encode(system_message.content))
tokens = sum([len(encoder.encode(message.content)) for message in messages]) tokens = sum([len(encoder.encode(message.content)) for message in messages])
while tokens > max_prompt_size and len(messages) > 1: while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
messages.pop() messages.pop()
tokens = sum([len(encoder.encode(message.content)) for message in messages]) tokens = sum([len(encoder.encode(message.content)) for message in messages])
# Truncate last message if still over max supported prompt size by model # Truncate current message if still over max supported prompt size by model
if tokens > max_prompt_size: if (tokens + system_message_tokens) > max_prompt_size:
last_message = "\n".join(messages[-1].content.split("\n")[:-1]) current_message = "\n".join(messages[0].content.split("\n")[:-1])
original_question = "\n".join(messages[-1].content.split("\n")[-1:]) original_question = "\n".join(messages[0].content.split("\n")[-1:])
original_question_tokens = len(encoder.encode(original_question)) original_question_tokens = len(encoder.encode(original_question))
remaining_tokens = max_prompt_size - original_question_tokens remaining_tokens = max_prompt_size - original_question_tokens
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip() truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
logger.debug( logger.debug(
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}" f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
) )
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)] messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
return messages return messages + [system_message]
def reciprocal_conversation_to_chatml(message_pair): def reciprocal_conversation_to_chatml(message_pair):