mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-24 07:55:07 +01:00
Use set, inferred max token limits wherever chat models are used (#713)
- User configured max tokens limits weren't being passed to `send_message_to_model_wrapper' - One of the load offline model code paths wasn't reachable. Remove it to simplify code - When max prompt size isn't set infer max tokens based on free VRAM on machine - Use min of app configured max tokens, vram based max tokens and model context window
This commit is contained in:
commit
419b044ac5
4 changed files with 7 additions and 13 deletions
|
@ -777,7 +777,9 @@ class ConversationAdapters:
|
||||||
|
|
||||||
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
|
if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline":
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model)
|
chat_model = conversation_config.chat_model
|
||||||
|
max_tokens = conversation_config.max_prompt_size
|
||||||
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
|
|
||||||
return conversation_config
|
return conversation_config
|
||||||
|
|
||||||
|
|
|
@ -68,7 +68,4 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
||||||
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
|
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
|
||||||
"""Infer max prompt size based on device memory and max context window supported by the model"""
|
"""Infer max prompt size based on device memory and max context window supported by the model"""
|
||||||
vram_based_n_ctx = int(get_device_memory() / 2e6) # based on heuristic
|
vram_based_n_ctx = int(get_device_memory() / 2e6) # based on heuristic
|
||||||
if configured_max_tokens:
|
return min(configured_max_tokens, vram_based_n_ctx, model_context_window)
|
||||||
return min(configured_max_tokens, model_context_window)
|
|
||||||
else:
|
|
||||||
return min(vram_based_n_ctx, model_context_window)
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from transformers import AutoTokenizer
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters
|
from khoj.database.adapters import ConversationAdapters
|
||||||
from khoj.database.models import ClientApplication, KhojUser
|
from khoj.database.models import ClientApplication, KhojUser
|
||||||
from khoj.processor.conversation.offline.utils import download_model
|
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -145,7 +145,7 @@ def generate_chatml_messages_with_context(
|
||||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||||
if not max_prompt_size:
|
if not max_prompt_size:
|
||||||
if loaded_model:
|
if loaded_model:
|
||||||
max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
|
max_prompt_size = infer_max_tokens(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
|
||||||
else:
|
else:
|
||||||
max_prompt_size = model_to_prompt_size.get(model_name, 2000)
|
max_prompt_size = model_to_prompt_size.get(model_name, 2000)
|
||||||
|
|
||||||
|
|
|
@ -409,7 +409,7 @@ async def send_message_to_model_wrapper(
|
||||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||||
api_key = openai_chat_config.api_key
|
api_key = openai_chat_config.api_key
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
user_message=message, system_message=system_message, model_name=chat_model
|
user_message=message, system_message=system_message, model_name=chat_model, max_prompt_size=max_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
openai_response = send_message_to_model(
|
openai_response = send_message_to_model(
|
||||||
|
@ -457,11 +457,6 @@ def generate_chat_response(
|
||||||
|
|
||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||||
if conversation_config.model_type == "offline":
|
if conversation_config.model_type == "offline":
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
|
||||||
chat_model = conversation_config.chat_model
|
|
||||||
max_tokens = conversation_config.max_prompt_size
|
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
chat_response = converse_offline(
|
chat_response = converse_offline(
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
|
|
Loading…
Reference in a new issue