Use offline chat prompt config to set context window of loaded chat model

Previously you couldn't configure the n_ctx of the loaded offline chat
model. This made it hard to use good offline chat model (which these
days also have larger context) on machines with lower VRAM
This commit is contained in:
Debanjum Singh Solanky 2024-04-13 22:15:34 +05:30
parent 689202e00e
commit 4977b55106
8 changed files with 81 additions and 44 deletions

View file

@ -78,6 +78,7 @@ dependencies = [
"phonenumbers == 8.13.27", "phonenumbers == 8.13.27",
"markdownify ~= 0.11.6", "markdownify ~= 0.11.6",
"websockets == 12.0", "websockets == 12.0",
"psutil >= 5.8.0",
] ]
dynamic = ["version"] dynamic = ["version"]
@ -105,7 +106,6 @@ dev = [
"pytest-asyncio == 0.21.1", "pytest-asyncio == 0.21.1",
"freezegun >= 1.2.0", "freezegun >= 1.2.0",
"factory-boy >= 3.2.1", "factory-boy >= 3.2.1",
"psutil >= 5.8.0",
"mypy >= 1.0.1", "mypy >= 1.0.1",
"black >= 23.1.0", "black >= 23.1.0",
"pre-commit >= 3.0.4", "pre-commit >= 3.0.4",

View file

@ -30,6 +30,7 @@ def extract_questions_offline(
use_history: bool = True, use_history: bool = True,
should_extract_questions: bool = True, should_extract_questions: bool = True,
location_data: LocationData = None, location_data: LocationData = None,
max_prompt_size: int = None,
) -> List[str]: ) -> List[str]:
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
@ -41,7 +42,7 @@ def extract_questions_offline(
return all_questions return all_questions
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
@ -67,12 +68,14 @@ def extract_questions_offline(
location=location, location=location,
) )
messages = generate_chatml_messages_with_context( messages = generate_chatml_messages_with_context(
example_questions, model_name=model, loaded_model=offline_chat_model example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
) )
state.chat_lock.acquire() state.chat_lock.acquire()
try: try:
response = send_message_to_model_offline(messages, loaded_model=offline_chat_model) response = send_message_to_model_offline(
messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
)
finally: finally:
state.chat_lock.release() state.chat_lock.release()
@ -138,7 +141,7 @@ def converse_offline(
""" """
# Initialize Variables # Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
compiled_references_message = "\n\n".join({f"{item}" for item in references}) compiled_references_message = "\n\n".join({f"{item}" for item in references})
current_date = datetime.now().strftime("%Y-%m-%d") current_date = datetime.now().strftime("%Y-%m-%d")
@ -190,18 +193,18 @@ def converse_offline(
) )
g = ThreadedGenerator(references, online_results, completion_func=completion_func) g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model)) t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
t.start() t.start()
return g return g
def llm_thread(g, messages: List[ChatMessage], model: Any): def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
stop_phrases = ["<s>", "INST]", "Notes:"] stop_phrases = ["<s>", "INST]", "Notes:"]
state.chat_lock.acquire() state.chat_lock.acquire()
try: try:
response_iterator = send_message_to_model_offline( response_iterator = send_message_to_model_offline(
messages, loaded_model=model, stop=stop_phrases, streaming=True messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
) )
for response in response_iterator: for response in response_iterator:
g.send(response["choices"][0]["delta"].get("content", "")) g.send(response["choices"][0]["delta"].get("content", ""))
@ -216,9 +219,10 @@ def send_message_to_model_offline(
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
streaming=False, streaming=False,
stop=[], stop=[],
max_prompt_size: int = None,
): ):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model) offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages] messages_dict = [{"role": message.role, "content": message.content} for message in messages]
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming) response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
if streaming: if streaming:

View file

@ -1,18 +1,19 @@
import glob import glob
import logging import logging
import math
import os import os
from huggingface_hub.constants import HF_HUB_CACHE from huggingface_hub.constants import HF_HUB_CACHE
from khoj.utils import state from khoj.utils import state
from khoj.utils.helpers import get_device_memory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"): def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
from llama_cpp.llama import Llama # Initialize Model Parameters
# Use n_ctx=0 to get context size from the model
# Initialize Model Parameters. Use n_ctx=0 to get context size from the model
kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False} kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False}
# Decide whether to load model to GPU or CPU # Decide whether to load model to GPU or CPU
@ -23,23 +24,33 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
model_path = load_model_from_cache(repo_id, filename) model_path = load_model_from_cache(repo_id, filename)
chat_model = None chat_model = None
try: try:
if model_path: chat_model = load_model(model_path, repo_id, filename, kwargs)
chat_model = Llama(model_path, **kwargs)
else:
Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
except: except:
# Load model on CPU if GPU is not available # Load model on CPU if GPU is not available
kwargs["n_gpu_layers"], device = 0, "cpu" kwargs["n_gpu_layers"], device = 0, "cpu"
if model_path: chat_model = load_model(model_path, repo_id, filename, kwargs)
chat_model = Llama(model_path, **kwargs)
else:
chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}") # Now load the model with context size set based on:
# 1. context size supported by model and
# 2. configured size or machine (V)RAM
kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens)
chat_model = load_model(model_path, repo_id, filename, kwargs)
logger.debug(
f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window."
)
return chat_model return chat_model
def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}):
from llama_cpp.llama import Llama
if model_path:
return Llama(model_path, **kwargs)
else:
return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"): def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
# Construct the path to the model file in the cache directory # Construct the path to the model file in the cache directory
repo_org, repo_name = repo_id.split("/") repo_org, repo_name = repo_id.split("/")
@ -52,3 +63,12 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
return paths[0] return paths[0]
else: else:
return None return None
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
"""Infer max prompt size based on device memory and max context window supported by the model"""
vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic
if configured_max_tokens:
return min(configured_max_tokens, model_context_window)
else:
return min(vram_based_n_ctx, model_context_window)

