mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Add default tokenizer, max_prompt as fallback for non-default offline chat models
Pass user configured chat model as argument to use by converse_offline The proper fix for this would allow users to configure the max_prompt and tokenizer to use (while supplying default ones, if none provided) For now, this is a reasonable start.
This commit is contained in:
parent
56bd69d5af
commit
1ad8b150e8
3 changed files with 12 additions and 5 deletions
|
@ -59,8 +59,8 @@ dependencies = [
|
|||
"bs4 >= 0.0.1",
|
||||
"anyio == 3.7.1",
|
||||
"pymupdf >= 1.23.3",
|
||||
"gpt4all == 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"gpt4all == 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'",
|
||||
"gpt4all >= 1.0.12; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"gpt4all >= 1.0.12; platform_system == 'Windows' or platform_system == 'Darwin'",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
|
|
@ -19,8 +19,12 @@ max_prompt_size = {
|
|||
"gpt-4": 8192,
|
||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
|
||||
"gpt-3.5-turbo-16k": 15000,
|
||||
"default": 1600,
|
||||
}
|
||||
tokenizer = {
|
||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
|
||||
"default": "hf-internal-testing/llama-tokenizer",
|
||||
}
|
||||
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer"}
|
||||
|
||||
|
||||
class ThreadedGenerator:
|
||||
|
@ -105,7 +109,7 @@ def generate_chatml_messages_with_context(
|
|||
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
||||
|
||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||
messages = truncate_messages(messages, max_prompt_size[model_name], model_name)
|
||||
messages = truncate_messages(messages, max_prompt_size.get(model_name, max_prompt_size["default"]), model_name)
|
||||
|
||||
# Return message in chronological order
|
||||
return messages[::-1]
|
||||
|
@ -116,8 +120,10 @@ def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name)
|
|||
|
||||
if "llama" in model_name:
|
||||
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
|
||||
else:
|
||||
elif "gpt" in model_name:
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
encoder = LlamaTokenizerFast.from_pretrained(tokenizer["default"])
|
||||
|
||||
system_message = messages.pop()
|
||||
system_message_tokens = len(encoder.encode(system_message.content))
|
||||
|
|
|
@ -122,6 +122,7 @@ def generate_chat_response(
|
|||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
conversation_command=conversation_command,
|
||||
model=state.processor_config.conversation.gpt4all_model.chat_model,
|
||||
)
|
||||
|
||||
elif state.processor_config.conversation.openai_model:
|
||||
|
|
Loading…
Reference in a new issue