Use set, inferred max token limits wherever chat models are used (#713)

- User configured max tokens limits weren't being passed to
  `send_message_to_model_wrapper'
- One of the load offline model code paths wasn't reachable. Remove it
  to simplify code
- When max prompt size isn't set infer max tokens based on free VRAM
  on machine
- Use min of app configured max tokens, vram based max tokens and
  model context window
This commit is contained in:
Debanjum 2024-04-23 16:42:35 +05:30 committed by GitHub
commit 419b044ac5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 7 additions and 13 deletions

View file

@ -777,7 +777,9 @@ class ConversationAdapters:
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model)
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return conversation_config

View file

@ -68,7 +68,4 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
"""Infer max prompt size based on device memory and max context window supported by the model"""
vram_based_n_ctx = int(get_device_memory() / 2e6) # based on heuristic
if configured_max_tokens:
return min(configured_max_tokens, model_context_window)
else:
return min(vram_based_n_ctx, model_context_window)
return min(configured_max_tokens, vram_based_n_ctx, model_context_window)

View file

@ -13,7 +13,7 @@ from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.utils.helpers import is_none_or_empty, merge_dicts
logger = logging.getLogger(__name__)
@ -145,7 +145,7 @@ def generate_chatml_messages_with_context(
# Set max prompt size from user config or based on pre-configured for model and machine specs
if not max_prompt_size:
if loaded_model:
max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
max_prompt_size = infer_max_tokens(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
else:
max_prompt_size = model_to_prompt_size.get(model_name, 2000)

View file

@ -409,7 +409,7 @@ async def send_message_to_model_wrapper(
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
api_key = openai_chat_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message, system_message=system_message, model_name=chat_model
user_message=message, system_message=system_message, model_name=chat_model, max_prompt_size=max_tokens
)
openai_response = send_message_to_model(
@ -457,11 +457,6 @@ def generate_chat_response(
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline(
references=compiled_references,