mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 09:25:06 +01:00
Update truncation test to reduce flakyness in cloud tests
Removed dependency on faker, factory for the truncation tests as that seems to be the point of flakiness
This commit is contained in:
parent
dbb06466bf
commit
5f2442450c
1 changed files with 24 additions and 27 deletions
|
@ -1,26 +1,17 @@
|
|||
import factory
|
||||
import tiktoken
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.processor.conversation import utils
|
||||
|
||||
|
||||
class ChatMessageFactory(factory.Factory):
|
||||
class Meta:
|
||||
model = ChatMessage
|
||||
|
||||
content = factory.Faker("paragraph")
|
||||
role = factory.Faker("name")
|
||||
|
||||
|
||||
class TestTruncateMessage:
|
||||
max_prompt_size = 4096
|
||||
max_prompt_size = 10
|
||||
model_name = "gpt-3.5-turbo"
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
|
||||
def test_truncate_message_all_small(self):
|
||||
# Arrange
|
||||
chat_history = ChatMessageFactory.build_batch(500)
|
||||
chat_history = generate_chat_history(50)
|
||||
|
||||
# Act
|
||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||
|
@ -28,15 +19,14 @@ class TestTruncateMessage:
|
|||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties
|
||||
assert len(chat_history) < 500
|
||||
assert len(chat_history) < 50
|
||||
assert len(chat_history) > 1
|
||||
assert tokens <= self.max_prompt_size
|
||||
|
||||
def test_truncate_message_first_large(self):
|
||||
# Arrange
|
||||
chat_history = ChatMessageFactory.build_batch(25)
|
||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
|
||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||
chat_history = generate_chat_history(5)
|
||||
big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?")
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_history.insert(0, big_chat_message)
|
||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
||||
|
@ -53,10 +43,9 @@ class TestTruncateMessage:
|
|||
|
||||
def test_truncate_message_last_large(self):
|
||||
# Arrange
|
||||
chat_history = ChatMessageFactory.build_batch(25)
|
||||
chat_history = generate_chat_history(5)
|
||||
chat_history[0].role = "system" # Mark the first message as system message
|
||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
|
||||
chat_history.insert(0, big_chat_message)
|
||||
|
@ -68,10 +57,10 @@ class TestTruncateMessage:
|
|||
|
||||
# Assert
|
||||
# The original object has been modified. Verify certain properties.
|
||||
assert len(truncated_chat_history) == (
|
||||
len(chat_history) + 1
|
||||
) # Because the system_prompt is popped off from the chat_messages lsit
|
||||
assert len(truncated_chat_history) < 26
|
||||
assert (
|
||||
len(truncated_chat_history) == len(chat_history) + 1
|
||||
) # Because the system_prompt is popped off from the chat_messages list
|
||||
assert len(truncated_chat_history) < 10
|
||||
assert len(truncated_chat_history) > 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
assert initial_tokens > self.max_prompt_size
|
||||
|
@ -79,9 +68,7 @@ class TestTruncateMessage:
|
|||
|
||||
def test_truncate_single_large_non_system_message(self):
|
||||
# Arrange
|
||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
|
||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
||||
big_chat_message.role = "user"
|
||||
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_messages = [big_chat_message]
|
||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
|
@ -100,8 +87,7 @@ class TestTruncateMessage:
|
|||
def test_truncate_single_large_question(self):
|
||||
# Arrange
|
||||
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
|
||||
big_chat_message = ChatMessageFactory.build(content=big_chat_message_content)
|
||||
big_chat_message.role = "user"
|
||||
big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
|
||||
copy_big_chat_message = big_chat_message.copy()
|
||||
chat_messages = [big_chat_message]
|
||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||
|
@ -116,3 +102,14 @@ class TestTruncateMessage:
|
|||
assert final_tokens <= self.max_prompt_size
|
||||
assert len(chat_messages) == 1
|
||||
assert truncated_chat_history[0] != copy_big_chat_message
|
||||
|
||||
|
||||
def generate_content(count):
|
||||
return " ".join([f"{index}" for index, _ in enumerate(range(count))])
|
||||
|
||||
|
||||
def generate_chat_history(count):
|
||||
return [
|
||||
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}")
|
||||
for index, _ in enumerate(range(count))
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue