Use offline chat prompt config to set context window of loaded chat model

Previously you couldn't configure the n_ctx of the loaded offline chat
model. This made it hard to use good offline chat model (which these
days also have larger context) on machines with lower VRAM
This commit is contained in:
Debanjum Singh Solanky 2024-04-13 22:15:34 +05:30
parent 689202e00e
commit 4977b55106
8 changed files with 81 additions and 44 deletions

View file

@ -78,6 +78,7 @@ dependencies = [
"phonenumbers == 8.13.27",
"markdownify ~= 0.11.6",
"websockets == 12.0",
"psutil >= 5.8.0",
]
dynamic = ["version"]
@ -105,7 +106,6 @@ dev = [
"pytest-asyncio == 0.21.1",
"freezegun >= 1.2.0",
"factory-boy >= 3.2.1",
"psutil >= 5.8.0",
"mypy >= 1.0.1",
"black >= 23.1.0",
"pre-commit >= 3.0.4",

View file

@ -30,6 +30,7 @@ def extract_questions_offline(
use_history: bool = True,
should_extract_questions: bool = True,
location_data: LocationData = None,
max_prompt_size: int = None,
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
@ -41,7 +42,7 @@ def extract_questions_offline(
return all_questions
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model)
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
@ -67,12 +68,14 @@ def extract_questions_offline(
location=location,
)
messages = generate_chatml_messages_with_context(
example_questions, model_name=model, loaded_model=offline_chat_model
example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
)
state.chat_lock.acquire()
try:
response = send_message_to_model_offline(messages, loaded_model=offline_chat_model)
response = send_message_to_model_offline(
messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
)
finally:
state.chat_lock.release()
@ -138,7 +141,7 @@ def converse_offline(
"""
# Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model)
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
compiled_references_message = "\n\n".join({f"{item}" for item in references})
current_date = datetime.now().strftime("%Y-%m-%d")
@ -190,18 +193,18 @@ def converse_offline(
)
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model))
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
t.start()
return g
def llm_thread(g, messages: List[ChatMessage], model: Any):
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
stop_phrases = ["<s>", "INST]", "Notes:"]
state.chat_lock.acquire()
try:
response_iterator = send_message_to_model_offline(
messages, loaded_model=model, stop=stop_phrases, streaming=True
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
)
for response in response_iterator:
g.send(response["choices"][0]["delta"].get("content", ""))
@ -216,9 +219,10 @@ def send_message_to_model_offline(
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
streaming=False,
stop=[],
max_prompt_size: int = None,
):
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model)
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
if streaming:

View file

@ -1,18 +1,19 @@
import glob
import logging
import math
import os
from huggingface_hub.constants import HF_HUB_CACHE
from khoj.utils import state
from khoj.utils.helpers import get_device_memory
logger = logging.getLogger(__name__)
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
from llama_cpp.llama import Llama
# Initialize Model Parameters. Use n_ctx=0 to get context size from the model
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
# Initialize Model Parameters
# Use n_ctx=0 to get context size from the model
kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False}
# Decide whether to load model to GPU or CPU
@ -23,23 +24,33 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
model_path = load_model_from_cache(repo_id, filename)
chat_model = None
try:
if model_path:
chat_model = Llama(model_path, **kwargs)
else:
Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
chat_model = load_model(model_path, repo_id, filename, kwargs)
except:
# Load model on CPU if GPU is not available
kwargs["n_gpu_layers"], device = 0, "cpu"
if model_path:
chat_model = Llama(model_path, **kwargs)
else:
chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
chat_model = load_model(model_path, repo_id, filename, kwargs)
logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}")
# Now load the model with context size set based on:
# 1. context size supported by model and
# 2. configured size or machine (V)RAM
kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens)
chat_model = load_model(model_path, repo_id, filename, kwargs)
logger.debug(
f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window."
)
return chat_model
def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}):
from llama_cpp.llama import Llama
if model_path:
return Llama(model_path, **kwargs)
else:
return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
# Construct the path to the model file in the cache directory
repo_org, repo_name = repo_id.split("/")
@ -52,3 +63,12 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
return paths[0]
else:
return None
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
"""Infer max prompt size based on device memory and max context window supported by the model"""
vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic
if configured_max_tokens:
return min(configured_max_tokens, model_context_window)
else:
return min(vram_based_n_ctx, model_context_window)

View file

@ -1,5 +1,6 @@
import json
import logging
import math
import queue
from datetime import datetime
from time import perf_counter
@ -141,14 +142,12 @@ def generate_chatml_messages_with_context(
tokenizer_name=None,
):
"""Generate messages for ChatGPT with context from previous conversation"""
# Set max prompt size from user config, pre-configured for model or to default prompt size
try:
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
except:
max_prompt_size = 2000
logger.warning(
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
)
# Set max prompt size from user config or based on pre-configured for model and machine specs
if not max_prompt_size:
if loaded_model:
max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
else:
max_prompt_size = model_to_prompt_size.get(model_name, 2000)
# Scale lookback turns proportional to max prompt size supported by model
lookback_turns = max_prompt_size // 750
@ -187,7 +186,7 @@ def truncate_messages(
max_prompt_size,
model_name: str,
loaded_model: Optional[Llama] = None,
tokenizer_name=None,
tokenizer_name="hf-internal-testing/llama-tokenizer",
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
@ -197,15 +196,11 @@ def truncate_messages(
elif model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name)
else:
try:
encoder = download_model(model_name).tokenizer()
except:
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
except:
default_tokenizer = "hf-internal-testing/llama-tokenizer"
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
logger.warning(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)
# Extract system message from messages

View file

@ -315,8 +315,9 @@ async def extract_references_and_questions(
using_offline_chat = True
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
chat_model = default_offline_llm.chat_model
max_tokens = default_offline_llm.max_prompt_size
if state.offline_chat_processor_config is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
@ -326,6 +327,7 @@ async def extract_references_and_questions(
conversation_log=meta_log,
should_extract_questions=True,
location_data=location_data,
max_prompt_size=conversation_config.max_prompt_size,
)
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config()

View file

@ -82,9 +82,10 @@ async def is_ready_to_chat(user: KhojUser):
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
chat_model = user_conversation_config.chat_model
max_tokens = user_conversation_config.max_prompt_size
if state.offline_chat_processor_config is None:
logger.info("Loading Offline Chat Model...")
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
return True
ready = has_openai_config or has_offline_config
@ -385,10 +386,11 @@ async def send_message_to_model_wrapper(
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model)
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
@ -455,7 +457,9 @@ def generate_chat_response(
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
if conversation_config.model_type == "offline":
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model)
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline(

View file

@ -69,11 +69,11 @@ class OfflineChatProcessorConfig:
class OfflineChatProcessorModel:
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"):
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", max_tokens: int = None):
self.chat_model = chat_model
self.loaded_model = None
try:
self.loaded_model = download_model(self.chat_model)
self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens)
except ValueError as e:
self.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)

View file

@ -17,6 +17,7 @@ from time import perf_counter
from typing import TYPE_CHECKING, Optional, Union
from urllib.parse import urlparse
import psutil
import torch
from asgiref.sync import sync_to_async
from magika import Magika
@ -271,6 +272,17 @@ def log_telemetry(
return request_body
def get_device_memory() -> int:
"""Get device memory in GB"""
device = get_device()
if device.type == "cuda":
return torch.cuda.get_device_properties(device).total_memory
elif device.type == "mps":
return torch.mps.driver_allocated_memory()
else:
return psutil.virtual_memory().total
def get_device() -> torch.device:
"""Get device to run model on"""
if torch.cuda.is_available():