mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Use offline chat prompt config to set context window of loaded chat model
Previously you couldn't configure the n_ctx of the loaded offline chat model. This made it hard to use good offline chat model (which these days also have larger context) on machines with lower VRAM
This commit is contained in:
parent
689202e00e
commit
4977b55106
8 changed files with 81 additions and 44 deletions
|
@ -78,6 +78,7 @@ dependencies = [
|
||||||
"phonenumbers == 8.13.27",
|
"phonenumbers == 8.13.27",
|
||||||
"markdownify ~= 0.11.6",
|
"markdownify ~= 0.11.6",
|
||||||
"websockets == 12.0",
|
"websockets == 12.0",
|
||||||
|
"psutil >= 5.8.0",
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
@ -105,7 +106,6 @@ dev = [
|
||||||
"pytest-asyncio == 0.21.1",
|
"pytest-asyncio == 0.21.1",
|
||||||
"freezegun >= 1.2.0",
|
"freezegun >= 1.2.0",
|
||||||
"factory-boy >= 3.2.1",
|
"factory-boy >= 3.2.1",
|
||||||
"psutil >= 5.8.0",
|
|
||||||
"mypy >= 1.0.1",
|
"mypy >= 1.0.1",
|
||||||
"black >= 23.1.0",
|
"black >= 23.1.0",
|
||||||
"pre-commit >= 3.0.4",
|
"pre-commit >= 3.0.4",
|
||||||
|
|
|
@ -30,6 +30,7 @@ def extract_questions_offline(
|
||||||
use_history: bool = True,
|
use_history: bool = True,
|
||||||
should_extract_questions: bool = True,
|
should_extract_questions: bool = True,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
|
max_prompt_size: int = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Infer search queries to retrieve relevant notes to answer user query
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
@ -41,7 +42,7 @@ def extract_questions_offline(
|
||||||
return all_questions
|
return all_questions
|
||||||
|
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
|
|
||||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||||
|
|
||||||
|
@ -67,12 +68,14 @@ def extract_questions_offline(
|
||||||
location=location,
|
location=location,
|
||||||
)
|
)
|
||||||
messages = generate_chatml_messages_with_context(
|
messages = generate_chatml_messages_with_context(
|
||||||
example_questions, model_name=model, loaded_model=offline_chat_model
|
example_questions, model_name=model, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
|
||||||
)
|
)
|
||||||
|
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
try:
|
try:
|
||||||
response = send_message_to_model_offline(messages, loaded_model=offline_chat_model)
|
response = send_message_to_model_offline(
|
||||||
|
messages, loaded_model=offline_chat_model, max_prompt_size=max_prompt_size
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
state.chat_lock.release()
|
state.chat_lock.release()
|
||||||
|
|
||||||
|
@ -138,7 +141,7 @@ def converse_offline(
|
||||||
"""
|
"""
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
||||||
|
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
@ -190,18 +193,18 @@ def converse_offline(
|
||||||
)
|
)
|
||||||
|
|
||||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model))
|
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model, max_prompt_size))
|
||||||
t.start()
|
t.start()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
|
|
||||||
def llm_thread(g, messages: List[ChatMessage], model: Any):
|
def llm_thread(g, messages: List[ChatMessage], model: Any, max_prompt_size: int = None):
|
||||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||||
|
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
try:
|
try:
|
||||||
response_iterator = send_message_to_model_offline(
|
response_iterator = send_message_to_model_offline(
|
||||||
messages, loaded_model=model, stop=stop_phrases, streaming=True
|
messages, loaded_model=model, stop=stop_phrases, max_prompt_size=max_prompt_size, streaming=True
|
||||||
)
|
)
|
||||||
for response in response_iterator:
|
for response in response_iterator:
|
||||||
g.send(response["choices"][0]["delta"].get("content", ""))
|
g.send(response["choices"][0]["delta"].get("content", ""))
|
||||||
|
@ -216,9 +219,10 @@ def send_message_to_model_offline(
|
||||||
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||||
streaming=False,
|
streaming=False,
|
||||||
stop=[],
|
stop=[],
|
||||||
|
max_prompt_size: int = None,
|
||||||
):
|
):
|
||||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||||
offline_chat_model = loaded_model or download_model(model)
|
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||||
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
||||||
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
|
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
|
||||||
if streaming:
|
if streaming:
|
||||||
|
|
|
@ -1,18 +1,19 @@
|
||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from huggingface_hub.constants import HF_HUB_CACHE
|
from huggingface_hub.constants import HF_HUB_CACHE
|
||||||
|
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
from khoj.utils.helpers import get_device_memory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
|
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None):
|
||||||
from llama_cpp.llama import Llama
|
# Initialize Model Parameters
|
||||||
|
# Use n_ctx=0 to get context size from the model
|
||||||
# Initialize Model Parameters. Use n_ctx=0 to get context size from the model
|
|
||||||
kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False}
|
kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False}
|
||||||
|
|
||||||
# Decide whether to load model to GPU or CPU
|
# Decide whether to load model to GPU or CPU
|
||||||
|
@ -23,23 +24,33 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
|
||||||
model_path = load_model_from_cache(repo_id, filename)
|
model_path = load_model_from_cache(repo_id, filename)
|
||||||
chat_model = None
|
chat_model = None
|
||||||
try:
|
try:
|
||||||
if model_path:
|
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||||
chat_model = Llama(model_path, **kwargs)
|
|
||||||
else:
|
|
||||||
Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
|
||||||
except:
|
except:
|
||||||
# Load model on CPU if GPU is not available
|
# Load model on CPU if GPU is not available
|
||||||
kwargs["n_gpu_layers"], device = 0, "cpu"
|
kwargs["n_gpu_layers"], device = 0, "cpu"
|
||||||
if model_path:
|
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||||
chat_model = Llama(model_path, **kwargs)
|
|
||||||
else:
|
|
||||||
chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
|
||||||
|
|
||||||
logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}")
|
# Now load the model with context size set based on:
|
||||||
|
# 1. context size supported by model and
|
||||||
|
# 2. configured size or machine (V)RAM
|
||||||
|
kwargs["n_ctx"] = infer_max_tokens(chat_model.n_ctx(), max_tokens)
|
||||||
|
chat_model = load_model(model_path, repo_id, filename, kwargs)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()} with {kwargs['n_ctx']} token context window."
|
||||||
|
)
|
||||||
return chat_model
|
return chat_model
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path: str, repo_id: str, filename: str = "*Q4_K_M.gguf", kwargs: dict = {}):
|
||||||
|
from llama_cpp.llama import Llama
|
||||||
|
|
||||||
|
if model_path:
|
||||||
|
return Llama(model_path, **kwargs)
|
||||||
|
else:
|
||||||
|
return Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
||||||
# Construct the path to the model file in the cache directory
|
# Construct the path to the model file in the cache directory
|
||||||
repo_org, repo_name = repo_id.split("/")
|
repo_org, repo_name = repo_id.split("/")
|
||||||
|
@ -52,3 +63,12 @@ def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
||||||
return paths[0]
|
return paths[0]
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def infer_max_tokens(model_context_window: int, configured_max_tokens=math.inf) -> int:
|
||||||
|
"""Infer max prompt size based on device memory and max context window supported by the model"""
|
||||||
|
vram_based_n_ctx = int(get_device_memory() / 1e6) # based on heuristic
|
||||||
|
if configured_max_tokens:
|
||||||
|
return min(configured_max_tokens, model_context_window)
|
||||||
|
else:
|
||||||
|
return min(vram_based_n_ctx, model_context_window)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
import queue
|
import queue
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
|
@ -141,14 +142,12 @@ def generate_chatml_messages_with_context(
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
):
|
):
|
||||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||||
# Set max prompt size from user config, pre-configured for model or to default prompt size
|
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||||
try:
|
if not max_prompt_size:
|
||||||
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
|
if loaded_model:
|
||||||
except:
|
max_prompt_size = min(loaded_model.n_ctx(), model_to_prompt_size.get(model_name, math.inf))
|
||||||
max_prompt_size = 2000
|
else:
|
||||||
logger.warning(
|
max_prompt_size = model_to_prompt_size.get(model_name, 2000)
|
||||||
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Scale lookback turns proportional to max prompt size supported by model
|
# Scale lookback turns proportional to max prompt size supported by model
|
||||||
lookback_turns = max_prompt_size // 750
|
lookback_turns = max_prompt_size // 750
|
||||||
|
@ -187,7 +186,7 @@ def truncate_messages(
|
||||||
max_prompt_size,
|
max_prompt_size,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
loaded_model: Optional[Llama] = None,
|
loaded_model: Optional[Llama] = None,
|
||||||
tokenizer_name=None,
|
tokenizer_name="hf-internal-testing/llama-tokenizer",
|
||||||
) -> list[ChatMessage]:
|
) -> list[ChatMessage]:
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
|
|
||||||
|
@ -197,15 +196,11 @@ def truncate_messages(
|
||||||
elif model_name.startswith("gpt-"):
|
elif model_name.startswith("gpt-"):
|
||||||
encoder = tiktoken.encoding_for_model(model_name)
|
encoder = tiktoken.encoding_for_model(model_name)
|
||||||
else:
|
else:
|
||||||
try:
|
encoder = download_model(model_name).tokenizer()
|
||||||
encoder = download_model(model_name).tokenizer()
|
|
||||||
except:
|
|
||||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
|
||||||
except:
|
except:
|
||||||
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
encoder = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
f"Fallback to default chat model tokenizer: {tokenizer_name}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract system message from messages
|
# Extract system message from messages
|
||||||
|
|
|
@ -315,8 +315,9 @@ async def extract_references_and_questions(
|
||||||
using_offline_chat = True
|
using_offline_chat = True
|
||||||
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
|
default_offline_llm = await ConversationAdapters.get_default_offline_llm()
|
||||||
chat_model = default_offline_llm.chat_model
|
chat_model = default_offline_llm.chat_model
|
||||||
|
max_tokens = default_offline_llm.max_prompt_size
|
||||||
if state.offline_chat_processor_config is None:
|
if state.offline_chat_processor_config is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
|
|
||||||
|
@ -326,6 +327,7 @@ async def extract_references_and_questions(
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
should_extract_questions=True,
|
should_extract_questions=True,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
|
max_prompt_size=conversation_config.max_prompt_size,
|
||||||
)
|
)
|
||||||
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||||
|
|
|
@ -82,9 +82,10 @@ async def is_ready_to_chat(user: KhojUser):
|
||||||
|
|
||||||
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
|
if has_offline_config and user_conversation_config and user_conversation_config.model_type == "offline":
|
||||||
chat_model = user_conversation_config.chat_model
|
chat_model = user_conversation_config.chat_model
|
||||||
|
max_tokens = user_conversation_config.max_prompt_size
|
||||||
if state.offline_chat_processor_config is None:
|
if state.offline_chat_processor_config is None:
|
||||||
logger.info("Loading Offline Chat Model...")
|
logger.info("Loading Offline Chat Model...")
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model=chat_model)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
ready = has_openai_config or has_offline_config
|
ready = has_openai_config or has_offline_config
|
||||||
|
@ -385,10 +386,11 @@ async def send_message_to_model_wrapper(
|
||||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||||
|
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
|
max_tokens = conversation_config.max_prompt_size
|
||||||
|
|
||||||
if conversation_config.model_type == "offline":
|
if conversation_config.model_type == "offline":
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model)
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
@ -455,7 +457,9 @@ def generate_chat_response(
|
||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||||
if conversation_config.model_type == "offline":
|
if conversation_config.model_type == "offline":
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(conversation_config.chat_model)
|
chat_model = conversation_config.chat_model
|
||||||
|
max_tokens = conversation_config.max_prompt_size
|
||||||
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
chat_response = converse_offline(
|
chat_response = converse_offline(
|
||||||
|
|
|
@ -69,11 +69,11 @@ class OfflineChatProcessorConfig:
|
||||||
|
|
||||||
|
|
||||||
class OfflineChatProcessorModel:
|
class OfflineChatProcessorModel:
|
||||||
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"):
|
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", max_tokens: int = None):
|
||||||
self.chat_model = chat_model
|
self.chat_model = chat_model
|
||||||
self.loaded_model = None
|
self.loaded_model = None
|
||||||
try:
|
try:
|
||||||
self.loaded_model = download_model(self.chat_model)
|
self.loaded_model = download_model(self.chat_model, max_tokens=max_tokens)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
self.loaded_model = None
|
self.loaded_model = None
|
||||||
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
||||||
|
|
|
@ -17,6 +17,7 @@ from time import perf_counter
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from magika import Magika
|
from magika import Magika
|
||||||
|
@ -271,6 +272,17 @@ def log_telemetry(
|
||||||
return request_body
|
return request_body
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_memory() -> int:
|
||||||
|
"""Get device memory in GB"""
|
||||||
|
device = get_device()
|
||||||
|
if device.type == "cuda":
|
||||||
|
return torch.cuda.get_device_properties(device).total_memory
|
||||||
|
elif device.type == "mps":
|
||||||
|
return torch.mps.driver_allocated_memory()
|
||||||
|
else:
|
||||||
|
return psutil.virtual_memory().total
|
||||||
|
|
||||||
|
|
||||||
def get_device() -> torch.device:
|
def get_device() -> torch.device:
|
||||||
"""Get device to run model on"""
|
"""Get device to run model on"""
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
Loading…
Reference in a new issue