mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Use max_prompt_size, tokenizer from config for chat model context stuffing
This commit is contained in:
parent
116595b351
commit
df1d74a879
5 changed files with 48 additions and 11 deletions
|
@ -127,6 +127,8 @@ def converse_offline(
|
|||
loaded_model: Union[Any, None] = None,
|
||||
completion_func=None,
|
||||
conversation_command=ConversationCommand.Default,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
"""
|
||||
Converse with user using Llama
|
||||
|
@ -158,6 +160,8 @@ def converse_offline(
|
|||
prompts.system_prompt_message_llamav2,
|
||||
conversation_log,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
)
|
||||
|
||||
g = ThreadedGenerator(references, completion_func=completion_func)
|
||||
|
|
|
@ -116,6 +116,8 @@ def converse(
|
|||
temperature: float = 0.2,
|
||||
completion_func=None,
|
||||
conversation_command=ConversationCommand.Default,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
):
|
||||
"""
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
|
@ -141,6 +143,8 @@ def converse(
|
|||
prompts.personality.format(),
|
||||
conversation_log,
|
||||
model,
|
||||
max_prompt_size,
|
||||
tokenizer_name,
|
||||
)
|
||||
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
||||
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
||||
|
|
|
@ -13,17 +13,16 @@ from transformers import AutoTokenizer
|
|||
import queue
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
max_prompt_size = {
|
||||
model_to_prompt_size = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-4": 8192,
|
||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
|
||||
"gpt-3.5-turbo-16k": 15000,
|
||||
"default": 1600,
|
||||
}
|
||||
tokenizer = {
|
||||
model_to_tokenizer = {
|
||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
|
||||
"default": "hf-internal-testing/llama-tokenizer",
|
||||
}
|
||||
|
||||
|
||||
|
@ -86,7 +85,13 @@ def message_to_log(
|
|||
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
|
||||
user_message,
|
||||
system_message,
|
||||
conversation_log={},
|
||||
model_name="gpt-3.5-turbo",
|
||||
lookback_turns=2,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
):
|
||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||
# Extract Chat History for Context
|
||||
|
@ -108,20 +113,38 @@ def generate_chatml_messages_with_context(
|
|||
|
||||
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
||||
|
||||
# Set max prompt size from user config, pre-configured for model or to default prompt size
|
||||
try:
|
||||
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
|
||||
except:
|
||||
max_prompt_size = 2000
|
||||
logger.warning(
|
||||
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
|
||||
)
|
||||
|
||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||
messages = truncate_messages(messages, max_prompt_size.get(model_name, max_prompt_size["default"]), model_name)
|
||||
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
|
||||
|
||||
# Return message in chronological order
|
||||
return messages[::-1]
|
||||
|
||||
|
||||
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name: str) -> list[ChatMessage]:
|
||||
def truncate_messages(
|
||||
messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
|
||||
if model_name.startswith("gpt-"):
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer.get(model_name, tokenizer["default"]))
|
||||
try:
|
||||
if model_name.startswith("gpt-"):
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
||||
except:
|
||||
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
||||
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
|
||||
logger.warning(
|
||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||
)
|
||||
|
||||
system_message = messages.pop()
|
||||
system_message_tokens = len(encoder.encode(system_message.content))
|
||||
|
|
|
@ -123,6 +123,8 @@ def generate_chat_response(
|
|||
completion_func=partial_completion,
|
||||
conversation_command=conversation_command,
|
||||
model=state.processor_config.conversation.offline_chat.chat_model,
|
||||
max_prompt_size=state.processor_config.conversation.max_prompt_size,
|
||||
tokenizer_name=state.processor_config.conversation.tokenizer,
|
||||
)
|
||||
|
||||
elif state.processor_config.conversation.openai_model:
|
||||
|
@ -136,6 +138,8 @@ def generate_chat_response(
|
|||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
conversation_command=conversation_command,
|
||||
max_prompt_size=state.processor_config.conversation.max_prompt_size,
|
||||
tokenizer_name=state.processor_config.conversation.tokenizer,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
@ -95,6 +95,8 @@ class ConversationProcessorConfigModel:
|
|||
self.openai_model = conversation_config.openai
|
||||
self.gpt4all_model = GPT4AllProcessorConfig()
|
||||
self.offline_chat = conversation_config.offline_chat
|
||||
self.max_prompt_size = conversation_config.max_prompt_size
|
||||
self.tokenizer = conversation_config.tokenizer
|
||||
self.conversation_logfile = Path(conversation_config.conversation_logfile)
|
||||
self.chat_session: List[str] = []
|
||||
self.meta_log: dict = {}
|
||||
|
|
Loading…
Add table
Reference in a new issue