Use max_prompt_size, tokenizer from config for chat model context stuffing

This commit is contained in:
Debanjum Singh Solanky 2023-10-15 16:33:26 -07:00
parent 116595b351
commit df1d74a879
5 changed files with 48 additions and 11 deletions

View file

@ -127,6 +127,8 @@ def converse_offline(
loaded_model: Union[Any, None] = None,
completion_func=None,
conversation_command=ConversationCommand.Default,
max_prompt_size=None,
tokenizer_name=None,
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
Converse with user using Llama
@ -158,6 +160,8 @@ def converse_offline(
prompts.system_prompt_message_llamav2,
conversation_log,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
)
g = ThreadedGenerator(references, completion_func=completion_func)

View file

@ -116,6 +116,8 @@ def converse(
temperature: float = 0.2,
completion_func=None,
conversation_command=ConversationCommand.Default,
max_prompt_size=None,
tokenizer_name=None,
):
"""
Converse with user using OpenAI's ChatGPT
@ -141,6 +143,8 @@ def converse(
prompts.personality.format(),
conversation_log,
model,
max_prompt_size,
tokenizer_name,
)
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}")

View file

@ -13,17 +13,16 @@ from transformers import AutoTokenizer
import queue
from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__)
max_prompt_size = {
model_to_prompt_size = {
"gpt-3.5-turbo": 4096,
"gpt-4": 8192,
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
"gpt-3.5-turbo-16k": 15000,
"default": 1600,
}
tokenizer = {
model_to_tokenizer = {
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
"default": "hf-internal-testing/llama-tokenizer",
}
@ -86,7 +85,13 @@ def message_to_log(
def generate_chatml_messages_with_context(
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
user_message,
system_message,
conversation_log={},
model_name="gpt-3.5-turbo",
lookback_turns=2,
max_prompt_size=None,
tokenizer_name=None,
):
"""Generate messages for ChatGPT with context from previous conversation"""
# Extract Chat History for Context
@ -108,20 +113,38 @@ def generate_chatml_messages_with_context(
messages = user_chatml_message + rest_backnforths + system_chatml_message
# Set max prompt size from user config, pre-configured for model or to default prompt size
try:
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
except:
max_prompt_size = 2000
logger.warning(
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
)
# Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_messages(messages, max_prompt_size.get(model_name, max_prompt_size["default"]), model_name)
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
# Return message in chronological order
return messages[::-1]
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name: str) -> list[ChatMessage]:
def truncate_messages(
messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
if model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name)
else:
encoder = AutoTokenizer.from_pretrained(tokenizer.get(model_name, tokenizer["default"]))
try:
if model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name)
else:
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
except:
default_tokenizer = "hf-internal-testing/llama-tokenizer"
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
logger.warning(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)
system_message = messages.pop()
system_message_tokens = len(encoder.encode(system_message.content))

View file

@ -123,6 +123,8 @@ def generate_chat_response(
completion_func=partial_completion,
conversation_command=conversation_command,
model=state.processor_config.conversation.offline_chat.chat_model,
max_prompt_size=state.processor_config.conversation.max_prompt_size,
tokenizer_name=state.processor_config.conversation.tokenizer,
)
elif state.processor_config.conversation.openai_model:
@ -136,6 +138,8 @@ def generate_chat_response(
api_key=api_key,
completion_func=partial_completion,
conversation_command=conversation_command,
max_prompt_size=state.processor_config.conversation.max_prompt_size,
tokenizer_name=state.processor_config.conversation.tokenizer,
)
except Exception as e:

View file

@ -95,6 +95,8 @@ class ConversationProcessorConfigModel:
self.openai_model = conversation_config.openai
self.gpt4all_model = GPT4AllProcessorConfig()
self.offline_chat = conversation_config.offline_chat
self.max_prompt_size = conversation_config.max_prompt_size
self.tokenizer = conversation_config.tokenizer
self.conversation_logfile = Path(conversation_config.conversation_logfile)
self.chat_session: List[str] = []
self.meta_log: dict = {}