From 175169c1567b3f42657a0b833dfa73955f06024b Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 20 Apr 2024 10:23:30 +0530 Subject: [PATCH] Use set, inferred max token limits wherever chat models are used - User configured max tokens limits weren't being passed to `send_message_to_model_wrapper' - One of the load offline model code paths wasn't reachable. Remove it to simplify code - When max prompt size isn't set infer max tokens based on free VRAM on machine - Use min of app configured max tokens, vram based max tokens and model context window --- src/khoj/database/adapters/__init__.py | 4 +++- src/khoj/processor/conversation/offline/utils.py | 5 +---- src/khoj/processor/conversation/utils.py | 4 ++-- src/khoj/routers/helpers.py | 7 +------ 4 files changed, 7 insertions(+), 13 deletions(-) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index e2b84d17..4752a859 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -777,7 +777,9 @@ class ConversationAdapters: if offline_chat_config and offline_chat_config.enabled and conversation_config.model_type == "offline": if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model) + chat_model = conversation_config.chat_model + max_tokens = conversation_config.max_prompt_size + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) return conversation_config diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index 4a7c69a9..44dec0b6 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -68,7 +68,4 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"): def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int: """Infer max prompt size based on device memory and max context window supported by the model""" vram_based_n_ctx = int(get_device_memory() / 2e6) # based on heuristic - if configured_max_tokens: - return min(configured_max_tokens, model_context_window) - else: - return min(vram_based_n_ctx, model_context_window) + return min(configured_max_tokens, vram_based_n_ctx, model_context_window) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e787eedf..877e5a43 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -13,7 +13,7 @@ from transformers import AutoTokenizer from khoj.database.adapters import ConversationAdapters from khoj.database.models import ClientApplication, KhojUser -from khoj.processor.conversation.offline.utils import download_model +from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens from khoj.utils.helpers import is_none_or_empty, merge_dicts logger = logging.getLogger(__name__) @@ -145,7 +145,7 @@ def generate_chatml_messages_with_context( # Set max prompt size from user config or based on pre-configured for model and machine specs if not max_prompt_size: if loaded_model: - max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf)) + max_prompt_size = infer_max_tokens(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf)) else: max_prompt_size = model_to_prompt_size.get(model_name, 2000) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 3c93385d..20dbb1e4 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -409,7 +409,7 @@ async def send_message_to_model_wrapper( openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() api_key = openai_chat_config.api_key truncated_messages = generate_chatml_messages_with_context( - user_message=message, system_message=system_message, model_name=chat_model + user_message=message, system_message=system_message, model_name=chat_model, max_prompt_size=max_tokens ) openai_response = send_message_to_model( @@ -457,11 +457,6 @@ def generate_chat_response( conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) if conversation_config.model_type == "offline": - if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - chat_model = conversation_config.chat_model - max_tokens = conversation_config.max_prompt_size - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) - loaded_model = state.offline_chat_processor_config.loaded_model chat_response = converse_offline( references=compiled_references,