mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Improve Offline Chat Model Experience (#494)
- Make offline chat model user configurable. Use `filename` of any [GPT4All supported model](https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-chat/metadata/models.json) like below: - Run GPT4All Chat Model on GPU, when available via [GPT4All Vulcan support](https://blog.nomic.ai/posts/gpt4all-gpu-inference-with-vulkan) - Use default Llama 2 supported by GPT4All - Make `tokenizer` and `max-prompt-size` of chat model user configurable. E.g When using chat models not in [this pre-defined list](https://github.com/khoj-ai/khoj/blob/master/src/khoj/processor/conversation/utils.py) that support larger context window or a different tokenizer. Closes #406, #418
This commit is contained in:
commit
b4949f7f0b
16 changed files with 230 additions and 141 deletions
|
@ -19,7 +19,7 @@ from khoj.utils.config import (
|
|||
)
|
||||
from khoj.utils.helpers import resolve_absolute_path, merge_dicts
|
||||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, ConversationProcessorConfig
|
||||
from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
|
||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||
|
||||
|
||||
|
@ -168,9 +168,7 @@ def configure_conversation_processor(
|
|||
conversation_config=ConversationProcessorConfig(
|
||||
conversation_logfile=conversation_logfile,
|
||||
openai=(conversation_config.openai if (conversation_config is not None) else None),
|
||||
enable_offline_chat=(
|
||||
conversation_config.enable_offline_chat if (conversation_config is not None) else False
|
||||
),
|
||||
offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -236,7 +236,7 @@
|
|||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p class="card-description">Setup chat using OpenAI</p>
|
||||
<p class="card-description">Setup online chat using OpenAI</p>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/processor/conversation/openai">
|
||||
|
@ -261,21 +261,21 @@
|
|||
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||
<h3 class="card-title">
|
||||
Offline Chat
|
||||
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and not current_model_state.conversation_gpt4all %}
|
||||
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and not current_model_state.conversation_gpt4all %}
|
||||
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
|
||||
{% endif %}
|
||||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p class="card-description">Setup offline chat (Llama V2)</p>
|
||||
<p class="card-description">Setup offline chat</p>
|
||||
</div>
|
||||
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
|
||||
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
|
||||
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
|
||||
Disable
|
||||
</button>
|
||||
</div>
|
||||
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
|
||||
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
|
||||
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
|
||||
Enable
|
||||
</button>
|
||||
|
@ -346,7 +346,7 @@
|
|||
featuresHintText.classList.add("show");
|
||||
}
|
||||
|
||||
fetch('/api/config/data/processor/conversation/enable_offline_chat' + '?enable_offline_chat=' + enable, {
|
||||
fetch('/api/config/data/processor/conversation/offline_chat' + '?enable_offline_chat=' + enable, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
|
83
src/khoj/migrations/migrate_offline_chat_schema.py
Normal file
83
src/khoj/migrations/migrate_offline_chat_schema.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
"""
|
||||
Current format of khoj.yml
|
||||
---
|
||||
app:
|
||||
...
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
enable-offline-chat: false
|
||||
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||
openai:
|
||||
...
|
||||
search-type:
|
||||
...
|
||||
|
||||
New format of khoj.yml
|
||||
---
|
||||
app:
|
||||
...
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
offline-chat:
|
||||
enable-offline-chat: false
|
||||
chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin
|
||||
tokenizer: null
|
||||
max_prompt_size: null
|
||||
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||
openai:
|
||||
...
|
||||
search-type:
|
||||
...
|
||||
"""
|
||||
import logging
|
||||
from packaging import version
|
||||
|
||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def migrate_offline_chat_schema(args):
|
||||
schema_version = "0.12.3"
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
previous_version = raw_config.get("version")
|
||||
|
||||
if "processor" not in raw_config:
|
||||
return args
|
||||
if raw_config["processor"] is None:
|
||||
return args
|
||||
if "conversation" not in raw_config["processor"]:
|
||||
return args
|
||||
|
||||
if previous_version is None or version.parse(previous_version) < version.parse("0.12.3"):
|
||||
logger.info(
|
||||
f"Upgrading config schema to {schema_version} from {previous_version} to make (offline) chat more configuration"
|
||||
)
|
||||
raw_config["version"] = schema_version
|
||||
|
||||
# Create max-prompt-size field in conversation processor schema
|
||||
raw_config["processor"]["conversation"]["max-prompt-size"] = None
|
||||
raw_config["processor"]["conversation"]["tokenizer"] = None
|
||||
|
||||
# Create offline chat schema based on existing enable_offline_chat field in khoj config schema
|
||||
offline_chat_model = (
|
||||
raw_config["processor"]["conversation"]
|
||||
.get("offline-chat", {})
|
||||
.get("chat-model", "llama-2-7b-chat.ggmlv3.q4_0.bin")
|
||||
)
|
||||
raw_config["processor"]["conversation"]["offline-chat"] = {
|
||||
"enable-offline-chat": raw_config["processor"]["conversation"].get("enable-offline-chat", False),
|
||||
"chat-model": offline_chat_model,
|
||||
}
|
||||
|
||||
# Delete old enable-offline-chat field from conversation processor schema
|
||||
if "enable-offline-chat" in raw_config["processor"]["conversation"]:
|
||||
del raw_config["processor"]["conversation"]["enable-offline-chat"]
|
||||
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
return args
|
|
@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def extract_questions_offline(
|
||||
text: str,
|
||||
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
|
||||
model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
conversation_log={},
|
||||
use_history: bool = True,
|
||||
|
@ -113,7 +113,7 @@ def filter_questions(questions: List[str]):
|
|||
]
|
||||
filtered_questions = []
|
||||
for q in questions:
|
||||
if not any([word in q.lower() for word in hint_words]):
|
||||
if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q):
|
||||
filtered_questions.append(q)
|
||||
|
||||
return filtered_questions
|
||||
|
@ -123,10 +123,12 @@ def converse_offline(
|
|||
references,
|
||||
user_query,
|
||||
conversation_log={},
|
||||
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
|
||||
model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
completion_func=None,
|
||||
conversation_command=ConversationCommand.Default,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
"""
|
||||
Converse with user using Llama
|
||||
|
@ -158,6 +160,8 @@ def converse_offline(
|
|||
prompts.system_prompt_message_llamav2,
|
||||
conversation_log,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
)
|
||||
|
||||
g = ThreadedGenerator(references, completion_func=completion_func)
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
model_name_to_url = {
|
||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
||||
}
|
|
@ -1,24 +1,8 @@
|
|||
import os
|
||||
import logging
|
||||
import requests
|
||||
import hashlib
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from khoj.processor.conversation.gpt4all import model_metadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"}
|
||||
|
||||
|
||||
def get_md5_checksum(filename: str):
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(filename, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(8192), b""):
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def download_model(model_name: str):
|
||||
try:
|
||||
|
@ -27,57 +11,12 @@ def download_model(model_name: str):
|
|||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
raise e
|
||||
|
||||
url = model_metadata.model_name_to_url.get(model_name)
|
||||
model_path = os.path.expanduser(f"~/.cache/gpt4all/")
|
||||
if not url:
|
||||
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
|
||||
return GPT4All(model_name=model_name, model_path=model_path)
|
||||
|
||||
filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}")
|
||||
if os.path.exists(filename):
|
||||
# Check if the user is connected to the internet
|
||||
try:
|
||||
requests.get("https://www.google.com/", timeout=5)
|
||||
except:
|
||||
logger.debug("User is offline. Disabling allowed download flag")
|
||||
return GPT4All(model_name=model_name, model_path=model_path, allow_download=False)
|
||||
return GPT4All(model_name=model_name, model_path=model_path)
|
||||
|
||||
# Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file
|
||||
tmp_filename = filename + ".tmp"
|
||||
|
||||
# Use GPU for Chat Model, if available
|
||||
try:
|
||||
os.makedirs(os.path.dirname(tmp_filename), exist_ok=True)
|
||||
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
total_size = int(r.headers.get("content-length", 0))
|
||||
with open(tmp_filename, "wb") as f, tqdm(
|
||||
unit="B", # unit string to be displayed.
|
||||
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
|
||||
unit_divisor=1024, # is used when unit_scale is true
|
||||
total=total_size, # the total iteration.
|
||||
desc=model_name, # prefix to be displayed on progress bar.
|
||||
) as progress_bar:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
model = GPT4All(model_name=model_name, device="gpu")
|
||||
logger.debug("Loaded chat model to GPU.")
|
||||
except ValueError:
|
||||
model = GPT4All(model_name=model_name)
|
||||
logger.debug("Loaded chat model to CPU.")
|
||||
|
||||
# Verify the checksum
|
||||
if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename):
|
||||
logger.error(
|
||||
f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available."
|
||||
)
|
||||
os.remove(tmp_filename)
|
||||
raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.")
|
||||
|
||||
# Move the tmp file to the actual file
|
||||
os.rename(tmp_filename, filename)
|
||||
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
|
||||
return GPT4All(model_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True)
|
||||
# Remove the tmp file if it exists
|
||||
if os.path.exists(tmp_filename):
|
||||
os.remove(tmp_filename)
|
||||
return None
|
||||
return model
|
||||
|
|
|
@ -116,6 +116,8 @@ def converse(
|
|||
temperature: float = 0.2,
|
||||
completion_func=None,
|
||||
conversation_command=ConversationCommand.Default,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
):
|
||||
"""
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
|
@ -141,6 +143,8 @@ def converse(
|
|||
prompts.personality.format(),
|
||||
conversation_log,
|
||||
model,
|
||||
max_prompt_size,
|
||||
tokenizer_name,
|
||||
)
|
||||
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
|
||||
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
||||
|
|
|
@ -23,7 +23,7 @@ no_notes_found = PromptTemplate.from_template(
|
|||
""".strip()
|
||||
)
|
||||
|
||||
system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant.
|
||||
system_prompt_message_llamav2 = f"""You are Khoj, a smart, inquisitive and helpful personal assistant.
|
||||
Using your general knowledge and our past conversations as context, answer the following question.
|
||||
If you do not know the answer, say 'I don't know.'"""
|
||||
|
||||
|
@ -51,13 +51,13 @@ extract_questions_system_prompt_llamav2 = PromptTemplate.from_template(
|
|||
|
||||
general_conversation_llamav2 = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST]{query}[/INST]
|
||||
<s>[INST] {query} [/INST]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
chat_history_llamav2_from_user = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST]{message}[/INST]
|
||||
<s>[INST] {message} [/INST]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
@ -69,7 +69,7 @@ chat_history_llamav2_from_assistant = PromptTemplate.from_template(
|
|||
|
||||
conversation_llamav2 = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST]{query}[/INST]
|
||||
<s>[INST] {query} [/INST]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
@ -91,7 +91,7 @@ Question: {query}
|
|||
|
||||
notes_conversation_llamav2 = PromptTemplate.from_template(
|
||||
"""
|
||||
Notes:
|
||||
User's Notes:
|
||||
{references}
|
||||
Question: {query}
|
||||
""".strip()
|
||||
|
@ -134,19 +134,25 @@ Answer (in second person):"""
|
|||
|
||||
extract_questions_llamav2_sample = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST]<<SYS>>Current Date: {current_date}<</SYS>>[/INST]</s>
|
||||
<s>[INST]How was my trip to Cambodia?[/INST][]</s>
|
||||
<s>[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia?</s>
|
||||
<s>[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have?</s>
|
||||
<s>[INST]How many tennis balls fit in the back of a 2002 Honda Civic?[/INST]What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?</s>
|
||||
<s>[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
|
||||
<s>[INST]How are you feeling today?[/INST]</s>
|
||||
<s>[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age?</s>
|
||||
<s>[INST]<<SYS>>
|
||||
<s>[INST] <<SYS>>Current Date: {current_date}<</SYS>> [/INST]</s>
|
||||
<s>[INST] How was my trip to Cambodia? [/INST]
|
||||
How was my trip to Cambodia?</s>
|
||||
<s>[INST] Who did I visit the temple with on that trip? [/INST]
|
||||
Who did I visit the temple with in Cambodia?</s>
|
||||
<s>[INST] How should I take care of my plants? [/INST]
|
||||
What kind of plants do I have? What issues do my plants have?</s>
|
||||
<s>[INST] How many tennis balls fit in the back of a 2002 Honda Civic? [/INST]
|
||||
What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?</s>
|
||||
<s>[INST] What did I do for Christmas last year? [/INST]
|
||||
What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
|
||||
<s>[INST] How are you feeling today? [/INST]</s>
|
||||
<s>[INST] Is Alice older than Bob? [/INST]
|
||||
When was Alice born? What is Bob's age?</s>
|
||||
<s>[INST] <<SYS>>
|
||||
Use these notes from the user's previous conversations to provide a response:
|
||||
{chat_history}
|
||||
<</SYS>>[/INST]</s>
|
||||
<s>[INST]{query}[/INST]
|
||||
<</SYS>> [/INST]</s>
|
||||
<s>[INST] {query} [/INST]
|
||||
"""
|
||||
)
|
||||
|
||||
|
|
|
@ -3,24 +3,27 @@ import logging
|
|||
from time import perf_counter
|
||||
import json
|
||||
from datetime import datetime
|
||||
import queue
|
||||
import tiktoken
|
||||
|
||||
# External packages
|
||||
from langchain.schema import ChatMessage
|
||||
from transformers import LlamaTokenizerFast
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
# Internal Packages
|
||||
import queue
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
max_prompt_size = {
|
||||
model_to_prompt_size = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-4": 8192,
|
||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548,
|
||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
|
||||
"gpt-3.5-turbo-16k": 15000,
|
||||
}
|
||||
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
|
||||
model_to_tokenizer = {
|
||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
|
||||
}
|
||||
|
||||
|
||||
class ThreadedGenerator:
|
||||
|
@ -82,9 +85,26 @@ def message_to_log(
|
|||
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
|
||||
user_message,
|
||||
system_message,
|
||||
conversation_log={},
|
||||
model_name="gpt-3.5-turbo",
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
):
|
||||
"""Generate messages for ChatGPT with context from previous conversation"""
|
||||
# Set max prompt size from user config, pre-configured for model or to default prompt size
|
||||
try:
|
||||
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
|
||||
except:
|
||||
max_prompt_size = 2000
|
||||
logger.warning(
|
||||
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
|
||||
)
|
||||
|
||||
# Scale lookback turns proportional to max prompt size supported by model
|
||||
lookback_turns = max_prompt_size // 750
|
||||
|
||||
# Extract Chat History for Context
|
||||
chat_logs = []
|
||||
for chat in conversation_log.get("chat", []):
|
||||
|
@ -105,19 +125,28 @@ def generate_chatml_messages_with_context(
|
|||
messages = user_chatml_message + rest_backnforths + system_chatml_message
|
||||
|
||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||
messages = truncate_messages(messages, max_prompt_size[model_name], model_name)
|
||||
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
|
||||
|
||||
# Return message in chronological order
|
||||
return messages[::-1]
|
||||
|
||||
|
||||
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
|
||||
def truncate_messages(
|
||||
messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
|
||||
if "llama" in model_name:
|
||||
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
|
||||
else:
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
try:
|
||||
if model_name.startswith("gpt-"):
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
||||
except:
|
||||
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
||||
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
|
||||
logger.warning(
|
||||
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
|
||||
)
|
||||
|
||||
system_message = messages.pop()
|
||||
system_message_tokens = len(encoder.encode(system_message.content))
|
||||
|
|
|
@ -284,10 +284,11 @@ if not state.demo:
|
|||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
@api.post("/config/data/processor/conversation/enable_offline_chat", status_code=200)
|
||||
@api.post("/config/data/processor/conversation/offline_chat", status_code=200)
|
||||
async def set_processor_enable_offline_chat_config_data(
|
||||
request: Request,
|
||||
enable_offline_chat: bool,
|
||||
offline_chat_model: Optional[str] = None,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
_initialize_config()
|
||||
|
@ -301,7 +302,9 @@ if not state.demo:
|
|||
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
|
||||
|
||||
assert state.config.processor.conversation is not None
|
||||
state.config.processor.conversation.enable_offline_chat = enable_offline_chat
|
||||
state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat
|
||||
if offline_chat_model is not None:
|
||||
state.config.processor.conversation.offline_chat.chat_model = offline_chat_model
|
||||
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
||||
|
||||
update_telemetry_state(
|
||||
|
@ -713,7 +716,7 @@ async def chat(
|
|||
conversation_command = ConversationCommand.General
|
||||
|
||||
if conversation_command == ConversationCommand.Help:
|
||||
model_type = "offline" if state.processor_config.conversation.enable_offline_chat else "openai"
|
||||
model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai"
|
||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version)
|
||||
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
||||
|
||||
|
@ -788,7 +791,7 @@ async def extract_references_and_questions(
|
|||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||
if state.processor_config.conversation.enable_offline_chat:
|
||||
if state.processor_config.conversation.offline_chat.enable_offline_chat:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
inferred_queries = extract_questions_offline(
|
||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||
|
@ -804,7 +807,7 @@ async def extract_references_and_questions(
|
|||
with timer("Searching knowledge base took", logger):
|
||||
result_list = []
|
||||
for query in inferred_queries:
|
||||
n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n
|
||||
n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n
|
||||
result_list.extend(
|
||||
await search(
|
||||
f"{query} {filters_in_query}",
|
||||
|
|
|
@ -113,7 +113,7 @@ def generate_chat_response(
|
|||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
if state.processor_config.conversation.enable_offline_chat:
|
||||
if state.processor_config.conversation.offline_chat.enable_offline_chat:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
chat_response = converse_offline(
|
||||
references=compiled_references,
|
||||
|
@ -122,6 +122,9 @@ def generate_chat_response(
|
|||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
conversation_command=conversation_command,
|
||||
model=state.processor_config.conversation.offline_chat.chat_model,
|
||||
max_prompt_size=state.processor_config.conversation.max_prompt_size,
|
||||
tokenizer_name=state.processor_config.conversation.tokenizer,
|
||||
)
|
||||
|
||||
elif state.processor_config.conversation.openai_model:
|
||||
|
@ -135,6 +138,8 @@ def generate_chat_response(
|
|||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
conversation_command=conversation_command,
|
||||
max_prompt_size=state.processor_config.conversation.max_prompt_size,
|
||||
tokenizer_name=state.processor_config.conversation.tokenizer,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
@ -9,6 +9,7 @@ from khoj.utils.yaml import parse_config_from_file
|
|||
from khoj.migrations.migrate_version import migrate_config_to_version
|
||||
from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
|
||||
from khoj.migrations.migrate_offline_model import migrate_offline_model
|
||||
from khoj.migrations.migrate_offline_chat_schema import migrate_offline_chat_schema
|
||||
|
||||
|
||||
def cli(args=None):
|
||||
|
@ -55,7 +56,12 @@ def cli(args=None):
|
|||
|
||||
|
||||
def run_migrations(args):
|
||||
migrations = [migrate_config_to_version, migrate_processor_conversation_schema, migrate_offline_model]
|
||||
migrations = [
|
||||
migrate_config_to_version,
|
||||
migrate_processor_conversation_schema,
|
||||
migrate_offline_model,
|
||||
migrate_offline_chat_schema,
|
||||
]
|
||||
for migration in migrations:
|
||||
args = migration(args)
|
||||
return args
|
||||
|
|
|
@ -84,7 +84,6 @@ class SearchModels:
|
|||
|
||||
@dataclass
|
||||
class GPT4AllProcessorConfig:
|
||||
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
||||
loaded_model: Union[Any, None] = None
|
||||
|
||||
|
||||
|
@ -95,18 +94,20 @@ class ConversationProcessorConfigModel:
|
|||
):
|
||||
self.openai_model = conversation_config.openai
|
||||
self.gpt4all_model = GPT4AllProcessorConfig()
|
||||
self.enable_offline_chat = conversation_config.enable_offline_chat
|
||||
self.offline_chat = conversation_config.offline_chat
|
||||
self.max_prompt_size = conversation_config.max_prompt_size
|
||||
self.tokenizer = conversation_config.tokenizer
|
||||
self.conversation_logfile = Path(conversation_config.conversation_logfile)
|
||||
self.chat_session: List[str] = []
|
||||
self.meta_log: dict = {}
|
||||
|
||||
if self.enable_offline_chat:
|
||||
if self.offline_chat.enable_offline_chat:
|
||||
try:
|
||||
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
|
||||
self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model)
|
||||
except ValueError as e:
|
||||
self.offline_chat.enable_offline_chat = False
|
||||
self.gpt4all_model.loaded_model = None
|
||||
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
|
||||
self.enable_offline_chat = False
|
||||
else:
|
||||
self.gpt4all_model.loaded_model = None
|
||||
|
||||
|
|
|
@ -91,10 +91,17 @@ class OpenAIProcessorConfig(ConfigBase):
|
|||
chat_model: Optional[str] = "gpt-3.5-turbo"
|
||||
|
||||
|
||||
class OfflineChatProcessorConfig(ConfigBase):
|
||||
enable_offline_chat: Optional[bool] = False
|
||||
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin"
|
||||
|
||||
|
||||
class ConversationProcessorConfig(ConfigBase):
|
||||
conversation_logfile: Path
|
||||
openai: Optional[OpenAIProcessorConfig]
|
||||
enable_offline_chat: Optional[bool] = False
|
||||
offline_chat: Optional[OfflineChatProcessorConfig]
|
||||
max_prompt_size: Optional[int]
|
||||
tokenizer: Optional[str]
|
||||
|
||||
|
||||
class ProcessorConfig(ConfigBase):
|
||||
|
|
|
@ -16,6 +16,7 @@ from khoj.utils.helpers import resolve_absolute_path
|
|||
from khoj.utils.rawconfig import (
|
||||
ContentConfig,
|
||||
ConversationProcessorConfig,
|
||||
OfflineChatProcessorConfig,
|
||||
OpenAIProcessorConfig,
|
||||
ProcessorConfig,
|
||||
TextContentConfig,
|
||||
|
@ -205,8 +206,9 @@ def processor_config_offline_chat(tmp_path_factory):
|
|||
|
||||
# Setup conversation processor
|
||||
processor_config = ProcessorConfig()
|
||||
offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True)
|
||||
processor_config.conversation = ConversationProcessorConfig(
|
||||
enable_offline_chat=True,
|
||||
offline_chat=offline_chat,
|
||||
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
|
||||
)
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ from khoj.processor.conversation.gpt4all.utils import download_model
|
|||
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
|
||||
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
||||
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_0.bin"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -128,15 +128,15 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model):
|
|||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_implicit_questions_from_message(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model)
|
||||
response = extract_questions_offline("Is Carl taller than Ross?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = ["height", "taller", "shorter", "heights"]
|
||||
expected_responses = ["height", "taller", "shorter", "heights", "who"]
|
||||
assert len(response) <= 3
|
||||
|
||||
for question in response:
|
||||
assert any([expected_response in question.lower() for expected_response in expected_responses]), (
|
||||
"Expected chat actor to ask follow-up questions about Morpheus and Neo, but got: " + question
|
||||
"Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question
|
||||
)
|
||||
|
||||
|
||||
|
@ -145,7 +145,7 @@ def test_extract_multiple_implicit_questions_from_message(loaded_model):
|
|||
def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
|
||||
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
|
@ -156,17 +156,22 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
|||
use_history=True,
|
||||
)
|
||||
|
||||
expected_responses = [
|
||||
"Vader",
|
||||
"sons",
|
||||
all_expected_in_response = [
|
||||
"Anderson",
|
||||
]
|
||||
|
||||
any_expected_in_response = [
|
||||
"son",
|
||||
"Darth",
|
||||
"sons",
|
||||
"children",
|
||||
]
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
assert any([expected_response in response[0] for expected_response in expected_responses]), (
|
||||
assert all([expected_response in response[0] for expected_response in all_expected_in_response]), (
|
||||
"Expected chat actor to ask for clarification in response, but got: " + response[0]
|
||||
)
|
||||
assert any([expected_response in response[0] for expected_response in any_expected_in_response]), (
|
||||
"Expected chat actor to ask for clarification in response, but got: " + response[0]
|
||||
)
|
||||
|
||||
|
@ -176,20 +181,20 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
|||
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
|
||||
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_offline(
|
||||
"Is she a Jedi?",
|
||||
"Is she a Doctor?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
expected_responses = [
|
||||
"Leia",
|
||||
"Vader",
|
||||
"Barbara",
|
||||
"Robert",
|
||||
"daughter",
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in a new issue