Use set, inferred max token limits wherever chat models are used (#713)

- User configured max tokens limits weren't being passed to
  `send_message_to_model_wrapper'
- One of the load offline model code paths wasn't reachable. Remove it
  to simplify code
- When max prompt size isn't set infer max tokens based on free VRAM
  on machine
- Use min of app configured max tokens, vram based max tokens and
  model context window
This commit is contained in:
Debanjum 2024-04-23 16:42:35 +05:30 committed by GitHub
commit 419b044ac5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 7 additions and 13 deletions

View file

@ -777,7 +777,9 @@ class ConversationAdapters:
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline": if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model) chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return conversation_config return conversation_config

View file

@ -68,7 +68,4 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int: def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
"""Infer max prompt size based on device memory and max context window supported by the model""" """Infer max prompt size based on device memory and max context window supported by the model"""
vram_based_n_ctx = int(get_device_memory() / 2e6) # based on heuristic vram_based_n_ctx = int(get_device_memory() / 2e6) # based on heuristic
if configured_max_tokens: return min(configured_max_tokens, vram_based_n_ctx, model_context_window)
return min(configured_max_tokens, model_context_window)
else:
return min(vram_based_n_ctx, model_context_window)

View file

@ -13,7 +13,7 @@ from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ClientApplication, KhojUser from khoj.database.models import ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.utils.helpers import is_none_or_empty, merge_dicts from khoj.utils.helpers import is_none_or_empty, merge_dicts
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -145,7 +145,7 @@ def generate_chatml_messages_with_context(
# Set max prompt size from user config or based on pre-configured for model and machine specs # Set max prompt size from user config or based on pre-configured for model and machine specs
if not max_prompt_size: if not max_prompt_size:
if loaded_model: if loaded_model:
max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf)) max_prompt_size = infer_max_tokens(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
else: else:
max_prompt_size = model_to_prompt_size.get(model_name, 2000) max_prompt_size = model_to_prompt_size.get(model_name, 2000)

View file

@ -409,7 +409,7 @@ async def send_message_to_model_wrapper(
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
user_message=message, system_message=system_message, model_name=chat_model user_message=message, system_message=system_message, model_name=chat_model, max_prompt_size=max_tokens
) )
openai_response = send_message_to_model( openai_response = send_message_to_model(
@ -457,11 +457,6 @@ def generate_chat_response(
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
if conversation_config.model_type == "offline": if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline( chat_response = converse_offline(
references=compiled_references, references=compiled_references,