mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Run pre-commit script
This commit is contained in:
parent
948ba6ddca
commit
7119ed0849
2 changed files with 11 additions and 9 deletions
|
@ -102,6 +102,7 @@ def generate_chatml_messages_with_context(
|
||||||
# Return message in chronological order
|
# Return message in chronological order
|
||||||
return messages[::-1]
|
return messages[::-1]
|
||||||
|
|
||||||
|
|
||||||
def truncate_message(messages, max_prompt_size, model_name):
|
def truncate_message(messages, max_prompt_size, model_name):
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
|
@ -112,8 +113,8 @@ def truncate_message(messages, max_prompt_size, model_name):
|
||||||
|
|
||||||
# Truncate last message if still over max supported prompt size by model
|
# Truncate last message if still over max supported prompt size by model
|
||||||
if tokens > max_prompt_size:
|
if tokens > max_prompt_size:
|
||||||
last_message = '\n'.join(messages[-1].content.split("\n")[:-1])
|
last_message = "\n".join(messages[-1].content.split("\n")[:-1])
|
||||||
original_question = '\n'.join(messages[-1].content.split("\n")[-1:])
|
original_question = "\n".join(messages[-1].content.split("\n")[-1:])
|
||||||
original_question_tokens = len(encoder.encode(original_question))
|
original_question_tokens = len(encoder.encode(original_question))
|
||||||
remaining_tokens = max_prompt_size - original_question_tokens
|
remaining_tokens = max_prompt_size - original_question_tokens
|
||||||
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
|
truncated_message = encoder.decode(encoder.encode(last_message)[:remaining_tokens]).strip()
|
||||||
|
|
|
@ -3,16 +3,18 @@ from langchain.schema import ChatMessage
|
||||||
import factory
|
import factory
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageFactory(factory.Factory):
|
class ChatMessageFactory(factory.Factory):
|
||||||
class Meta:
|
class Meta:
|
||||||
model = ChatMessage
|
model = ChatMessage
|
||||||
|
|
||||||
content = factory.Faker('paragraph')
|
content = factory.Faker("paragraph")
|
||||||
role = factory.Faker('name')
|
role = factory.Faker("name")
|
||||||
|
|
||||||
|
|
||||||
class TestTruncateMessage:
|
class TestTruncateMessage:
|
||||||
max_prompt_size = 4096
|
max_prompt_size = 4096
|
||||||
model_name = 'gpt-3.5-turbo'
|
model_name = "gpt-3.5-turbo"
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
|
|
||||||
def test_truncate_message_all_small(self):
|
def test_truncate_message_all_small(self):
|
||||||
|
@ -33,7 +35,7 @@ class TestTruncateMessage:
|
||||||
|
|
||||||
def test_truncate_message_first_large(self):
|
def test_truncate_message_first_large(self):
|
||||||
chat_messages = ChatMessageFactory.build_batch(25)
|
chat_messages = ChatMessageFactory.build_batch(25)
|
||||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000))
|
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
||||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
chat_messages.insert(0, big_chat_message)
|
chat_messages.insert(0, big_chat_message)
|
||||||
|
@ -53,10 +55,10 @@ class TestTruncateMessage:
|
||||||
|
|
||||||
def test_truncate_message_last_large(self):
|
def test_truncate_message_last_large(self):
|
||||||
chat_messages = ChatMessageFactory.build_batch(25)
|
chat_messages = ChatMessageFactory.build_batch(25)
|
||||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker('paragraph', nb_sentences=1000))
|
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
||||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
|
|
||||||
chat_messages.append(big_chat_message)
|
chat_messages.append(big_chat_message)
|
||||||
assert len(chat_messages) == 26
|
assert len(chat_messages) == 26
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
|
@ -71,4 +73,3 @@ class TestTruncateMessage:
|
||||||
|
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
tokens = sum([len(self.encoder.encode(message.content)) for message in prompt])
|
||||||
assert tokens < self.max_prompt_size
|
assert tokens < self.max_prompt_size
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue