mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Fix unit tests and truncation logic
This commit is contained in:
parent
2335f11b00
commit
e55e9a7b67
2 changed files with 4 additions and 6 deletions
|
@ -123,7 +123,7 @@ def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name)
|
|||
current_message = "\n".join(messages[0].content.split("\n")[:-1])
|
||||
original_question = "\n".join(messages[0].content.split("\n")[-1:])
|
||||
original_question_tokens = len(encoder.encode(original_question))
|
||||
remaining_tokens = max_prompt_size - original_question_tokens
|
||||
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
|
||||
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
||||
logger.debug(
|
||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||
|
|
|
@ -19,7 +19,6 @@ class TestTruncateMessage:
|
|||
|
||||
def test_truncate_message_all_small(self):
|
||||
chat_messages = ChatMessageFactory.build_batch(500)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
|
||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||
|
@ -27,7 +26,6 @@ class TestTruncateMessage:
|
|||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_messages) < 500
|
||||
assert len(chat_messages) > 1
|
||||
assert prompt == chat_messages
|
||||
assert tokens <= self.max_prompt_size
|
||||
|
||||
def test_truncate_message_first_large(self):
|
||||
|
@ -52,14 +50,14 @@ class TestTruncateMessage:
|
|||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
|
||||
chat_messages.append(big_chat_message)
|
||||
chat_messages.insert(0, big_chat_message)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
|
||||
prompt = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_messages) < 26
|
||||
assert len(chat_messages) > 1
|
||||
assert len(prompt) < 26
|
||||
assert len(prompt) > 1
|
||||
assert prompt[0] != copy_big_chat_message
|
||||
assert tokens <= self.max_prompt_size
|
||||
|
|
Loading…
Reference in a new issue