From 4977b55106e7aac94a5d618fae990b71fb9bcc1a Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sat, 13 Apr 2024 22:15:34 +0530 Subject: [PATCH] Use offline chat prompt config to set context window of loaded chat model Previously you couldn't configure the n_ctx of the loaded offline chat model. This made it hard to use good offline chat model (which these days also have larger context) on machines with lower VRAM --- pyproject.toml | 2 +- .../conversation/offline/chat_model.py | 20 ++++---- .../processor/conversation/offline/utils.py | 46 +++++++++++++------ src/khoj/processor/conversation/utils.py | 27 +++++------ src/khoj/routers/api.py | 4 +- src/khoj/routers/helpers.py | 10 ++-- src/khoj/utils/config.py | 4 +- src/khoj/utils/helpers.py | 12 +++++ 8 files changed, 81 insertions(+), 44 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0b7483a1..c9c96691 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,7 @@ dependencies = [ "phonenumbers == 8.13.27", "markdownify ~= 0.11.6", "websockets == 12.0", + "psutil >= 5.8.0", ] dynamic = ["version"] @@ -105,7 +106,6 @@ dev = [ "pytest-asyncio == 0.21.1", "freezegun >= 1.2.0", "factory-boy >= 3.2.1", - "psutil >= 5.8.0", "mypy >= 1.0.1", "black >= 23.1.0", "pre-commit >= 3.0.4", diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 10dc08fa..a559df22 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -30,6 +30,7 @@ def extract_questions_offline( use_history: bool = True, should_extract_questions: bool = True, location_data: LocationData = None, + max_prompt_size: int = None, ) -> List[str]: """ Infer search queries to retrieve relevant notes to answer user query @@ -41,7 +42,7 @@ def extract_questions_offline( return all_questions assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model) + offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" @@ -67,12 +68,14 @@ def extract_questions_offline( location=location, ) messages = generate_chatml_messages_with_context( - example_questions, model_name=model, loaded_model=offline_chat_model + example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size ) state.chat_lock.acquire() try: - response = send_message_to_model_offline(messages, loaded_model=offline_chat_model) + response = send_message_to_model_offline( + messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size + ) finally: state.chat_lock.release() @@ -138,7 +141,7 @@ def converse_offline( """ # Initialize Variables assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model) + offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) compiled_references_message = "\n\n".join({f"{item}" for item in references}) current_date = datetime.now().strftime("%Y-%m-%d") @@ -190,18 +193,18 @@ def converse_offline( ) g = ThreadedGenerator(references, online_results, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, offline_chat_model)) + t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size)) t.start() return g -def llm_thread(g, messages: List[ChatMessage], model: Any): +def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None): stop_phrases = ["", "INST]", "Notes:"] state.chat_lock.acquire() try: response_iterator = send_message_to_model_offline( - messages, loaded_model=model, stop=stop_phrases, streaming=True + messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True ) for response in response_iterator: g.send(response["choices"][0]["delta"].get("content", "")) @@ -216,9 +219,10 @@ def send_message_to_model_offline( model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", streaming=False, stop=[], + max_prompt_size: int = None, ): assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" - offline_chat_model = loaded_model or download_model(model) + offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size) messages_dict = [{"role": message.role, "content": message.content} for message in messages] response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming) if streaming: diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index b711c11a..c2b08bfa 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -1,18 +1,19 @@ import glob import logging +import math import os from huggingface_hub.constants import HF_HUB_CACHE from khoj.utils import state +from khoj.utils.helpers import get_device_memory logger = logging.getLogger(__name__) -def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"): - from llama_cpp.llama import Llama - - # Initialize Model Parameters. Use n_ctx=0 to get context size from the model +def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None): + # Initialize Model Parameters + # Use n_ctx=0 to get context size from the model kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False} # Decide whether to load model to GPU or CPU @@ -23,23 +24,33 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"): model_path = load_model_from_cache(repo_id, filename) chat_model = None try: - if model_path: - chat_model = Llama(model_path, **kwargs) - else: - Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) + chat_model = load_model(model_path, repo_id, filename, kwargs) except: # Load model on CPU if GPU is not available kwargs["n_gpu_layers"], device = 0, "cpu" - if model_path: - chat_model = Llama(model_path, **kwargs) - else: - chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) + chat_model = load_model(model_path, repo_id, filename, kwargs) - logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}") + # Now load the model with context size set based on: + # 1. context size supported by model and + # 2. configured size or machine (V)RAM + kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens) + chat_model = load_model(model_path, repo_id, filename, kwargs) + logger.debug( + f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window." + ) return chat_model +def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}): + from llama_cpp.llama import Llama + + if model_path: + return Llama(model_path, **kwargs) + else: + return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) + + def load_model_from_cache(repo_id: str, filename: str, repo_type="models"): # Construct the path to the model file in the cache directory repo_org, repo_name = repo_id.split("/") @@ -52,3 +63,12 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"): return paths[0] else: return None + + +def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int: + """Infer max prompt size based on device memory and max context window supported by the model""" + vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic + if configured_max_tokens: + return min(configured_max_tokens, model_context_window) + else: + return min(vram_based_n_ctx, model_context_window) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 845ccb48..e787eedf 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,5 +1,6 @@ import json import logging +import math import queue from datetime import datetime from time import perf_counter @@ -141,14 +142,12 @@ def generate_chatml_messages_with_context( tokenizer_name=None, ): """Generate messages for ChatGPT with context from previous conversation""" - # Set max prompt size from user config, pre-configured for model or to default prompt size - try: - max_prompt_size = max_prompt_size or model_to_prompt_size[model_name] - except: - max_prompt_size = 2000 - logger.warning( - f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window." - ) + # Set max prompt size from user config or based on pre-configured for model and machine specs + if not max_prompt_size: + if loaded_model: + max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf)) + else: + max_prompt_size = model_to_prompt_size.get(model_name, 2000) # Scale lookback turns proportional to max prompt size supported by model lookback_turns = max_prompt_size // 750 @@ -187,7 +186,7 @@ def truncate_messages( max_prompt_size, model_name: str, loaded_model: Optional[Llama] = None, - tokenizer_name=None, + tokenizer_name="hf-internal-testing/llama-tokenizer", ) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" @@ -197,15 +196,11 @@ def truncate_messages( elif model_name.startswith("gpt-"): encoder = tiktoken.encoding_for_model(model_name) else: - try: - encoder = download_model(model_name).tokenizer() - except: - encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name]) + encoder = download_model(model_name).tokenizer() except: - default_tokenizer = "hf-internal-testing/llama-tokenizer" - encoder = AutoTokenizer.from_pretrained(default_tokenizer) + encoder = AutoTokenizer.from_pretrained(tokenizer_name) logger.warning( - f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." + f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." ) # Extract system message from messages diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 7f546832..c511b6d9 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -315,8 +315,9 @@ async def extract_references_and_questions( using_offline_chat = True default_offline_llm = await ConversationAdapters.get_default_offline_llm() chat_model = default_offline_llm.chat_model + max_tokens = default_offline_llm.max_prompt_size if state.offline_chat_processor_config is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) loaded_model = state.offline_chat_processor_config.loaded_model @@ -326,6 +327,7 @@ async def extract_references_and_questions( conversation_log=meta_log, should_extract_questions=True, location_data=location_data, + max_prompt_size=conversation_config.max_prompt_size, ) elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = await ConversationAdapters.get_openai_chat_config() diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index cbf29c02..06d849ca 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -82,9 +82,10 @@ async def is_ready_to_chat(user: KhojUser): if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline": chat_model = user_conversation_config.chat_model + max_tokens = user_conversation_config.max_prompt_size if state.offline_chat_processor_config is None: logger.info("Loading Offline Chat Model...") - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) return True ready = has_openai_config or has_offline_config @@ -385,10 +386,11 @@ async def send_message_to_model_wrapper( raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") chat_model = conversation_config.chat_model + max_tokens = conversation_config.max_prompt_size if conversation_config.model_type == "offline": if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model) + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) loaded_model = state.offline_chat_processor_config.loaded_model truncated_messages = generate_chatml_messages_with_context( @@ -455,7 +457,9 @@ def generate_chat_response( conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) if conversation_config.model_type == "offline": if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: - state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model) + chat_model = conversation_config.chat_model + max_tokens = conversation_config.max_prompt_size + state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens) loaded_model = state.offline_chat_processor_config.loaded_model chat_response = converse_offline( diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 3f95030f..1732271a 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -69,11 +69,11 @@ class OfflineChatProcessorConfig: class OfflineChatProcessorModel: - def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"): + def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", max_tokens: int = None): self.chat_model = chat_model self.loaded_model = None try: - self.loaded_model = download_model(self.chat_model) + self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens) except ValueError as e: self.loaded_model = None logger.error(f"Error while loading offline chat model: {e}", exc_info=True) diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index a61387d5..04974b7d 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -17,6 +17,7 @@ from time import perf_counter from typing import TYPE_CHECKING, Optional, Union from urllib.parse import urlparse +import psutil import torch from asgiref.sync import sync_to_async from magika import Magika @@ -271,6 +272,17 @@ def log_telemetry( return request_body +def get_device_memory() -> int: + """Get device memory in GB""" + device = get_device() + if device.type == "cuda": + return torch.cuda.get_device_properties(device).total_memory + elif device.type == "mps": + return torch.mps.driver_allocated_memory() + else: + return psutil.virtual_memory().total + + def get_device() -> torch.device: """Get device to run model on""" if torch.cuda.is_available():