mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-20 06:55:08 +00:00
Update truncation test to reduce flakyness in cloud tests
Removed dependency on faker, factory for the truncation tests as that seems to be the point of flakiness
This commit is contained in:
parent
dbb06466bf
commit
5f2442450c
1 changed files with 24 additions and 27 deletions
|
@ -1,26 +1,17 @@
|
||||||
import factory
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
from khoj.processor.conversation import utils
|
from khoj.processor.conversation import utils
|
||||||
|
|
||||||
|
|
||||||
class ChatMessageFactory(factory.Factory):
|
|
||||||
class Meta:
|
|
||||||
model = ChatMessage
|
|
||||||
|
|
||||||
content = factory.Faker("paragraph")
|
|
||||||
role = factory.Faker("name")
|
|
||||||
|
|
||||||
|
|
||||||
class TestTruncateMessage:
|
class TestTruncateMessage:
|
||||||
max_prompt_size = 4096
|
max_prompt_size = 10
|
||||||
model_name = "gpt-3.5-turbo"
|
model_name = "gpt-3.5-turbo"
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
|
|
||||||
def test_truncate_message_all_small(self):
|
def test_truncate_message_all_small(self):
|
||||||
# Arrange
|
# Arrange
|
||||||
chat_history = ChatMessageFactory.build_batch(500)
|
chat_history = generate_chat_history(50)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
truncated_chat_history = utils.truncate_messages(chat_history, self.max_prompt_size, self.model_name)
|
||||||
|
@ -28,15 +19,14 @@ class TestTruncateMessage:
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties
|
# The original object has been modified. Verify certain properties
|
||||||
assert len(chat_history) < 500
|
assert len(chat_history) < 50
|
||||||
assert len(chat_history) > 1
|
assert len(chat_history) > 1
|
||||||
assert tokens <= self.max_prompt_size
|
assert tokens <= self.max_prompt_size
|
||||||
|
|
||||||
def test_truncate_message_first_large(self):
|
def test_truncate_message_first_large(self):
|
||||||
# Arrange
|
# Arrange
|
||||||
chat_history = ChatMessageFactory.build_batch(25)
|
chat_history = generate_chat_history(5)
|
||||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
|
big_chat_message = ChatMessage(role="user", content=f"{generate_content(6)}\nQuestion?")
|
||||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
chat_history.insert(0, big_chat_message)
|
chat_history.insert(0, big_chat_message)
|
||||||
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
tokens = sum([len(self.encoder.encode(message.content)) for message in chat_history])
|
||||||
|
@ -53,10 +43,9 @@ class TestTruncateMessage:
|
||||||
|
|
||||||
def test_truncate_message_last_large(self):
|
def test_truncate_message_last_large(self):
|
||||||
# Arrange
|
# Arrange
|
||||||
chat_history = ChatMessageFactory.build_batch(25)
|
chat_history = generate_chat_history(5)
|
||||||
chat_history[0].role = "system" # Mark the first message as system message
|
chat_history[0].role = "system" # Mark the first message as system message
|
||||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=1000))
|
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
|
||||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
|
|
||||||
chat_history.insert(0, big_chat_message)
|
chat_history.insert(0, big_chat_message)
|
||||||
|
@ -68,10 +57,10 @@ class TestTruncateMessage:
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# The original object has been modified. Verify certain properties.
|
# The original object has been modified. Verify certain properties.
|
||||||
assert len(truncated_chat_history) == (
|
assert (
|
||||||
len(chat_history) + 1
|
len(truncated_chat_history) == len(chat_history) + 1
|
||||||
) # Because the system_prompt is popped off from the chat_messages lsit
|
) # Because the system_prompt is popped off from the chat_messages list
|
||||||
assert len(truncated_chat_history) < 26
|
assert len(truncated_chat_history) < 10
|
||||||
assert len(truncated_chat_history) > 1
|
assert len(truncated_chat_history) > 1
|
||||||
assert truncated_chat_history[0] != copy_big_chat_message
|
assert truncated_chat_history[0] != copy_big_chat_message
|
||||||
assert initial_tokens > self.max_prompt_size
|
assert initial_tokens > self.max_prompt_size
|
||||||
|
@ -79,9 +68,7 @@ class TestTruncateMessage:
|
||||||
|
|
||||||
def test_truncate_single_large_non_system_message(self):
|
def test_truncate_single_large_non_system_message(self):
|
||||||
# Arrange
|
# Arrange
|
||||||
big_chat_message = ChatMessageFactory.build(content=factory.Faker("paragraph", nb_sentences=2000))
|
big_chat_message = ChatMessage(role="user", content=f"{generate_content(11)}\nQuestion?")
|
||||||
big_chat_message.content = big_chat_message.content + "\n" + "Question?"
|
|
||||||
big_chat_message.role = "user"
|
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
chat_messages = [big_chat_message]
|
chat_messages = [big_chat_message]
|
||||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
|
@ -100,8 +87,7 @@ class TestTruncateMessage:
|
||||||
def test_truncate_single_large_question(self):
|
def test_truncate_single_large_question(self):
|
||||||
# Arrange
|
# Arrange
|
||||||
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
|
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
|
||||||
big_chat_message = ChatMessageFactory.build(content=big_chat_message_content)
|
big_chat_message = ChatMessage(role="user", content=big_chat_message_content)
|
||||||
big_chat_message.role = "user"
|
|
||||||
copy_big_chat_message = big_chat_message.copy()
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
chat_messages = [big_chat_message]
|
chat_messages = [big_chat_message]
|
||||||
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
|
@ -116,3 +102,14 @@ class TestTruncateMessage:
|
||||||
assert final_tokens <= self.max_prompt_size
|
assert final_tokens <= self.max_prompt_size
|
||||||
assert len(chat_messages) == 1
|
assert len(chat_messages) == 1
|
||||||
assert truncated_chat_history[0] != copy_big_chat_message
|
assert truncated_chat_history[0] != copy_big_chat_message
|
||||||
|
|
||||||
|
|
||||||
|
def generate_content(count):
|
||||||
|
return " ".join([f"{index}" for index, _ in enumerate(range(count))])
|
||||||
|
|
||||||
|
|
||||||
|
def generate_chat_history(count):
|
||||||
|
return [
|
||||||
|
ChatMessage(role="user" if index % 2 == 0 else "assistant", content=f"{index}")
|
||||||
|
for index, _ in enumerate(range(count))
|
||||||
|
]
|
||||||
|
|
Loading…
Add table
Reference in a new issue