mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Handle msg truncation when question is larger than max prompt size
Notice and truncate the question it self at this point
This commit is contained in:
parent
c6487f2e48
commit
4228965c9b
4 changed files with 29 additions and 8 deletions
|
@ -21,6 +21,7 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
|
||||||
|
|
||||||
# Check if the model is already downloaded
|
# Check if the model is already downloaded
|
||||||
model_path = load_model_from_cache(repo_id, filename)
|
model_path = load_model_from_cache(repo_id, filename)
|
||||||
|
chat_model = None
|
||||||
try:
|
try:
|
||||||
if model_path:
|
if model_path:
|
||||||
chat_model = Llama(model_path, **kwargs)
|
chat_model = Llama(model_path, **kwargs)
|
||||||
|
|
|
@ -101,8 +101,3 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, model_
|
||||||
chat(messages=messages)
|
chat(messages=messages)
|
||||||
|
|
||||||
g.close()
|
g.close()
|
||||||
|
|
||||||
|
|
||||||
def extract_summaries(metadata):
|
|
||||||
"""Extract summaries from metadata"""
|
|
||||||
return "".join([f'\n{session["summary"]}' for session in metadata])
|
|
||||||
|
|
|
@ -232,12 +232,17 @@ def truncate_messages(
|
||||||
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
original_question = "\n".join(messages[0].content.split("\n")[-1:]) if type(messages[0].content) == str else ""
|
||||||
original_question = f"\n{original_question}"
|
original_question = f"\n{original_question}"
|
||||||
original_question_tokens = len(encoder.encode(original_question))
|
original_question_tokens = len(encoder.encode(original_question))
|
||||||
remaining_tokens = max_prompt_size - original_question_tokens - system_message_tokens
|
remaining_tokens = max_prompt_size - system_message_tokens
|
||||||
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
if remaining_tokens > original_question_tokens:
|
||||||
|
remaining_tokens -= original_question_tokens
|
||||||
|
truncated_message = encoder.decode(encoder.encode(current_message)[:remaining_tokens]).strip()
|
||||||
|
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
||||||
|
else:
|
||||||
|
truncated_message = encoder.decode(encoder.encode(original_question)[:remaining_tokens]).strip()
|
||||||
|
messages = [ChatMessage(content=truncated_message, role=messages[0].role)]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
f"Truncate current message to fit within max prompt size of {max_prompt_size} supported by {model_name} model:\n {truncated_message}"
|
||||||
)
|
)
|
||||||
messages = [ChatMessage(content=truncated_message + original_question, role=messages[0].role)]
|
|
||||||
|
|
||||||
return messages + [system_message] if system_message else messages
|
return messages + [system_message] if system_message else messages
|
||||||
|
|
||||||
|
|
|
@ -96,3 +96,23 @@ class TestTruncateMessage:
|
||||||
assert final_tokens <= self.max_prompt_size
|
assert final_tokens <= self.max_prompt_size
|
||||||
assert len(chat_messages) == 1
|
assert len(chat_messages) == 1
|
||||||
assert truncated_chat_history[0] != copy_big_chat_message
|
assert truncated_chat_history[0] != copy_big_chat_message
|
||||||
|
|
||||||
|
def test_truncate_single_large_question(self):
|
||||||
|
# Arrange
|
||||||
|
big_chat_message_content = " ".join(["hi"] * (self.max_prompt_size + 1))
|
||||||
|
big_chat_message = ChatMessageFactory.build(content=big_chat_message_content)
|
||||||
|
big_chat_message.role = "user"
|
||||||
|
copy_big_chat_message = big_chat_message.copy()
|
||||||
|
chat_messages = [big_chat_message]
|
||||||
|
initial_tokens = sum([len(self.encoder.encode(message.content)) for message in chat_messages])
|
||||||
|
|
||||||
|
# Act
|
||||||
|
truncated_chat_history = utils.truncate_messages(chat_messages, self.max_prompt_size, self.model_name)
|
||||||
|
final_tokens = sum([len(self.encoder.encode(message.content)) for message in truncated_chat_history])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
# The original object has been modified. Verify certain properties
|
||||||
|
assert initial_tokens > self.max_prompt_size
|
||||||
|
assert final_tokens <= self.max_prompt_size
|
||||||
|
assert len(chat_messages) == 1
|
||||||
|
assert truncated_chat_history[0] != copy_big_chat_message
|
||||||
|
|
Loading…
Reference in a new issue