mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Resolve merge conflicts: let Khoj fail if the model tokenizer is not found
This commit is contained in:
commit
209975e065
2 changed files with 14 additions and 11 deletions
|
@ -125,7 +125,7 @@ def converse_offline(
|
||||||
# Get Conversation Primer appropriate to Conversation Type
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
# TODO If compiled_references_message is too long, we need to truncate it.
|
# TODO If compiled_references_message is too long, we need to truncate it.
|
||||||
if compiled_references_message == "":
|
if compiled_references_message == "":
|
||||||
conversation_primer = prompts.conversation_llamav2.format(query=user_query)
|
conversation_primer = user_query
|
||||||
else:
|
else:
|
||||||
conversation_primer = prompts.notes_conversation_llamav2.format(
|
conversation_primer = prompts.notes_conversation_llamav2.format(
|
||||||
query=user_query, references=compiled_references_message
|
query=user_query, references=compiled_references_message
|
||||||
|
|
|
@ -102,7 +102,7 @@ def generate_chatml_messages_with_context(
|
||||||
return messages[::-1]
|
return messages[::-1]
|
||||||
|
|
||||||
|
|
||||||
def truncate_messages(messages, max_prompt_size, model_name):
|
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
|
|
||||||
if "llama" in model_name:
|
if "llama" in model_name:
|
||||||
|
@ -110,24 +110,27 @@ def truncate_messages(messages, max_prompt_size, model_name):
|
||||||
else:
|
else:
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
|
|
||||||
|
system_message = messages.pop()
|
||||||
|
system_message_tokens = len(encoder.encode(system_message.content))
|
||||||
|
|
||||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||||
while tokens > max_prompt_size and len(messages) > 1:
|
while (tokens + system_message_tokens) > max_prompt_size and len(messages) > 1:
|
||||||
messages.pop()
|
messages.pop()
|
||||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||||
|
|
||||||
# Truncate last message if still over max supported prompt size by model
|
# Truncate current message if still over max supported prompt size by model
|
||||||
if tokens > max_prompt_size:
|
if (tokens + system_message_tokens) > max_prompt_size:
|
||||||
last_message = "\n".join(messages[-1].content.split("\n")[:-1])
|
current_message = "\n".join(messages[0].content.split("\n")[:-1])
|
||||||
original_question = "\n".join(messages[-1].content.split("\n")[-1:])
|
original_question = "\n".join(messages[0].content.split("\n")[-1:])
|
||||||
original_question_tokens = len(encoder.encode(original_question))
|
original_question_tokens = len(encoder.encode(original_question))
|
||||||
remaining_tokens = max_prompt_size - original_question_tokens
|
remaining_tokens = max_prompt_size - original_question_tokens
|
||||||
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
|
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Truncate last message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||||
)
|
)
|
||||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[-1].role)]
|
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
||||||
|
|
||||||
return messages
|
return messages + [system_message]
|
||||||
|
|
||||||
|
|
||||||
def reciprocal_conversation_to_chatml(message_pair):
|
def reciprocal_conversation_to_chatml(message_pair):
|
||||||
|
|
Loading…
Add table
Reference in a new issue