Truncate message logs to below max supported prompt size by model

- Use tiktoken to count tokens for chat models
- Make conversation turns to add to prompt configurable via method
  argument to generate_chatml_messages_with_context method
This commit is contained in:
Debanjum Singh Solanky 2023-03-25 04:37:55 +07:00
parent 4725416fbd
commit 7e36f421f9
4 changed files with 30 additions and 6 deletions

View file

@ -41,6 +41,7 @@ dependencies = [
"fastapi == 0.77.1", "fastapi == 0.77.1",
"jinja2 == 3.1.2", "jinja2 == 3.1.2",
"openai >= 0.27.0", "openai >= 0.27.0",
"tiktoken >= 0.3.0",
"pillow == 9.3.0", "pillow == 9.3.0",
"pydantic == 1.9.1", "pydantic == 1.9.1",
"pyqt6 == 6.3.1", "pyqt6 == 6.3.1",

View file

@ -247,6 +247,7 @@ Question: {user_query}"""
conversation_primer, conversation_primer,
personality_primer, personality_primer,
conversation_log, conversation_log,
model,
) )
# Get Response from GPT # Get Response from GPT

View file

@ -1,22 +1,44 @@
# Standard Packages # Standard Packages
from datetime import datetime from datetime import datetime
# External Packages
import tiktoken
# Internal Packages # Internal Packages
from khoj.utils.helpers import merge_dicts from khoj.utils.helpers import merge_dicts
def generate_chatml_messages_with_context(user_message, system_message, conversation_log={}): max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
def generate_chatml_messages_with_context(
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
):
"""Generate messages for ChatGPT with context from previous conversation""" """Generate messages for ChatGPT with context from previous conversation"""
# Extract Chat History for Context # Extract Chat History for Context
chat_logs = [f'{chat["message"]}\n\nNotes:\n{chat.get("context","")}' for chat in conversation_log.get("chat", [])] chat_logs = [f'{chat["message"]}\n\nNotes:\n{chat.get("context","")}' for chat in conversation_log.get("chat", [])]
last_backnforth = reciprocal_conversation_to_chatml(chat_logs[-2:]) rest_backnforths = []
rest_backnforth = reciprocal_conversation_to_chatml(chat_logs[-4:-2]) # Extract in reverse chronological order
for user_msg, assistant_msg in zip(chat_logs[-2::-2], chat_logs[::-2]):
if len(rest_backnforths) >= 2 * lookback_turns:
break
rest_backnforths += reciprocal_conversation_to_chatml([user_msg, assistant_msg])[::-1]
# Format user and system messages to chatml format # Format user and system messages to chatml format
system_chatml_message = [message_to_chatml(system_message, "system")] system_chatml_message = [message_to_chatml(system_message, "system")]
user_chatml_message = [message_to_chatml(user_message, "user")] user_chatml_message = [message_to_chatml(user_message, "user")]
return rest_backnforth + system_chatml_message + last_backnforth + user_chatml_message messages = user_chatml_message + rest_backnforths[:2] + system_chatml_message + rest_backnforths[2:]
# Truncate oldest messages from conversation history until under max supported prompt size by model
encoder = tiktoken.encoding_for_model(model_name)
tokens = sum([len(encoder.encode(value)) for message in messages for value in message.values()])
while tokens > max_prompt_size[model_name]:
messages.pop()
tokens = sum([len(encoder.encode(value)) for message in messages for value in message.values()])
# Return message in chronological order
return messages[::-1]
def reciprocal_conversation_to_chatml(message_pair): def reciprocal_conversation_to_chatml(message_pair):

View file

@ -47,7 +47,7 @@ def test_chat_with_no_chat_history_or_retrieved_content(chat_client):
expected_responses = ["Khoj", "khoj"] expected_responses = ["Khoj", "khoj"]
assert response.status_code == 200 assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), ( assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got" + response_message "Expected assistants name, [K|k]hoj, in response but got: " + response_message
) )
@ -69,7 +69,7 @@ def test_answer_from_chat_history(chat_client):
expected_responses = ["Testatron", "testatron"] expected_responses = ["Testatron", "testatron"]
assert response.status_code == 200 assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), ( assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected [T|t]estatron in response but got" + response_message "Expected [T|t]estatron in response but got: " + response_message
) )