Use llama.cpp for offline chat models

- Benefits of moving to llama-cpp-python from gpt4all:
  - Support for all GGUF format chat models
  - Support for AMD, Nvidia, Mac, Vulcan GPU machines (instead of just Vulcan, Mac)
  - Supports models with more capabilities like tools, schema
    enforcement, speculative ddecoding, image gen etc.
- Upgrade default chat model, prompt size, tokenizer for new supported
  chat models

- Load offline chat model when present on disk without requiring internet
  - Load model onto GPU if not disabled and device has GPU
  - Load model onto CPU if loading model onto GPU fails
  - Create helper function to check and load model from disk, when model
    glob is present on disk.

    `Llama.from_pretrained' needs internet to get repo info from
    HuggingFace. This isn't required, if the model is already downloaded

    Didn't find any existing HF or llama.cpp method that looked for model
    glob on disk without internet
This commit is contained in:
Debanjum Singh Solanky 2024-03-16 01:49:44 +05:30
parent 0a7392f6ec
commit 8ca39a436c
12 changed files with 146 additions and 164 deletions

View file

@ -62,8 +62,7 @@ dependencies = [
"pymupdf >= 1.23.5", "pymupdf >= 1.23.5",
"django == 4.2.10", "django == 4.2.10",
"authlib == 1.2.1", "authlib == 1.2.1",
"gpt4all == 2.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'", "llama-cpp-python == 0.2.56",
"gpt4all == 2.1.0; platform_system == 'Windows' or platform_system == 'Darwin'",
"itsdangerous == 2.1.2", "itsdangerous == 2.1.2",
"httpx == 0.25.0", "httpx == 0.25.0",
"pgvector == 0.2.4", "pgvector == 0.2.4",

View file

@ -160,7 +160,7 @@ class ChatModelOptions(BaseModel):
max_prompt_size = models.IntegerField(default=None, null=True, blank=True) max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf") chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)

View file

@ -1,12 +1,13 @@
import logging import logging
from collections import deque
from datetime import datetime from datetime import datetime
from threading import Thread from threading import Thread
from typing import Any, Iterator, List, Union from typing import Any, Iterator, List, Union
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp import Llama
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
ThreadedGenerator, ThreadedGenerator,
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
@ -21,7 +22,7 @@ logger = logging.getLogger(__name__)
def extract_questions_offline( def extract_questions_offline(
text: str, text: str,
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf", model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
loaded_model: Union[Any, None] = None, loaded_model: Union[Any, None] = None,
conversation_log={}, conversation_log={},
use_history: bool = True, use_history: bool = True,
@ -31,22 +32,14 @@ def extract_questions_offline(
""" """
Infer search queries to retrieve relevant notes to answer user query Infer search queries to retrieve relevant notes to answer user query
""" """
try:
from gpt4all import GPT4All
except ModuleNotFoundError as e:
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
# Assert that loaded_model is either None or of type GPT4All
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
all_questions = text.split("? ") all_questions = text.split("? ")
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]] all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
if not should_extract_questions: if not should_extract_questions:
return all_questions return all_questions
gpt4all_model = loaded_model or GPT4All(model) assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model)
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
@ -75,10 +68,12 @@ def extract_questions_offline(
next_christmas_date=next_christmas_date, next_christmas_date=next_christmas_date,
location=location, location=location,
) )
message = system_prompt + example_questions messages = generate_chatml_messages_with_context(example_questions, system_message=system_prompt, model_name=model)
state.chat_lock.acquire() state.chat_lock.acquire()
try: try:
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512) response = offline_chat_model.create_chat_completion(messages, max_tokens=200, top_k=2, temp=0)
response = response[0]["choices"][0]["message"]["content"]
finally: finally:
state.chat_lock.release() state.chat_lock.release()
@ -133,7 +128,7 @@ def converse_offline(
references=[], references=[],
online_results=[], online_results=[],
conversation_log={}, conversation_log={},
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf", model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
loaded_model: Union[Any, None] = None, loaded_model: Union[Any, None] = None,
completion_func=None, completion_func=None,
conversation_commands=[ConversationCommand.Default], conversation_commands=[ConversationCommand.Default],
@ -145,15 +140,9 @@ def converse_offline(
""" """
Converse with user using Llama Converse with user using Llama
""" """
try:
from gpt4all import GPT4All
except ModuleNotFoundError as e:
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
gpt4all_model = loaded_model or GPT4All(model)
# Initialize Variables # Initialize Variables
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model)
compiled_references_message = "\n\n".join({f"{item}" for item in references}) compiled_references_message = "\n\n".join({f"{item}" for item in references})
conversation_primer = prompts.query_prompt.format(query=user_query) conversation_primer = prompts.query_prompt.format(query=user_query)
@ -191,72 +180,44 @@ def converse_offline(
prompts.system_prompt_message_gpt4all.format(current_date=current_date), prompts.system_prompt_message_gpt4all.format(current_date=current_date),
conversation_log, conversation_log,
model_name=model, model_name=model,
loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
) )
g = ThreadedGenerator(references, online_results, completion_func=completion_func) g = ThreadedGenerator(references, online_results, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, gpt4all_model)) t = Thread(target=llm_thread, args=(g, messages, offline_chat_model))
t.start() t.start()
return g return g
def llm_thread(g, messages: List[ChatMessage], model: Any): def llm_thread(g, messages: List[ChatMessage], model: Any):
user_message = messages[-1]
system_message = messages[0]
conversation_history = messages[1:-1]
formatted_messages = [
prompts.khoj_message_gpt4all.format(message=message.content)
if message.role == "assistant"
else prompts.user_message_gpt4all.format(message=message.content)
for message in conversation_history
]
stop_phrases = ["<s>", "INST]", "Notes:"] stop_phrases = ["<s>", "INST]", "Notes:"]
chat_history = "".join(formatted_messages)
templated_system_message = prompts.system_prompt_gpt4all.format(message=system_message.content)
templated_user_message = prompts.user_message_gpt4all.format(message=user_message.content)
prompted_message = templated_system_message + chat_history + templated_user_message
response_queue: deque[str] = deque(maxlen=3) # Create a response queue with a maximum length of 3
hit_stop_phrase = False
state.chat_lock.acquire() state.chat_lock.acquire()
response_iterator = send_message_to_model_offline(prompted_message, loaded_model=model, streaming=True)
try: try:
response_iterator = send_message_to_model_offline(
messages, loaded_model=model, stop=stop_phrases, streaming=True
)
for response in response_iterator: for response in response_iterator:
response_queue.append(response) g.send(response["choices"][0]["delta"].get("content", ""))
hit_stop_phrase = any(stop_phrase in "".join(response_queue) for stop_phrase in stop_phrases)
if hit_stop_phrase:
logger.debug(f"Stop response as hit stop phrase: {''.join(response_queue)}")
break
# Start streaming the response at a lag once the queue is full
# This allows stop word testing before sending the response
if len(response_queue) == response_queue.maxlen:
g.send(response_queue[0])
finally: finally:
if not hit_stop_phrase:
if len(response_queue) == response_queue.maxlen:
# remove already sent reponse chunk
response_queue.popleft()
# send the remaining response
g.send("".join(response_queue))
state.chat_lock.release() state.chat_lock.release()
g.close() g.close()
def send_message_to_model_offline( def send_message_to_model_offline(
message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False, system_message="" messages: List[ChatMessage],
) -> str: loaded_model=None,
try: model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
from gpt4all import GPT4All streaming=False,
except ModuleNotFoundError as e: stop=[],
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") ):
raise e assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model)
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None" messages_dict = [{"role": message.role, "content": message.content} for message in messages]
gpt4all_model = loaded_model or GPT4All(model) response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
if streaming:
return gpt4all_model.generate( return response
system_message + message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming else:
) return response["choices"][0]["message"].get("content", "")

View file

@ -1,43 +1,53 @@
import glob
import logging import logging
import os
from huggingface_hub.constants import HF_HUB_CACHE
from khoj.utils import state from khoj.utils import state
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def download_model(model_name: str): def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
try: from llama_cpp.llama import Llama
import gpt4all
except ModuleNotFoundError as e: # Initialize Model Parameters
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") kwargs = {"n_threads": 4, "n_ctx": 4096, "verbose": False}
raise e
# Decide whether to load model to GPU or CPU # Decide whether to load model to GPU or CPU
chat_model_config = None device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu"
kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0
# Check if the model is already downloaded
model_path = load_model_from_cache(repo_id, filename)
try: try:
# Download the chat model and its config if model_path:
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True) chat_model = Llama(model_path, **kwargs)
# Try load chat model to GPU if:
# 1. Loading chat model to GPU isn't disabled via CLI and
# 2. Machine has GPU
# 3. GPU has enough free memory to load the chat model with max context length of 4096
device = (
"gpu"
if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"], 4096)
else "cpu"
)
except ValueError:
device = "cpu"
except Exception as e:
if chat_model_config is None:
device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory
logger.debug(f"Unable to download model config from gpt4all website: {e}")
else: else:
raise e Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
except:
# Load model on CPU if GPU is not available
kwargs["n_gpu_layers"], device = 0, "cpu"
if model_path:
chat_model = Llama(model_path, **kwargs)
else:
chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
# Now load the downloaded chat model onto appropriate device logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}")
chat_model = gpt4all.GPT4All(model_name=model_name, n_ctx=4096, device=device, allow_download=False)
logger.debug(f"Loaded chat model to {device.upper()}.")
return chat_model return chat_model
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
# Construct the path to the model file in the cache directory
repo_org, repo_name = repo_id.split("/")
object_id = "--".join([repo_type, repo_org, repo_name])
model_path = os.path.sep.join([HF_HUB_CACHE, object_id, "snapshots", "**", filename])
# Check if the model file exists
paths = glob.glob(model_path)
if paths:
return paths[0]
else:
return None

View file

@ -3,29 +3,28 @@ import logging
import queue import queue
from datetime import datetime from datetime import datetime
from time import perf_counter from time import perf_counter
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
import tiktoken import tiktoken
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from llama_cpp.llama import Llama
from transformers import AutoTokenizer from transformers import AutoTokenizer
from khoj.database.adapters import ConversationAdapters from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ClientApplication, KhojUser from khoj.database.models import ClientApplication, KhojUser
from khoj.processor.conversation.offline.utils import download_model
from khoj.utils.helpers import is_none_or_empty, merge_dicts from khoj.utils.helpers import is_none_or_empty, merge_dicts
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
model_to_prompt_size = { model_to_prompt_size = {
"gpt-3.5-turbo": 3000, "gpt-3.5-turbo": 12000,
"gpt-3.5-turbo-0125": 3000, "gpt-3.5-turbo-0125": 12000,
"gpt-4-0125-preview": 7000, "gpt-4-0125-preview": 20000,
"gpt-4-turbo-preview": 7000, "gpt-4-turbo-preview": 20000,
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548, "TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500,
"mistral-7b-instruct-v0.1.Q4_0.gguf": 1548, "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500,
}
model_to_tokenizer = {
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
"mistral-7b-instruct-v0.1.Q4_0.gguf": "mistralai/Mistral-7B-Instruct-v0.1",
} }
model_to_tokenizer: Dict[str, str] = {}
class ThreadedGenerator: class ThreadedGenerator:
@ -134,9 +133,10 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
def generate_chatml_messages_with_context( def generate_chatml_messages_with_context(
user_message, user_message,
system_message, system_message=None,
conversation_log={}, conversation_log={},
model_name="gpt-3.5-turbo", model_name="gpt-3.5-turbo",
loaded_model: Optional[Llama] = None,
max_prompt_size=None, max_prompt_size=None,
tokenizer_name=None, tokenizer_name=None,
): ):
@ -159,7 +159,7 @@ def generate_chatml_messages_with_context(
chat_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n" chat_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
chat_logs += [chat["message"] + chat_notes] chat_logs += [chat["message"] + chat_notes]
rest_backnforths = [] rest_backnforths: List[ChatMessage] = []
# Extract in reverse chronological order # Extract in reverse chronological order
for user_msg, assistant_msg in zip(chat_logs[-2::-2], chat_logs[::-2]): for user_msg, assistant_msg in zip(chat_logs[-2::-2], chat_logs[::-2]):
if len(rest_backnforths) >= 2 * lookback_turns: if len(rest_backnforths) >= 2 * lookback_turns:
@ -176,21 +176,30 @@ def generate_chatml_messages_with_context(
messages.append(ChatMessage(content=system_message, role="system")) messages.append(ChatMessage(content=system_message, role="system"))
# Truncate oldest messages from conversation history until under max supported prompt size by model # Truncate oldest messages from conversation history until under max supported prompt size by model
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name) messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
# Return message in chronological order # Return message in chronological order
return messages[::-1] return messages[::-1]
def truncate_messages( def truncate_messages(
messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None messages: list[ChatMessage],
max_prompt_size,
model_name: str,
loaded_model: Optional[Llama] = None,
tokenizer_name=None,
) -> list[ChatMessage]: ) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model""" """Truncate messages to fit within max prompt size supported by model"""
try: try:
if model_name.startswith("gpt-"): if loaded_model:
encoder = loaded_model.tokenizer()
elif model_name.startswith("gpt-"):
encoder = tiktoken.encoding_for_model(model_name) encoder = tiktoken.encoding_for_model(model_name)
else: else:
try:
encoder = download_model(model_name).tokenizer()
except:
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name]) encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
except: except:
default_tokenizer = "hf-internal-testing/llama-tokenizer" default_tokenizer = "hf-internal-testing/llama-tokenizer"

View file

@ -370,27 +370,31 @@ async def send_message_to_model_wrapper(
if conversation_config is None: if conversation_config is None:
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
truncated_messages = generate_chatml_messages_with_context( chat_model = conversation_config.chat_model
user_message=message, system_message=system_message, model_name=conversation_config.chat_model
)
if conversation_config.model_type == "offline": if conversation_config.model_type == "offline":
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None: if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model) state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model)
loaded_model = state.gpt4all_processor_config.loaded_model loaded_model = state.gpt4all_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=message, system_message=system_message, model_name=chat_model, loaded_model=loaded_model
)
return send_message_to_model_offline( return send_message_to_model_offline(
message=truncated_messages[-1].content, messages=truncated_messages,
loaded_model=loaded_model, loaded_model=loaded_model,
model=conversation_config.chat_model, model=chat_model,
streaming=False, streaming=False,
system_message=truncated_messages[0].content,
) )
elif conversation_config.model_type == "openai": elif conversation_config.model_type == "openai":
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
api_key = openai_chat_config.api_key api_key = openai_chat_config.api_key
chat_model = conversation_config.chat_model truncated_messages = generate_chatml_messages_with_context(
user_message=message, system_message=system_message, model_name=chat_model
)
openai_response = send_message_to_model( openai_response = send_message_to_model(
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
) )

View file

@ -75,10 +75,7 @@ class GPT4AllProcessorConfig:
class GPT4AllProcessorModel: class GPT4AllProcessorModel:
def __init__( def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"):
self,
chat_model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
):
self.chat_model = chat_model self.chat_model = chat_model
self.loaded_model = None self.loaded_model = None
try: try:

View file

@ -6,7 +6,7 @@ empty_escape_sequences = "\n|\r|\t| "
app_env_filepath = "~/.khoj/env" app_env_filepath = "~/.khoj/env"
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry" telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
content_directory = "~/.khoj/content/" content_directory = "~/.khoj/content/"
default_offline_chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf" default_offline_chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
default_online_chat_model = "gpt-4-turbo-preview" default_online_chat_model = "gpt-4-turbo-preview"
empty_config = { empty_config = {

View file

@ -91,7 +91,7 @@ class OpenAIProcessorConfig(ConfigBase):
class OfflineChatProcessorConfig(ConfigBase): class OfflineChatProcessorConfig(ConfigBase):
enable_offline_chat: Optional[bool] = False enable_offline_chat: Optional[bool] = False
chat_model: Optional[str] = "mistral-7b-instruct-v0.1.Q4_0.gguf" chat_model: Optional[str] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
class ConversationProcessorConfig(ConfigBase): class ConversationProcessorConfig(ConfigBase):

View file

@ -40,9 +40,9 @@ class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
class Meta: class Meta:
model = ChatModelOptions model = ChatModelOptions
max_prompt_size = 2000 max_prompt_size = 3500
tokenizer = None tokenizer = None
chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf" chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
model_type = "offline" model_type = "offline"

View file

@ -5,18 +5,12 @@ import pytest
SKIP_TESTS = True SKIP_TESTS = True
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
SKIP_TESTS, SKIP_TESTS,
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.", reason="Disable in CI to avoid long test runs.",
) )
import freezegun import freezegun
from freezegun import freeze_time from freezegun import freeze_time
try:
from gpt4all import GPT4All
except ModuleNotFoundError as e:
print("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
from khoj.processor.conversation.offline.chat_model import ( from khoj.processor.conversation.offline.chat_model import (
converse_offline, converse_offline,
extract_questions_offline, extract_questions_offline,
@ -25,14 +19,12 @@ from khoj.processor.conversation.offline.chat_model import (
from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.offline.utils import download_model
from khoj.processor.conversation.utils import message_to_log from khoj.processor.conversation.utils import message_to_log
from khoj.routers.helpers import aget_relevant_output_modes from khoj.routers.helpers import aget_relevant_output_modes
from khoj.utils.constants import default_offline_chat_model
MODEL_NAME = "mistral-7b-instruct-v0.1.Q4_0.gguf"
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def loaded_model(): def loaded_model():
download_model(MODEL_NAME) return download_model(default_offline_chat_model)
return GPT4All(MODEL_NAME)
freezegun.configure(extend_ignore_list=["transformers"]) freezegun.configure(extend_ignore_list=["transformers"])
@ -40,7 +32,6 @@ freezegun.configure(extend_ignore_list=["transformers"])
# Test # Test
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Search actor isn't very date aware nor capable of formatting")
@pytest.mark.chatquality @pytest.mark.chatquality
@freeze_time("1984-04-02", ignore=["transformers"]) @freeze_time("1984-04-02", ignore=["transformers"])
def test_extract_question_with_date_filter_from_relative_day(loaded_model): def test_extract_question_with_date_filter_from_relative_day(loaded_model):
@ -149,20 +140,22 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
message_list = [ message_list = [
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []), ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
] ]
query = "Does he have any sons?"
# Act # Act
response = extract_questions_offline( response = extract_questions_offline(
"Does he have any sons?", query,
conversation_log=populate_chat_history(message_list), conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model, loaded_model=loaded_model,
use_history=True, use_history=True,
) )
all_expected_in_response = [ any_expected_with_barbara = [
"Anderson", "sibling",
"brother",
] ]
any_expected_in_response = [ any_expected_with_anderson = [
"son", "son",
"sons", "sons",
"children", "children",
@ -170,11 +163,20 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
# Assert # Assert
assert len(response) >= 1 assert len(response) >= 1
assert all([expected_response in response[0] for expected_response in all_expected_in_response]), ( assert response[-1] == query, "Expected last question to be the user query, but got: " + response[-1]
"Expected chat actor to ask for clarification in response, but got: " + response[0] # Ensure the remaining generated search queries use proper nouns and chat history context
for question in response[:-1]:
if "Barbara" in question:
assert any([expected_relation in question for expected_relation in any_expected_with_barbara]), (
"Expected search queries using proper nouns and chat history for context, but got: " + question
) )
assert any([expected_response in response[0] for expected_response in any_expected_in_response]), ( elif "Anderson" in question:
"Expected chat actor to ask for clarification in response, but got: " + response[0] assert any([expected_response in question for expected_response in any_expected_with_anderson]), (
"Expected search queries using proper nouns and chat history for context, but got: " + question
)
else:
assert False, (
"Expected search queries using proper nouns and chat history for context, but got: " + question
) )
@ -312,6 +314,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor lies when it doesn't know the answer")
@pytest.mark.chatquality @pytest.mark.chatquality
def test_refuse_answering_unanswerable_question(loaded_model): def test_refuse_answering_unanswerable_question(loaded_model):
"Chat actor should not try make up answers to unanswerable questions." "Chat actor should not try make up answers to unanswerable questions."
@ -436,7 +439,6 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor doesn't ask clarifying questions when context is insufficient")
@pytest.mark.chatquality @pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model): def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context" "Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"

View file

@ -14,7 +14,7 @@ from tests.helpers import ConversationFactory
SKIP_TESTS = True SKIP_TESTS = True
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
SKIP_TESTS, SKIP_TESTS,
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.", reason="Disable in CI to avoid long test runs.",
) )
fake = Faker() fake = Faker()
@ -47,7 +47,7 @@ def populate_chat_history(message_list, user):
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet") @pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
@pytest.mark.chatquality @pytest.mark.chatquality
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat): def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_chat):
# Act # Act
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true') response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
response_message = response.content.decode("utf-8") response_message = response.content.decode("utf-8")
@ -338,7 +338,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(client_off
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
assert "23" in response_message assert "26" in response_message
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@ -514,7 +514,7 @@ async def test_get_correct_tools_general(client_offline_chat):
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
async def test_get_correct_tools_with_chat_history(client_offline_chat): async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2):
# Arrange # Arrange
user_query = "What's the latest in the Israel/Palestine conflict?" user_query = "What's the latest in the Israel/Palestine conflict?"
chat_log = [ chat_log = [
@ -525,7 +525,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat):
), ),
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []), ("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []),
] ]
chat_history = populate_chat_history(chat_log) chat_history = populate_chat_history(chat_log, default_user2)
# Act # Act
tools = await aget_relevant_information_sources(user_query, chat_history) tools = await aget_relevant_information_sources(user_query, chat_history)