View file

@ -1,5 +1,6 @@
import json import json
import logging import logging
import math
import queue import queue
from datetime import datetime from datetime import datetime
from time import perf_counter from time import perf_counter
@ -141,14 +142,12 @@ def generate_chatml_messages_with_context(
tokenizer_name=None, tokenizer_name=None,
): ):
"""Generate messages for ChatGPT with context from previous conversation""" """Generate messages for ChatGPT with context from previous conversation"""
# Set max prompt size from user config, pre-configured for model or to default prompt size # Set max prompt size from user config or based on pre-configured for model and machine specs
try: if not max_prompt_size:
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name] if loaded_model:
except: max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
max_prompt_size = 2000 else:
logger.warning( max_prompt_size = model_to_prompt_size.get(model_name, 2000)
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
)
# Scale lookback turns proportional to max prompt size supported by model # Scale lookback turns proportional to max prompt size supported by model
lookback_turns = max_prompt_size // 750 lookback_turns = max_prompt_size // 750
@ -187,7 +186,7 @@ def truncate_messages(
max_prompt_size, max_prompt_size,
model_name: str, model_name: str,
loaded_model: Optional[Llama] = None, loaded_model: Optional[Llama] = None,
tokenizer_name=None, tokenizer_name="hf-internal-testing/llama-tokenizer",
) -> list[ChatMessage]: ) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model""" """Truncate messages to fit within max prompt size supported by model"""
@ -197,15 +196,11 @@ def truncate_messages(
elif model_name.startswith("gpt-"): elif model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
else: else:
try:
encoder = download_model(model_name).tokenizer() encoder = download_model(model_name).tokenizer()
except: except:
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name]) encoder = AutoTokenizer.from_pretrained(tokenizer_name)
except:
default_tokenizer = "hf-internal-testing/llama-tokenizer"
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
logger.warning( logger.warning(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
) )
# Extract system message from messages # Extract system message from messages

View file

@ -315,8 +315,9 @@ async def extract_references_and_questions(
using_offline_chat = True using_offline_chat = True
default_offline_llm = await ConversationAdapters.get_default_offline_llm() default_offline_llm = await ConversationAdapters.get_default_offline_llm()
chat_model = default_offline_llm.chat_model chat_model = default_offline_llm.chat_model
max_tokens = default_offline_llm.max_prompt_size
if state.offline_chat_processor_config is None: if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
@ -326,6 +327,7 @@ async def extract_references_and_questions(
conversation_log=meta_log, conversation_log=meta_log,
should_extract_questions=True, should_extract_questions=True,
location_data=location_data, location_data=location_data,
max_prompt_size=conversation_config.max_prompt_size,
) )
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config() openai_chat_config = await ConversationAdapters.get_openai_chat_config()

View file

@ -82,9 +82,10 @@ async def is_ready_to_chat(user: KhojUser):
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline": if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
chat_model = user_conversation_config.chat_model chat_model = user_conversation_config.chat_model
max_tokens = user_conversation_config.max_prompt_size
if state.offline_chat_processor_config is None: if state.offline_chat_processor_config is None:
logger.info("Loading Offline Chat Model...") logger.info("Loading Offline Chat Model...")
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return True return True
ready = has_openai_config or has_offline_config ready = has_openai_config or has_offline_config
@ -385,10 +386,11 @@ async def send_message_to_model_wrapper(
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
if conversation_config.model_type == "offline": if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model) state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context( truncated_messages = generate_chatml_messages_with_context(
@ -455,7 +457,9 @@ def generate_chat_response(
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
if conversation_config.model_type == "offline": if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None: if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model) chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline( chat_response = converse_offline(

View file

@ -69,11 +69,11 @@ class OfflineChatProcessorConfig:
class OfflineChatProcessorModel: class OfflineChatProcessorModel:
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"): def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", max_tokens: int = None):
self.chat_model = chat_model self.chat_model = chat_model
self.loaded_model = None self.loaded_model = None
try: try:
self.loaded_model = download_model(self.chat_model) self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens)
except ValueError as e: except ValueError as e:
self.loaded_model = None self.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True) logger.error(f"Error while loading offline chat model: {e}", exc_info=True)

View file

@ -17,6 +17,7 @@ from time import perf_counter
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
from urllib.parse import urlparse from urllib.parse import urlparse
import psutil
import torch import torch
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from magika import Magika from magika import Magika
@ -271,6 +272,17 @@ def log_telemetry(
return request_body return request_body
def get_device_memory() -> int:
"""Get device memory in GB"""
device = get_device()
if device.type == "cuda":
return torch.cuda.get_device_properties(device).total_memory
elif device.type == "mps":
return torch.mps.driver_allocated_memory()
else:
return psutil.virtual_memory().total
def get_device() -> torch.device: def get_device() -> torch.device:
"""Get device to run model on""" """Get device to run model on"""
if torch.cuda.is_available(): if torch.cuda.is_available():