mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Use llama.cpp for offline chat models
- Benefits of moving to llama-cpp-python from gpt4all: - Support for all GGUF format chat models - Support for AMD, Nvidia, Mac, Vulcan GPU machines (instead of just Vulcan, Mac) - Supports models with more capabilities like tools, schema enforcement, speculative ddecoding, image gen etc. - Upgrade default chat model, prompt size, tokenizer for new supported chat models - Load offline chat model when present on disk without requiring internet - Load model onto GPU if not disabled and device has GPU - Load model onto CPU if loading model onto GPU fails - Create helper function to check and load model from disk, when model glob is present on disk. `Llama.from_pretrained' needs internet to get repo info from HuggingFace. This isn't required, if the model is already downloaded Didn't find any existing HF or llama.cpp method that looked for model glob on disk without internet
This commit is contained in:
parent
0a7392f6ec
commit
8ca39a436c
12 changed files with 146 additions and 164 deletions
|
@ -62,8 +62,7 @@ dependencies = [
|
|||
"pymupdf >= 1.23.5",
|
||||
"django == 4.2.10",
|
||||
"authlib == 1.2.1",
|
||||
"gpt4all == 2.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'",
|
||||
"gpt4all == 2.1.0; platform_system == 'Windows' or platform_system == 'Darwin'",
|
||||
"llama-cpp-python == 0.2.56",
|
||||
"itsdangerous == 2.1.2",
|
||||
"httpx == 0.25.0",
|
||||
"pgvector == 0.2.4",
|
||||
|
|
|
@ -160,7 +160,7 @@ class ChatModelOptions(BaseModel):
|
|||
|
||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf")
|
||||
chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
|
||||
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import logging
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
from typing import Any, Iterator, List, Union
|
||||
|
||||
from langchain.schema import ChatMessage
|
||||
from llama_cpp import Llama
|
||||
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
generate_chatml_messages_with_context,
|
||||
|
@ -21,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
def extract_questions_offline(
|
||||
text: str,
|
||||
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
|
||||
model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
conversation_log={},
|
||||
use_history: bool = True,
|
||||
|
@ -31,22 +32,14 @@ def extract_questions_offline(
|
|||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError as e:
|
||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
raise e
|
||||
|
||||
# Assert that loaded_model is either None or of type GPT4All
|
||||
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
|
||||
|
||||
all_questions = text.split("? ")
|
||||
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
|
||||
|
||||
if not should_extract_questions:
|
||||
return all_questions
|
||||
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model)
|
||||
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
|
||||
|
||||
|
@ -75,10 +68,12 @@ def extract_questions_offline(
|
|||
next_christmas_date=next_christmas_date,
|
||||
location=location,
|
||||
)
|
||||
message = system_prompt + example_questions
|
||||
messages = generate_chatml_messages_with_context(example_questions, system_message=system_prompt, model_name=model)
|
||||
|
||||
state.chat_lock.acquire()
|
||||
try:
|
||||
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512)
|
||||
response = offline_chat_model.create_chat_completion(messages, max_tokens=200, top_k=2, temp=0)
|
||||
response = response[0]["choices"][0]["message"]["content"]
|
||||
finally:
|
||||
state.chat_lock.release()
|
||||
|
||||
|
@ -133,7 +128,7 @@ def converse_offline(
|
|||
references=[],
|
||||
online_results=[],
|
||||
conversation_log={},
|
||||
model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
|
||||
model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
completion_func=None,
|
||||
conversation_commands=[ConversationCommand.Default],
|
||||
|
@ -145,15 +140,9 @@ def converse_offline(
|
|||
"""
|
||||
Converse with user using Llama
|
||||
"""
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError as e:
|
||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
raise e
|
||||
|
||||
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
# Initialize Variables
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model)
|
||||
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
@ -191,72 +180,44 @@ def converse_offline(
|
|||
prompts.system_prompt_message_gpt4all.format(current_date=current_date),
|
||||
conversation_log,
|
||||
model_name=model,
|
||||
loaded_model=offline_chat_model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
)
|
||||
|
||||
g = ThreadedGenerator(references, online_results, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, gpt4all_model))
|
||||
t = Thread(target=llm_thread, args=(g, messages, offline_chat_model))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages: List[ChatMessage], model: Any):
|
||||
user_message = messages[-1]
|
||||
system_message = messages[0]
|
||||
conversation_history = messages[1:-1]
|
||||
|
||||
formatted_messages = [
|
||||
prompts.khoj_message_gpt4all.format(message=message.content)
|
||||
if message.role == "assistant"
|
||||
else prompts.user_message_gpt4all.format(message=message.content)
|
||||
for message in conversation_history
|
||||
]
|
||||
|
||||
stop_phrases = ["<s>", "INST]", "Notes:"]
|
||||
chat_history = "".join(formatted_messages)
|
||||
templated_system_message = prompts.system_prompt_gpt4all.format(message=system_message.content)
|
||||
templated_user_message = prompts.user_message_gpt4all.format(message=user_message.content)
|
||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||
response_queue: deque[str] = deque(maxlen=3) # Create a response queue with a maximum length of 3
|
||||
hit_stop_phrase = False
|
||||
|
||||
state.chat_lock.acquire()
|
||||
response_iterator = send_message_to_model_offline(prompted_message, loaded_model=model, streaming=True)
|
||||
try:
|
||||
response_iterator = send_message_to_model_offline(
|
||||
messages, loaded_model=model, stop=stop_phrases, streaming=True
|
||||
)
|
||||
for response in response_iterator:
|
||||
response_queue.append(response)
|
||||
hit_stop_phrase = any(stop_phrase in "".join(response_queue) for stop_phrase in stop_phrases)
|
||||
if hit_stop_phrase:
|
||||
logger.debug(f"Stop response as hit stop phrase: {''.join(response_queue)}")
|
||||
break
|
||||
# Start streaming the response at a lag once the queue is full
|
||||
# This allows stop word testing before sending the response
|
||||
if len(response_queue) == response_queue.maxlen:
|
||||
g.send(response_queue[0])
|
||||
g.send(response["choices"][0]["delta"].get("content", ""))
|
||||
finally:
|
||||
if not hit_stop_phrase:
|
||||
if len(response_queue) == response_queue.maxlen:
|
||||
# remove already sent reponse chunk
|
||||
response_queue.popleft()
|
||||
# send the remaining response
|
||||
g.send("".join(response_queue))
|
||||
state.chat_lock.release()
|
||||
g.close()
|
||||
|
||||
|
||||
def send_message_to_model_offline(
|
||||
message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False, system_message=""
|
||||
) -> str:
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError as e:
|
||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
raise e
|
||||
|
||||
assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None"
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
|
||||
return gpt4all_model.generate(
|
||||
system_message + message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming
|
||||
)
|
||||
messages: List[ChatMessage],
|
||||
loaded_model=None,
|
||||
model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF",
|
||||
streaming=False,
|
||||
stop=[],
|
||||
):
|
||||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model)
|
||||
messages_dict = [{"role": message.role, "content": message.content} for message in messages]
|
||||
response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming)
|
||||
if streaming:
|
||||
return response
|
||||
else:
|
||||
return response["choices"][0]["message"].get("content", "")
|
||||
|
|
|
@ -1,43 +1,53 @@
|
|||
import glob
|
||||
import logging
|
||||
import os
|
||||
|
||||
from huggingface_hub.constants import HF_HUB_CACHE
|
||||
|
||||
from khoj.utils import state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def download_model(model_name: str):
|
||||
try:
|
||||
import gpt4all
|
||||
except ModuleNotFoundError as e:
|
||||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
raise e
|
||||
def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"):
|
||||
from llama_cpp.llama import Llama
|
||||
|
||||
# Initialize Model Parameters
|
||||
kwargs = {"n_threads": 4, "n_ctx": 4096, "verbose": False}
|
||||
|
||||
# Decide whether to load model to GPU or CPU
|
||||
chat_model_config = None
|
||||
device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu"
|
||||
kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0
|
||||
|
||||
# Check if the model is already downloaded
|
||||
model_path = load_model_from_cache(repo_id, filename)
|
||||
try:
|
||||
# Download the chat model and its config
|
||||
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
|
||||
|
||||
# Try load chat model to GPU if:
|
||||
# 1. Loading chat model to GPU isn't disabled via CLI and
|
||||
# 2. Machine has GPU
|
||||
# 3. GPU has enough free memory to load the chat model with max context length of 4096
|
||||
device = (
|
||||
"gpu"
|
||||
if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"], 4096)
|
||||
else "cpu"
|
||||
)
|
||||
except ValueError:
|
||||
device = "cpu"
|
||||
except Exception as e:
|
||||
if chat_model_config is None:
|
||||
device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory
|
||||
logger.debug(f"Unable to download model config from gpt4all website: {e}")
|
||||
if model_path:
|
||||
chat_model = Llama(model_path, **kwargs)
|
||||
else:
|
||||
raise e
|
||||
Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||
except:
|
||||
# Load model on CPU if GPU is not available
|
||||
kwargs["n_gpu_layers"], device = 0, "cpu"
|
||||
if model_path:
|
||||
chat_model = Llama(model_path, **kwargs)
|
||||
else:
|
||||
chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs)
|
||||
|
||||
# Now load the downloaded chat model onto appropriate device
|
||||
chat_model = gpt4all.GPT4All(model_name=model_name, n_ctx=4096, device=device, allow_download=False)
|
||||
logger.debug(f"Loaded chat model to {device.upper()}.")
|
||||
logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}")
|
||||
|
||||
return chat_model
|
||||
|
||||
|
||||
def load_model_from_cache(repo_id: str, filename: str, repo_type="models"):
|
||||
# Construct the path to the model file in the cache directory
|
||||
repo_org, repo_name = repo_id.split("/")
|
||||
object_id = "--".join([repo_type, repo_org, repo_name])
|
||||
model_path = os.path.sep.join([HF_HUB_CACHE, object_id, "snapshots", "**", filename])
|
||||
|
||||
# Check if the model file exists
|
||||
paths = glob.glob(model_path)
|
||||
if paths:
|
||||
return paths[0]
|
||||
else:
|
||||
return None
|
||||
|
|
|
@ -3,29 +3,28 @@ import logging
|
|||
import queue
|
||||
from datetime import datetime
|
||||
from time import perf_counter
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import tiktoken
|
||||
from langchain.schema import ChatMessage
|
||||
from llama_cpp.llama import Llama
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import ClientApplication, KhojUser
|
||||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.utils.helpers import is_none_or_empty, merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
model_to_prompt_size = {
|
||||
"gpt-3.5-turbo": 3000,
|
||||
"gpt-3.5-turbo-0125": 3000,
|
||||
"gpt-4-0125-preview": 7000,
|
||||
"gpt-4-turbo-preview": 7000,
|
||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
|
||||
"mistral-7b-instruct-v0.1.Q4_0.gguf": 1548,
|
||||
}
|
||||
model_to_tokenizer = {
|
||||
"llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
|
||||
"mistral-7b-instruct-v0.1.Q4_0.gguf": "mistralai/Mistral-7B-Instruct-v0.1",
|
||||
"gpt-3.5-turbo": 12000,
|
||||
"gpt-3.5-turbo-0125": 12000,
|
||||
"gpt-4-0125-preview": 20000,
|
||||
"gpt-4-turbo-preview": 20000,
|
||||
"TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500,
|
||||
"NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500,
|
||||
}
|
||||
model_to_tokenizer: Dict[str, str] = {}
|
||||
|
||||
|
||||
class ThreadedGenerator:
|
||||
|
@ -134,9 +133,10 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
|||
|
||||
def generate_chatml_messages_with_context(
|
||||
user_message,
|
||||
system_message,
|
||||
system_message=None,
|
||||
conversation_log={},
|
||||
model_name="gpt-3.5-turbo",
|
||||
loaded_model: Optional[Llama] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
):
|
||||
|
@ -159,7 +159,7 @@ def generate_chatml_messages_with_context(
|
|||
chat_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
|
||||
chat_logs += [chat["message"] + chat_notes]
|
||||
|
||||
rest_backnforths = []
|
||||
rest_backnforths: List[ChatMessage] = []
|
||||
# Extract in reverse chronological order
|
||||
for user_msg, assistant_msg in zip(chat_logs[-2::-2], chat_logs[::-2]):
|
||||
if len(rest_backnforths) >= 2 * lookback_turns:
|
||||
|
@ -176,22 +176,31 @@ def generate_chatml_messages_with_context(
|
|||
messages.append(ChatMessage(content=system_message, role="system"))
|
||||
|
||||
# Truncate oldest messages from conversation history until under max supported prompt size by model
|
||||
messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
|
||||
messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name)
|
||||
|
||||
# Return message in chronological order
|
||||
return messages[::-1]
|
||||
|
||||
|
||||
def truncate_messages(
|
||||
messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None
|
||||
messages: list[ChatMessage],
|
||||
max_prompt_size,
|
||||
model_name: str,
|
||||
loaded_model: Optional[Llama] = None,
|
||||
tokenizer_name=None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
|
||||
try:
|
||||
if model_name.startswith("gpt-"):
|
||||
if loaded_model:
|
||||
encoder = loaded_model.tokenizer()
|
||||
elif model_name.startswith("gpt-"):
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
else:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
||||
try:
|
||||
encoder = download_model(model_name).tokenizer()
|
||||
except:
|
||||
encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
|
||||
except:
|
||||
default_tokenizer = "hf-internal-testing/llama-tokenizer"
|
||||
encoder = AutoTokenizer.from_pretrained(default_tokenizer)
|
||||
|
|
|
@ -370,27 +370,31 @@ async def send_message_to_model_wrapper(
|
|||
if conversation_config is None:
|
||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message, system_message=system_message, model_name=conversation_config.chat_model
|
||||
)
|
||||
chat_model = conversation_config.chat_model
|
||||
|
||||
if conversation_config.model_type == "offline":
|
||||
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model)
|
||||
|
||||
loaded_model = state.gpt4all_processor_config.loaded_model
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message, system_message=system_message, model_name=chat_model, loaded_model=loaded_model
|
||||
)
|
||||
|
||||
return send_message_to_model_offline(
|
||||
message=truncated_messages[-1].content,
|
||||
messages=truncated_messages,
|
||||
loaded_model=loaded_model,
|
||||
model=conversation_config.chat_model,
|
||||
model=chat_model,
|
||||
streaming=False,
|
||||
system_message=truncated_messages[0].content,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == "openai":
|
||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||
api_key = openai_chat_config.api_key
|
||||
chat_model = conversation_config.chat_model
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message, system_message=system_message, model_name=chat_model
|
||||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
)
|
||||
|
|
|
@ -75,10 +75,7 @@ class GPT4AllProcessorConfig:
|
|||
|
||||
|
||||
class GPT4AllProcessorModel:
|
||||
def __init__(
|
||||
self,
|
||||
chat_model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf",
|
||||
):
|
||||
def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"):
|
||||
self.chat_model = chat_model
|
||||
self.loaded_model = None
|
||||
try:
|
||||
|
|
|
@ -6,7 +6,7 @@ empty_escape_sequences = "\n|\r|\t| "
|
|||
app_env_filepath = "~/.khoj/env"
|
||||
telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry"
|
||||
content_directory = "~/.khoj/content/"
|
||||
default_offline_chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
||||
default_offline_chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||
default_online_chat_model = "gpt-4-turbo-preview"
|
||||
|
||||
empty_config = {
|
||||
|
|
|
@ -91,7 +91,7 @@ class OpenAIProcessorConfig(ConfigBase):
|
|||
|
||||
class OfflineChatProcessorConfig(ConfigBase):
|
||||
enable_offline_chat: Optional[bool] = False
|
||||
chat_model: Optional[str] = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
||||
chat_model: Optional[str] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||
|
||||
|
||||
class ConversationProcessorConfig(ConfigBase):
|
||||
|
|
|
@ -40,9 +40,9 @@ class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
|
|||
class Meta:
|
||||
model = ChatModelOptions
|
||||
|
||||
max_prompt_size = 2000
|
||||
max_prompt_size = 3500
|
||||
tokenizer = None
|
||||
chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
||||
chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||
model_type = "offline"
|
||||
|
||||
|
||||
|
|
|
@ -5,18 +5,12 @@ import pytest
|
|||
SKIP_TESTS = True
|
||||
pytestmark = pytest.mark.skipif(
|
||||
SKIP_TESTS,
|
||||
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.",
|
||||
reason="Disable in CI to avoid long test runs.",
|
||||
)
|
||||
|
||||
import freezegun
|
||||
from freezegun import freeze_time
|
||||
|
||||
try:
|
||||
from gpt4all import GPT4All
|
||||
except ModuleNotFoundError as e:
|
||||
print("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
|
||||
|
||||
from khoj.processor.conversation.offline.chat_model import (
|
||||
converse_offline,
|
||||
extract_questions_offline,
|
||||
|
@ -25,14 +19,12 @@ from khoj.processor.conversation.offline.chat_model import (
|
|||
from khoj.processor.conversation.offline.utils import download_model
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_output_modes
|
||||
|
||||
MODEL_NAME = "mistral-7b-instruct-v0.1.Q4_0.gguf"
|
||||
from khoj.utils.constants import default_offline_chat_model
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loaded_model():
|
||||
download_model(MODEL_NAME)
|
||||
return GPT4All(MODEL_NAME)
|
||||
return download_model(default_offline_chat_model)
|
||||
|
||||
|
||||
freezegun.configure(extend_ignore_list=["transformers"])
|
||||
|
@ -40,7 +32,6 @@ freezegun.configure(extend_ignore_list=["transformers"])
|
|||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Search actor isn't very date aware nor capable of formatting")
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02", ignore=["transformers"])
|
||||
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
||||
|
@ -149,20 +140,22 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
|||
message_list = [
|
||||
("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
|
||||
]
|
||||
query = "Does he have any sons?"
|
||||
|
||||
# Act
|
||||
response = extract_questions_offline(
|
||||
"Does he have any sons?",
|
||||
query,
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
all_expected_in_response = [
|
||||
"Anderson",
|
||||
any_expected_with_barbara = [
|
||||
"sibling",
|
||||
"brother",
|
||||
]
|
||||
|
||||
any_expected_in_response = [
|
||||
any_expected_with_anderson = [
|
||||
"son",
|
||||
"sons",
|
||||
"children",
|
||||
|
@ -170,12 +163,21 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
|||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
assert all([expected_response in response[0] for expected_response in all_expected_in_response]), (
|
||||
"Expected chat actor to ask for clarification in response, but got: " + response[0]
|
||||
)
|
||||
assert any([expected_response in response[0] for expected_response in any_expected_in_response]), (
|
||||
"Expected chat actor to ask for clarification in response, but got: " + response[0]
|
||||
)
|
||||
assert response[-1] == query, "Expected last question to be the user query, but got: " + response[-1]
|
||||
# Ensure the remaining generated search queries use proper nouns and chat history context
|
||||
for question in response[:-1]:
|
||||
if "Barbara" in question:
|
||||
assert any([expected_relation in question for expected_relation in any_expected_with_barbara]), (
|
||||
"Expected search queries using proper nouns and chat history for context, but got: " + question
|
||||
)
|
||||
elif "Anderson" in question:
|
||||
assert any([expected_response in question for expected_response in any_expected_with_anderson]), (
|
||||
"Expected search queries using proper nouns and chat history for context, but got: " + question
|
||||
)
|
||||
else:
|
||||
assert False, (
|
||||
"Expected search queries using proper nouns and chat history for context, but got: " + question
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -312,6 +314,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor lies when it doesn't know the answer")
|
||||
@pytest.mark.chatquality
|
||||
def test_refuse_answering_unanswerable_question(loaded_model):
|
||||
"Chat actor should not try make up answers to unanswerable questions."
|
||||
|
@ -436,7 +439,6 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor doesn't ask clarifying questions when context is insufficient")
|
||||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
|
||||
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
|
|
@ -14,7 +14,7 @@ from tests.helpers import ConversationFactory
|
|||
SKIP_TESTS = True
|
||||
pytestmark = pytest.mark.skipif(
|
||||
SKIP_TESTS,
|
||||
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.",
|
||||
reason="Disable in CI to avoid long test runs.",
|
||||
)
|
||||
|
||||
fake = Faker()
|
||||
|
@ -47,7 +47,7 @@ def populate_chat_history(message_list, user):
|
|||
@pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat):
|
||||
def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_chat):
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true')
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
@ -338,7 +338,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(client_off
|
|||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
assert "23" in response_message
|
||||
assert "26" in response_message
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -514,7 +514,7 @@ async def test_get_correct_tools_general(client_offline_chat):
|
|||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_get_correct_tools_with_chat_history(client_offline_chat):
|
||||
async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
user_query = "What's the latest in the Israel/Palestine conflict?"
|
||||
chat_log = [
|
||||
|
@ -525,7 +525,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat):
|
|||
),
|
||||
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []),
|
||||
]
|
||||
chat_history = populate_chat_history(chat_log)
|
||||
chat_history = populate_chat_history(chat_log, default_user2)
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, chat_history)
|
Loading…
Reference in a new issue