Add default tokenizer, max_prompt as fallback for non-default offline chat models

Pass user configured chat model as argument to use by converse_offline

The proper fix for this would allow users to configure the max_prompt
and tokenizer to use (while supplying default ones, if none provided)
For now, this is a reasonable start.
This commit is contained in:
Debanjum Singh Solanky 2023-10-13 22:26:59 -07:00
parent 56bd69d5af
commit 1ad8b150e8
3 changed files with 12 additions and 5 deletions

View file

@ -59,8 +59,8 @@ dependencies = [
"bs4 >= 0.0.1", "bs4 >= 0.0.1",
"anyio == 3.7.1", "anyio == 3.7.1",
"pymupdf >= 1.23.3", "pymupdf >= 1.23.3",
"gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'", "gpt4all >= 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'",
"gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'", "gpt4all >= 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'",
] ]
dynamic = ["version"] dynamic = ["version"]

View file

@ -19,8 +19,12 @@ max_prompt_size = {
"gpt-4": 8192, "gpt-4": 8192,
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548, "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
"gpt-3.5-turbo-16k": 15000, "gpt-3.5-turbo-16k": 15000,
"default": 1600,
}
tokenizer = {
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
"default": "hf-internal-testing/llama-tokenizer",
} }
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer"}
class ThreadedGenerator: class ThreadedGenerator:
@ -105,7 +109,7 @@ def generate_chatml_messages_with_context(
messages = user_chatml_message + rest_backnforths + system_chatml_message messages = user_chatml_message + rest_backnforths + system_chatml_message
# Truncate oldest messages from conversation history until under max supported prompt size by model # Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_messages(messages, max_prompt_size[model_name], model_name) messages = truncate_messages(messages, max_prompt_size.get(model_name, max_prompt_size["default"]), model_name)
# Return message in chronological order # Return message in chronological order
return messages[::-1] return messages[::-1]
@ -116,8 +120,10 @@ def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name)
if "llama" in model_name: if "llama" in model_name:
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name]) encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
else: elif "gpt" in model_name:
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
else:
encoder = LlamaTokenizerFast.from_pretrained(tokenizer["default"])
system_message = messages.pop() system_message = messages.pop()
system_message_tokens = len(encoder.encode(system_message.content)) system_message_tokens = len(encoder.encode(system_message.content))

View file

@ -122,6 +122,7 @@ def generate_chat_response(
conversation_log=meta_log, conversation_log=meta_log,
completion_func=partial_completion, completion_func=partial_completion,
conversation_command=conversation_command, conversation_command=conversation_command,
model=state.processor_config.conversation.gpt4all_model.chat_model,
) )
elif state.processor_config.conversation.openai_model: elif state.processor_config.conversation.openai_model: