Improve Offline Chat Model Experience (#494)

- Make offline chat model user configurable. Use `filename` of any [GPT4All supported  model](https://github.com/nomic-ai/gpt4all/blob/main/gpt4all-chat/metadata/models.json) like below:
- Run GPT4All Chat Model on GPU, when available via [GPT4All Vulcan support](https://blog.nomic.ai/posts/gpt4all-gpu-inference-with-vulkan)
- Use default Llama 2 supported by GPT4All
- Make `tokenizer` and `max-prompt-size` of chat model user configurable. E.g When using chat models not in [this pre-defined list](https://github.com/khoj-ai/khoj/blob/master/src/khoj/processor/conversation/utils.py) that support larger context window or a different tokenizer.

Closes #406, #418
This commit is contained in:
Debanjum 2023-10-16 17:44:49 -07:00 committed by GitHub
commit b4949f7f0b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 230 additions and 141 deletions

View file

@ -19,7 +19,7 @@ from khoj.utils.config import (
)
from khoj.utils.helpers import resolve_absolute_path, merge_dicts
from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, ConversationProcessorConfig
from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig
from khoj.routers.indexer import configure_content, load_content, configure_search
@ -168,9 +168,7 @@ def configure_conversation_processor(
conversation_config=ConversationProcessorConfig(
conversation_logfile=conversation_logfile,
openai=(conversation_config.openai if (conversation_config is not None) else None),
enable_offline_chat=(
conversation_config.enable_offline_chat if (conversation_config is not None) else False
),
offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(),
)
)
else:

View file

@ -236,7 +236,7 @@
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Setup chat using OpenAI</p>
<p class="card-description">Setup online chat using OpenAI</p>
</div>
<div class="card-action-row">
<a class="card-button" href="/config/processor/conversation/openai">
@ -261,21 +261,21 @@
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<h3 class="card-title">
Offline Chat
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and not current_model_state.conversation_gpt4all %}
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and current_model_state.conversation_gpt4all %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
{% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and not current_model_state.conversation_gpt4all %}
<img id="misconfigured-icon-conversation-enable-offline-chat" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="The model was not downloaded as expected.">
{% endif %}
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Setup offline chat (Llama V2)</p>
<p class="card-description">Setup offline chat</p>
</div>
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
Disable
</button>
</div>
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
Enable
</button>
@ -346,7 +346,7 @@
featuresHintText.classList.add("show");
}
fetch('/api/config/data/processor/conversation/enable_offline_chat' + '?enable_offline_chat=' + enable, {
fetch('/api/config/data/processor/conversation/offline_chat' + '?enable_offline_chat=' + enable, {
method: 'POST',
headers: {
'Content-Type': 'application/json',

View file

@ -0,0 +1,83 @@
"""
Current format of khoj.yml
---
app:
...
content-type:
...
processor:
conversation:
enable-offline-chat: false
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
openai:
...
search-type:
...
New format of khoj.yml
---
app:
...
content-type:
...
processor:
conversation:
offline-chat:
enable-offline-chat: false
chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin
tokenizer: null
max_prompt_size: null
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
openai:
...
search-type:
...
"""
import logging
from packaging import version
from khoj.utils.yaml import load_config_from_file, save_config_to_file
logger = logging.getLogger(__name__)
def migrate_offline_chat_schema(args):
schema_version = "0.12.3"
raw_config = load_config_from_file(args.config_file)
previous_version = raw_config.get("version")
if "processor" not in raw_config:
return args
if raw_config["processor"] is None:
return args
if "conversation" not in raw_config["processor"]:
return args
if previous_version is None or version.parse(previous_version) < version.parse("0.12.3"):
logger.info(
f"Upgrading config schema to {schema_version} from {previous_version} to make (offline) chat more configuration"
)
raw_config["version"] = schema_version
# Create max-prompt-size field in conversation processor schema
raw_config["processor"]["conversation"]["max-prompt-size"] = None
raw_config["processor"]["conversation"]["tokenizer"] = None
# Create offline chat schema based on existing enable_offline_chat field in khoj config schema
offline_chat_model = (
raw_config["processor"]["conversation"]
.get("offline-chat", {})
.get("chat-model", "llama-2-7b-chat.ggmlv3.q4_0.bin")
)
raw_config["processor"]["conversation"]["offline-chat"] = {
"enable-offline-chat": raw_config["processor"]["conversation"].get("enable-offline-chat", False),
"chat-model": offline_chat_model,
}
# Delete old enable-offline-chat field from conversation processor schema
if "enable-offline-chat" in raw_config["processor"]["conversation"]:
del raw_config["processor"]["conversation"]["enable-offline-chat"]
save_config_to_file(raw_config, args.config_file)
return args

View file

@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
def extract_questions_offline(
text: str,
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
loaded_model: Union[Any, None] = None,
conversation_log={},
use_history: bool = True,
@ -113,7 +113,7 @@ def filter_questions(questions: List[str]):
]
filtered_questions = []
for q in questions:
if not any([word in q.lower() for word in hint_words]):
if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q):
filtered_questions.append(q)
return filtered_questions
@ -123,10 +123,12 @@ def converse_offline(
references,
user_query,
conversation_log={},
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
loaded_model: Union[Any, None] = None,
completion_func=None,
conversation_command=ConversationCommand.Default,
max_prompt_size=None,
tokenizer_name=None,
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
Converse with user using Llama
@ -158,6 +160,8 @@ def converse_offline(
prompts.system_prompt_message_llamav2,
conversation_log,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
)
g = ThreadedGenerator(references, completion_func=completion_func)

View file

@ -1,3 +0,0 @@
model_name_to_url = {
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
}

View file

@ -1,24 +1,8 @@
import os
import logging
import requests
import hashlib
from tqdm import tqdm
from khoj.processor.conversation.gpt4all import model_metadata
logger = logging.getLogger(__name__)
expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"}
def get_md5_checksum(filename: str):
hash_md5 = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def download_model(model_name: str):
try:
@ -27,57 +11,12 @@ def download_model(model_name: str):
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
url = model_metadata.model_name_to_url.get(model_name)
model_path = os.path.expanduser(f"~/.cache/gpt4all/")
if not url:
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
return GPT4All(model_name=model_name, model_path=model_path)
filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}")
if os.path.exists(filename):
# Check if the user is connected to the internet
try:
requests.get("https://www.google.com/", timeout=5)
except:
logger.debug("User is offline. Disabling allowed download flag")
return GPT4All(model_name=model_name, model_path=model_path, allow_download=False)
return GPT4All(model_name=model_name, model_path=model_path)
# Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file
tmp_filename = filename + ".tmp"
# Use GPU for Chat Model, if available
try:
os.makedirs(os.path.dirname(tmp_filename), exist_ok=True)
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
with requests.get(url, stream=True) as r:
r.raise_for_status()
total_size = int(r.headers.get("content-length", 0))
with open(tmp_filename, "wb") as f, tqdm(
unit="B", # unit string to be displayed.
unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
unit_divisor=1024, # is used when unit_scale is true
total=total_size, # the total iteration.
desc=model_name, # prefix to be displayed on progress bar.
) as progress_bar:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
progress_bar.update(len(chunk))
model = GPT4All(model_name=model_name, device="gpu")
logger.debug("Loaded chat model to GPU.")
except ValueError:
model = GPT4All(model_name=model_name)
logger.debug("Loaded chat model to CPU.")
# Verify the checksum
if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename):
logger.error(
f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available."
)
os.remove(tmp_filename)
raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.")
# Move the tmp file to the actual file
os.rename(tmp_filename, filename)
logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
return GPT4All(model_name)
except Exception as e:
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True)
# Remove the tmp file if it exists
if os.path.exists(tmp_filename):
os.remove(tmp_filename)
return None
return model

View file

@ -116,6 +116,8 @@ def converse(
temperature: float = 0.2,
completion_func=None,
conversation_command=ConversationCommand.Default,
max_prompt_size=None,
tokenizer_name=None,
):
"""
Converse with user using OpenAI's ChatGPT
@ -141,6 +143,8 @@ def converse(
prompts.personality.format(),
conversation_log,
model,
max_prompt_size,
tokenizer_name,
)
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}")

View file

@ -23,7 +23,7 @@ no_notes_found = PromptTemplate.from_template(
""".strip()
)
system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant.
system_prompt_message_llamav2 = f"""You are Khoj, a smart, inquisitive and helpful personal assistant.
Using your general knowledge and our past conversations as context, answer the following question.
If you do not know the answer, say 'I don't know.'"""
@ -51,13 +51,13 @@ extract_questions_system_prompt_llamav2 = PromptTemplate.from_template(
general_conversation_llamav2 = PromptTemplate.from_template(
"""
<s>[INST]{query}[/INST]
<s>[INST] {query} [/INST]
""".strip()
)
chat_history_llamav2_from_user = PromptTemplate.from_template(
"""
<s>[INST]{message}[/INST]
<s>[INST] {message} [/INST]
""".strip()
)
@ -69,7 +69,7 @@ chat_history_llamav2_from_assistant = PromptTemplate.from_template(
conversation_llamav2 = PromptTemplate.from_template(
"""
<s>[INST]{query}[/INST]
<s>[INST] {query} [/INST]
""".strip()
)
@ -91,7 +91,7 @@ Question: {query}
notes_conversation_llamav2 = PromptTemplate.from_template(
"""
Notes:
User's Notes:
{references}
Question: {query}
""".strip()
@ -134,19 +134,25 @@ Answer (in second person):"""
extract_questions_llamav2_sample = PromptTemplate.from_template(
"""
<s>[INST]<<SYS>>Current Date: {current_date}<</SYS>>[/INST]</s>
<s>[INST]How was my trip to Cambodia?[/INST][]</s>
<s>[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia?</s>
<s>[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have?</s>
<s>[INST]How many tennis balls fit in the back of a 2002 Honda Civic?[/INST]What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?</s>
<s>[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
<s>[INST]How are you feeling today?[/INST]</s>
<s>[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age?</s>
<s>[INST]<<SYS>>
<s>[INST] <<SYS>>Current Date: {current_date}<</SYS>> [/INST]</s>
<s>[INST] How was my trip to Cambodia? [/INST]
How was my trip to Cambodia?</s>
<s>[INST] Who did I visit the temple with on that trip? [/INST]
Who did I visit the temple with in Cambodia?</s>
<s>[INST] How should I take care of my plants? [/INST]
What kind of plants do I have? What issues do my plants have?</s>
<s>[INST] How many tennis balls fit in the back of a 2002 Honda Civic? [/INST]
What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?</s>
<s>[INST] What did I do for Christmas last year? [/INST]
What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
<s>[INST] How are you feeling today? [/INST]</s>
<s>[INST] Is Alice older than Bob? [/INST]
When was Alice born? What is Bob's age?</s>
<s>[INST] <<SYS>>
Use these notes from the user's previous conversations to provide a response:
{chat_history}
<</SYS>>[/INST]</s>
<s>[INST]{query}[/INST]
<</SYS>> [/INST]</s>
<s>[INST] {query} [/INST]
"""
)

View file

@ -3,24 +3,27 @@ import logging
from time import perf_counter
import json
from datetime import datetime
import queue
import tiktoken
# External packages
from langchain.schema import ChatMessage
from transformers import LlamaTokenizerFast
from transformers import AutoTokenizer
# Internal Packages
import queue
from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__)
max_prompt_size = {
model_to_prompt_size = {
"gpt-3.5-turbo": 4096,
"gpt-4": 8192,
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548,
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
"gpt-3.5-turbo-16k": 15000,
}
tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
model_to_tokenizer = {
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
}
class ThreadedGenerator:
@ -82,9 +85,26 @@ def message_to_log(
def generate_chatml_messages_with_context(
user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
user_message,
system_message,
conversation_log={},
model_name="gpt-3.5-turbo",
max_prompt_size=None,
tokenizer_name=None,
):
"""Generate messages for ChatGPT with context from previous conversation"""
# Set max prompt size from user config, pre-configured for model or to default prompt size
try:
max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
except:
max_prompt_size = 2000
logger.warning(
f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
)
# Scale lookback turns proportional to max prompt size supported by model
lookback_turns = max_prompt_size // 750
# Extract Chat History for Context
chat_logs = []
for chat in conversation_log.get("chat", []):
@ -105,19 +125,28 @@ def generate_chatml_messages_with_context(
messages = user_chatml_message + rest_backnforths + system_chatml_message
# Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_messages(messages, max_prompt_size[model_name], model_name)
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
# Return message in chronological order
return messages[::-1]
def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
def truncate_messages(
messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
if "llama" in model_name:
encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
else:
encoder = tiktoken.encoding_for_model(model_name)
try:
if model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name)
else:
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
except:
default_tokenizer = "hf-internal-testing/llama-tokenizer"
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
logger.warning(
f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
)
system_message = messages.pop()
system_message_tokens = len(encoder.encode(system_message.content))

View file

@ -284,10 +284,11 @@ if not state.demo:
except Exception as e:
return {"status": "error", "message": str(e)}
@api.post("/config/data/processor/conversation/enable_offline_chat", status_code=200)
@api.post("/config/data/processor/conversation/offline_chat", status_code=200)
async def set_processor_enable_offline_chat_config_data(
request: Request,
enable_offline_chat: bool,
offline_chat_model: Optional[str] = None,
client: Optional[str] = None,
):
_initialize_config()
@ -301,7 +302,9 @@ if not state.demo:
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
assert state.config.processor.conversation is not None
state.config.processor.conversation.enable_offline_chat = enable_offline_chat
state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat
if offline_chat_model is not None:
state.config.processor.conversation.offline_chat.chat_model = offline_chat_model
state.processor_config = configure_processor(state.config.processor, state.processor_config)
update_telemetry_state(
@ -713,7 +716,7 @@ async def chat(
conversation_command = ConversationCommand.General
if conversation_command == ConversationCommand.Help:
model_type = "offline" if state.processor_config.conversation.enable_offline_chat else "openai"
model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai"
formatted_help = help_message.format(model=model_type, version=state.khoj_version)
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
@ -788,7 +791,7 @@ async def extract_references_and_questions(
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
if state.processor_config.conversation.enable_offline_chat:
if state.processor_config.conversation.offline_chat.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
@ -804,7 +807,7 @@ async def extract_references_and_questions(
with timer("Searching knowledge base took", logger):
result_list = []
for query in inferred_queries:
n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n
n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n
result_list.extend(
await search(
f"{query} {filters_in_query}",

View file

@ -113,7 +113,7 @@ def generate_chat_response(
meta_log=meta_log,
)
if state.processor_config.conversation.enable_offline_chat:
if state.processor_config.conversation.offline_chat.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
chat_response = converse_offline(
references=compiled_references,
@ -122,6 +122,9 @@ def generate_chat_response(
conversation_log=meta_log,
completion_func=partial_completion,
conversation_command=conversation_command,
model=state.processor_config.conversation.offline_chat.chat_model,
max_prompt_size=state.processor_config.conversation.max_prompt_size,
tokenizer_name=state.processor_config.conversation.tokenizer,
)
elif state.processor_config.conversation.openai_model:
@ -135,6 +138,8 @@ def generate_chat_response(
api_key=api_key,
completion_func=partial_completion,
conversation_command=conversation_command,
max_prompt_size=state.processor_config.conversation.max_prompt_size,
tokenizer_name=state.processor_config.conversation.tokenizer,
)
except Exception as e:

View file

@ -9,6 +9,7 @@ from khoj.utils.yaml import parse_config_from_file
from khoj.migrations.migrate_version import migrate_config_to_version
from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
from khoj.migrations.migrate_offline_model import migrate_offline_model
from khoj.migrations.migrate_offline_chat_schema import migrate_offline_chat_schema
def cli(args=None):
@ -55,7 +56,12 @@ def cli(args=None):
def run_migrations(args):
migrations = [migrate_config_to_version, migrate_processor_conversation_schema, migrate_offline_model]
migrations = [
migrate_config_to_version,
migrate_processor_conversation_schema,
migrate_offline_model,
migrate_offline_chat_schema,
]
for migration in migrations:
args = migration(args)
return args

View file

@ -84,7 +84,6 @@ class SearchModels:
@dataclass
class GPT4AllProcessorConfig:
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
loaded_model: Union[Any, None] = None
@ -95,18 +94,20 @@ class ConversationProcessorConfigModel:
):
self.openai_model = conversation_config.openai
self.gpt4all_model = GPT4AllProcessorConfig()
self.enable_offline_chat = conversation_config.enable_offline_chat
self.offline_chat = conversation_config.offline_chat
self.max_prompt_size = conversation_config.max_prompt_size
self.tokenizer = conversation_config.tokenizer
self.conversation_logfile = Path(conversation_config.conversation_logfile)
self.chat_session: List[str] = []
self.meta_log: dict = {}
if self.enable_offline_chat:
if self.offline_chat.enable_offline_chat:
try:
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model)
except ValueError as e:
self.offline_chat.enable_offline_chat = False
self.gpt4all_model.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
self.enable_offline_chat = False
else:
self.gpt4all_model.loaded_model = None

View file

@ -91,10 +91,17 @@ class OpenAIProcessorConfig(ConfigBase):
chat_model: Optional[str] = "gpt-3.5-turbo"
class OfflineChatProcessorConfig(ConfigBase):
enable_offline_chat: Optional[bool] = False
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin"
class ConversationProcessorConfig(ConfigBase):
conversation_logfile: Path
openai: Optional[OpenAIProcessorConfig]
enable_offline_chat: Optional[bool] = False
offline_chat: Optional[OfflineChatProcessorConfig]
max_prompt_size: Optional[int]
tokenizer: Optional[str]
class ProcessorConfig(ConfigBase):

View file

@ -16,6 +16,7 @@ from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import (
ContentConfig,
ConversationProcessorConfig,
OfflineChatProcessorConfig,
OpenAIProcessorConfig,
ProcessorConfig,
TextContentConfig,
@ -205,8 +206,9 @@ def processor_config_offline_chat(tmp_path_factory):
# Setup conversation processor
processor_config = ProcessorConfig()
offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True)
processor_config.conversation = ConversationProcessorConfig(
enable_offline_chat=True,
offline_chat=offline_chat,
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
)

View file

@ -24,7 +24,7 @@ from khoj.processor.conversation.gpt4all.utils import download_model
from khoj.processor.conversation.utils import message_to_log
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_0.bin"
@pytest.fixture(scope="session")
@ -128,15 +128,15 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model):
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message(loaded_model):
# Act
response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model)
response = extract_questions_offline("Is Carl taller than Ross?", loaded_model=loaded_model)
# Assert
expected_responses = ["height", "taller", "shorter", "heights"]
expected_responses = ["height", "taller", "shorter", "heights", "who"]
assert len(response) <= 3
for question in response:
assert any([expected_response in question.lower() for expected_response in expected_responses]), (
"Expected chat actor to ask follow-up questions about Morpheus and Neo, but got: " + question
"Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question
)
@ -145,7 +145,7 @@ def test_extract_multiple_implicit_questions_from_message(loaded_model):
def test_generate_search_query_using_question_from_chat_history(loaded_model):
# Arrange
message_list = [
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
]
# Act
@ -156,17 +156,22 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
use_history=True,
)
expected_responses = [
"Vader",
"sons",
all_expected_in_response = [
"Anderson",
]
any_expected_in_response = [
"son",
"Darth",
"sons",
"children",
]
# Assert
assert len(response) >= 1
assert any([expected_response in response[0] for expected_response in expected_responses]), (
assert all([expected_response in response[0] for expected_response in all_expected_in_response]), (
"Expected chat actor to ask for clarification in response, but got: " + response[0]
)
assert any([expected_response in response[0] for expected_response in any_expected_in_response]), (
"Expected chat actor to ask for clarification in response, but got: " + response[0]
)
@ -176,20 +181,20 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
# Arrange
message_list = [
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
]
# Act
response = extract_questions_offline(
"Is she a Jedi?",
"Is she a Doctor?",
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
use_history=True,
)
expected_responses = [
"Leia",
"Vader",
"Barbara",
"Robert",
"daughter",
]