diff --git a/pyproject.toml b/pyproject.toml index d443568d..11520f9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,8 +62,7 @@ dependencies = [ "pymupdf >= 1.23.5", "django == 4.2.10", "authlib == 1.2.1", - "gpt4all == 2.1.0; platform_system == 'Linux' and platform_machine == 'x86_64'", - "gpt4all == 2.1.0; platform_system == 'Windows' or platform_system == 'Darwin'", + "llama-cpp-python == 0.2.56", "itsdangerous == 2.1.2", "httpx == 0.25.0", "pgvector == 0.2.4", diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 3f8f50b4..dc532d78 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -160,7 +160,7 @@ class ChatModelOptions(BaseModel): max_prompt_size = models.IntegerField(default=None, null=True, blank=True) tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) - chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf") + chat_model = models.CharField(max_length=200, default="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index d1469ecf..09a6d55e 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -1,12 +1,13 @@ import logging -from collections import deque from datetime import datetime from threading import Thread from typing import Any, Iterator, List, Union from langchain.schema import ChatMessage +from llama_cpp import Llama from khoj.processor.conversation import prompts +from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import ( ThreadedGenerator, generate_chatml_messages_with_context, @@ -21,7 +22,7 @@ logger = logging.getLogger(__name__) def extract_questions_offline( text: str, - model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf", + model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", loaded_model: Union[Any, None] = None, conversation_log={}, use_history: bool = True, @@ -31,22 +32,14 @@ def extract_questions_offline( """ Infer search queries to retrieve relevant notes to answer user query """ - try: - from gpt4all import GPT4All - except ModuleNotFoundError as e: - logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") - raise e - - # Assert that loaded_model is either None or of type GPT4All - assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None" - all_questions = text.split("? ") all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]] if not should_extract_questions: return all_questions - gpt4all_model = loaded_model or GPT4All(model) + assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" + offline_chat_model = loaded_model or download_model(model) location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown" @@ -75,10 +68,12 @@ def extract_questions_offline( next_christmas_date=next_christmas_date, location=location, ) - message = system_prompt + example_questions + messages = generate_chatml_messages_with_context(example_questions, system_message=system_prompt, model_name=model) + state.chat_lock.acquire() try: - response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0, n_batch=512) + response = offline_chat_model.create_chat_completion(messages, max_tokens=200, top_k=2, temp=0) + response = response[0]["choices"][0]["message"]["content"] finally: state.chat_lock.release() @@ -133,7 +128,7 @@ def converse_offline( references=[], online_results=[], conversation_log={}, - model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf", + model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", loaded_model: Union[Any, None] = None, completion_func=None, conversation_commands=[ConversationCommand.Default], @@ -145,15 +140,9 @@ def converse_offline( """ Converse with user using Llama """ - try: - from gpt4all import GPT4All - except ModuleNotFoundError as e: - logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") - raise e - - assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None" - gpt4all_model = loaded_model or GPT4All(model) # Initialize Variables + assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" + offline_chat_model = loaded_model or download_model(model) compiled_references_message = "\n\n".join({f"{item}" for item in references}) conversation_primer = prompts.query_prompt.format(query=user_query) @@ -191,72 +180,44 @@ def converse_offline( prompts.system_prompt_message_gpt4all.format(current_date=current_date), conversation_log, model_name=model, + loaded_model=offline_chat_model, max_prompt_size=max_prompt_size, tokenizer_name=tokenizer_name, ) g = ThreadedGenerator(references, online_results, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, gpt4all_model)) + t = Thread(target=llm_thread, args=(g, messages, offline_chat_model)) t.start() return g def llm_thread(g, messages: List[ChatMessage], model: Any): - user_message = messages[-1] - system_message = messages[0] - conversation_history = messages[1:-1] - - formatted_messages = [ - prompts.khoj_message_gpt4all.format(message=message.content) - if message.role == "assistant" - else prompts.user_message_gpt4all.format(message=message.content) - for message in conversation_history - ] - stop_phrases = ["", "INST]", "Notes:"] - chat_history = "".join(formatted_messages) - templated_system_message = prompts.system_prompt_gpt4all.format(message=system_message.content) - templated_user_message = prompts.user_message_gpt4all.format(message=user_message.content) - prompted_message = templated_system_message + chat_history + templated_user_message - response_queue: deque[str] = deque(maxlen=3) # Create a response queue with a maximum length of 3 - hit_stop_phrase = False state.chat_lock.acquire() - response_iterator = send_message_to_model_offline(prompted_message, loaded_model=model, streaming=True) try: + response_iterator = send_message_to_model_offline( + messages, loaded_model=model, stop=stop_phrases, streaming=True + ) for response in response_iterator: - response_queue.append(response) - hit_stop_phrase = any(stop_phrase in "".join(response_queue) for stop_phrase in stop_phrases) - if hit_stop_phrase: - logger.debug(f"Stop response as hit stop phrase: {''.join(response_queue)}") - break - # Start streaming the response at a lag once the queue is full - # This allows stop word testing before sending the response - if len(response_queue) == response_queue.maxlen: - g.send(response_queue[0]) + g.send(response["choices"][0]["delta"].get("content", "")) finally: - if not hit_stop_phrase: - if len(response_queue) == response_queue.maxlen: - # remove already sent reponse chunk - response_queue.popleft() - # send the remaining response - g.send("".join(response_queue)) state.chat_lock.release() g.close() def send_message_to_model_offline( - message, loaded_model=None, model="mistral-7b-instruct-v0.1.Q4_0.gguf", streaming=False, system_message="" -) -> str: - try: - from gpt4all import GPT4All - except ModuleNotFoundError as e: - logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") - raise e - - assert loaded_model is None or isinstance(loaded_model, GPT4All), "loaded_model must be of type GPT4All or None" - gpt4all_model = loaded_model or GPT4All(model) - - return gpt4all_model.generate( - system_message + message, max_tokens=200, top_k=2, temp=0, n_batch=512, streaming=streaming - ) + messages: List[ChatMessage], + loaded_model=None, + model="NousResearch/Hermes-2-Pro-Mistral-7B-GGUF", + streaming=False, + stop=[], +): + assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured" + offline_chat_model = loaded_model or download_model(model) + messages_dict = [{"role": message.role, "content": message.content} for message in messages] + response = offline_chat_model.create_chat_completion(messages_dict, stop=stop, stream=streaming) + if streaming: + return response + else: + return response["choices"][0]["message"].get("content", "") diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index 9a2223c6..c60498ee 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -1,43 +1,53 @@ +import glob import logging +import os + +from huggingface_hub.constants import HF_HUB_CACHE from khoj.utils import state logger = logging.getLogger(__name__) -def download_model(model_name: str): - try: - import gpt4all - except ModuleNotFoundError as e: - logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") - raise e +def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf"): + from llama_cpp.llama import Llama + + # Initialize Model Parameters + kwargs = {"n_threads": 4, "n_ctx": 4096, "verbose": False} # Decide whether to load model to GPU or CPU - chat_model_config = None + device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu" + kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0 + + # Check if the model is already downloaded + model_path = load_model_from_cache(repo_id, filename) try: - # Download the chat model and its config - chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True) - - # Try load chat model to GPU if: - # 1. Loading chat model to GPU isn't disabled via CLI and - # 2. Machine has GPU - # 3. GPU has enough free memory to load the chat model with max context length of 4096 - device = ( - "gpu" - if state.chat_on_gpu and gpt4all.pyllmodel.LLModel().list_gpu(chat_model_config["path"], 4096) - else "cpu" - ) - except ValueError: - device = "cpu" - except Exception as e: - if chat_model_config is None: - device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory - logger.debug(f"Unable to download model config from gpt4all website: {e}") + if model_path: + chat_model = Llama(model_path, **kwargs) else: - raise e + Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) + except: + # Load model on CPU if GPU is not available + kwargs["n_gpu_layers"], device = 0, "cpu" + if model_path: + chat_model = Llama(model_path, **kwargs) + else: + chat_model = Llama.from_pretrained(repo_id=repo_id, filename=filename, **kwargs) - # Now load the downloaded chat model onto appropriate device - chat_model = gpt4all.GPT4All(model_name=model_name, n_ctx=4096, device=device, allow_download=False) - logger.debug(f"Loaded chat model to {device.upper()}.") + logger.debug(f"{'Loaded' if model_path else 'Downloaded'} chat model to {device.upper()}") return chat_model + + +def load_model_from_cache(repo_id: str, filename: str, repo_type="models"): + # Construct the path to the model file in the cache directory + repo_org, repo_name = repo_id.split("/") + object_id = "--".join([repo_type, repo_org, repo_name]) + model_path = os.path.sep.join([HF_HUB_CACHE, object_id, "snapshots", "**", filename]) + + # Check if the model file exists + paths = glob.glob(model_path) + if paths: + return paths[0] + else: + return None diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 15a4970e..ff1ca1e1 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -3,29 +3,28 @@ import logging import queue from datetime import datetime from time import perf_counter -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import tiktoken from langchain.schema import ChatMessage +from llama_cpp.llama import Llama from transformers import AutoTokenizer from khoj.database.adapters import ConversationAdapters from khoj.database.models import ClientApplication, KhojUser +from khoj.processor.conversation.offline.utils import download_model from khoj.utils.helpers import is_none_or_empty, merge_dicts logger = logging.getLogger(__name__) model_to_prompt_size = { - "gpt-3.5-turbo": 3000, - "gpt-3.5-turbo-0125": 3000, - "gpt-4-0125-preview": 7000, - "gpt-4-turbo-preview": 7000, - "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548, - "mistral-7b-instruct-v0.1.Q4_0.gguf": 1548, -} -model_to_tokenizer = { - "llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer", - "mistral-7b-instruct-v0.1.Q4_0.gguf": "mistralai/Mistral-7B-Instruct-v0.1", + "gpt-3.5-turbo": 12000, + "gpt-3.5-turbo-0125": 12000, + "gpt-4-0125-preview": 20000, + "gpt-4-turbo-preview": 20000, + "TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500, + "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500, } +model_to_tokenizer: Dict[str, str] = {} class ThreadedGenerator: @@ -134,9 +133,10 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response} def generate_chatml_messages_with_context( user_message, - system_message, + system_message=None, conversation_log={}, model_name="gpt-3.5-turbo", + loaded_model: Optional[Llama] = None, max_prompt_size=None, tokenizer_name=None, ): @@ -159,7 +159,7 @@ def generate_chatml_messages_with_context( chat_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n" chat_logs += [chat["message"] + chat_notes] - rest_backnforths = [] + rest_backnforths: List[ChatMessage] = [] # Extract in reverse chronological order for user_msg, assistant_msg in zip(chat_logs[-2::-2], chat_logs[::-2]): if len(rest_backnforths) >= 2 * lookback_turns: @@ -176,22 +176,31 @@ def generate_chatml_messages_with_context( messages.append(ChatMessage(content=system_message, role="system")) # Truncate oldest messages from conversation history until under max supported prompt size by model - messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name) + messages = truncate_messages(messages, max_prompt_size, model_name, loaded_model, tokenizer_name) # Return message in chronological order return messages[::-1] def truncate_messages( - messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None + messages: list[ChatMessage], + max_prompt_size, + model_name: str, + loaded_model: Optional[Llama] = None, + tokenizer_name=None, ) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" try: - if model_name.startswith("gpt-"): + if loaded_model: + encoder = loaded_model.tokenizer() + elif model_name.startswith("gpt-"): encoder = tiktoken.encoding_for_model(model_name) else: - encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name]) + try: + encoder = download_model(model_name).tokenizer() + except: + encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name]) except: default_tokenizer = "hf-internal-testing/llama-tokenizer" encoder = AutoTokenizer.from_pretrained(default_tokenizer) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 5c3637b6..e1267291 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -370,27 +370,31 @@ async def send_message_to_model_wrapper( if conversation_config is None: raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.") - truncated_messages = generate_chatml_messages_with_context( - user_message=message, system_message=system_message, model_name=conversation_config.chat_model - ) + chat_model = conversation_config.chat_model if conversation_config.model_type == "offline": if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None: - state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model) + state.gpt4all_processor_config = GPT4AllProcessorModel(chat_model) loaded_model = state.gpt4all_processor_config.loaded_model + truncated_messages = generate_chatml_messages_with_context( + user_message=message, system_message=system_message, model_name=chat_model, loaded_model=loaded_model + ) + return send_message_to_model_offline( - message=truncated_messages[-1].content, + messages=truncated_messages, loaded_model=loaded_model, - model=conversation_config.chat_model, + model=chat_model, streaming=False, - system_message=truncated_messages[0].content, ) elif conversation_config.model_type == "openai": openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() api_key = openai_chat_config.api_key - chat_model = conversation_config.chat_model + truncated_messages = generate_chatml_messages_with_context( + user_message=message, system_message=system_message, model_name=chat_model + ) + openai_response = send_message_to_model( messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type ) diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 79a8957e..5184b85a 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -75,10 +75,7 @@ class GPT4AllProcessorConfig: class GPT4AllProcessorModel: - def __init__( - self, - chat_model: str = "mistral-7b-instruct-v0.1.Q4_0.gguf", - ): + def __init__(self, chat_model: str = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"): self.chat_model = chat_model self.loaded_model = None try: diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index b4c00df4..f7747a99 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -6,7 +6,7 @@ empty_escape_sequences = "\n|\r|\t| " app_env_filepath = "~/.khoj/env" telemetry_server = "https://khoj.beta.haletic.com/v1/telemetry" content_directory = "~/.khoj/content/" -default_offline_chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf" +default_offline_chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF" default_online_chat_model = "gpt-4-turbo-preview" empty_config = { diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 7218cca9..0d17e9e4 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -91,7 +91,7 @@ class OpenAIProcessorConfig(ConfigBase): class OfflineChatProcessorConfig(ConfigBase): enable_offline_chat: Optional[bool] = False - chat_model: Optional[str] = "mistral-7b-instruct-v0.1.Q4_0.gguf" + chat_model: Optional[str] = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF" class ConversationProcessorConfig(ConfigBase): diff --git a/tests/helpers.py b/tests/helpers.py index 321e08cf..26c7e2af 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -40,9 +40,9 @@ class ChatModelOptionsFactory(factory.django.DjangoModelFactory): class Meta: model = ChatModelOptions - max_prompt_size = 2000 + max_prompt_size = 3500 tokenizer = None - chat_model = "mistral-7b-instruct-v0.1.Q4_0.gguf" + chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF" model_type = "offline" diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_offline_chat_actors.py similarity index 93% rename from tests/test_gpt4all_chat_actors.py rename to tests/test_offline_chat_actors.py index 5cce4fc5..f82a5fe5 100644 --- a/tests/test_gpt4all_chat_actors.py +++ b/tests/test_offline_chat_actors.py @@ -5,18 +5,12 @@ import pytest SKIP_TESTS = True pytestmark = pytest.mark.skipif( SKIP_TESTS, - reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.", + reason="Disable in CI to avoid long test runs.", ) import freezegun from freezegun import freeze_time -try: - from gpt4all import GPT4All -except ModuleNotFoundError as e: - print("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") - - from khoj.processor.conversation.offline.chat_model import ( converse_offline, extract_questions_offline, @@ -25,14 +19,12 @@ from khoj.processor.conversation.offline.chat_model import ( from khoj.processor.conversation.offline.utils import download_model from khoj.processor.conversation.utils import message_to_log from khoj.routers.helpers import aget_relevant_output_modes - -MODEL_NAME = "mistral-7b-instruct-v0.1.Q4_0.gguf" +from khoj.utils.constants import default_offline_chat_model @pytest.fixture(scope="session") def loaded_model(): - download_model(MODEL_NAME) - return GPT4All(MODEL_NAME) + return download_model(default_offline_chat_model) freezegun.configure(extend_ignore_list=["transformers"]) @@ -40,7 +32,6 @@ freezegun.configure(extend_ignore_list=["transformers"]) # Test # ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Search actor isn't very date aware nor capable of formatting") @pytest.mark.chatquality @freeze_time("1984-04-02", ignore=["transformers"]) def test_extract_question_with_date_filter_from_relative_day(loaded_model): @@ -149,20 +140,22 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model): message_list = [ ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []), ] + query = "Does he have any sons?" # Act response = extract_questions_offline( - "Does he have any sons?", + query, conversation_log=populate_chat_history(message_list), loaded_model=loaded_model, use_history=True, ) - all_expected_in_response = [ - "Anderson", + any_expected_with_barbara = [ + "sibling", + "brother", ] - any_expected_in_response = [ + any_expected_with_anderson = [ "son", "sons", "children", @@ -170,12 +163,21 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model): # Assert assert len(response) >= 1 - assert all([expected_response in response[0] for expected_response in all_expected_in_response]), ( - "Expected chat actor to ask for clarification in response, but got: " + response[0] - ) - assert any([expected_response in response[0] for expected_response in any_expected_in_response]), ( - "Expected chat actor to ask for clarification in response, but got: " + response[0] - ) + assert response[-1] == query, "Expected last question to be the user query, but got: " + response[-1] + # Ensure the remaining generated search queries use proper nouns and chat history context + for question in response[:-1]: + if "Barbara" in question: + assert any([expected_relation in question for expected_relation in any_expected_with_barbara]), ( + "Expected search queries using proper nouns and chat history for context, but got: " + question + ) + elif "Anderson" in question: + assert any([expected_response in question for expected_response in any_expected_with_anderson]), ( + "Expected search queries using proper nouns and chat history for context, but got: " + question + ) + else: + assert False, ( + "Expected search queries using proper nouns and chat history for context, but got: " + question + ) # ---------------------------------------------------------------------------------------------------- @@ -312,6 +314,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model): # ---------------------------------------------------------------------------------------------------- +@pytest.mark.xfail(reason="Chat actor lies when it doesn't know the answer") @pytest.mark.chatquality def test_refuse_answering_unanswerable_question(loaded_model): "Chat actor should not try make up answers to unanswerable questions." @@ -436,7 +439,6 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded # ---------------------------------------------------------------------------------------------------- -@pytest.mark.xfail(reason="Chat actor doesn't ask clarifying questions when context is insufficient") @pytest.mark.chatquality def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model): "Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context" diff --git a/tests/test_gpt4all_chat_director.py b/tests/test_offline_chat_director.py similarity index 98% rename from tests/test_gpt4all_chat_director.py rename to tests/test_offline_chat_director.py index 87ed5116..ed47bed7 100644 --- a/tests/test_gpt4all_chat_director.py +++ b/tests/test_offline_chat_director.py @@ -14,7 +14,7 @@ from tests.helpers import ConversationFactory SKIP_TESTS = True pytestmark = pytest.mark.skipif( SKIP_TESTS, - reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.", + reason="Disable in CI to avoid long test runs.", ) fake = Faker() @@ -47,7 +47,7 @@ def populate_chat_history(message_list, user): @pytest.mark.xfail(AssertionError, reason="Chat director not capable of answering this question yet") @pytest.mark.chatquality @pytest.mark.django_db(transaction=True) -def test_chat_with_no_chat_history_or_retrieved_content_gpt4all(client_offline_chat): +def test_offline_chat_with_no_chat_history_or_retrieved_content(client_offline_chat): # Act response = client_offline_chat.get(f'/api/chat?q="Hello, my name is Testatron. Who are you?"&stream=true') response_message = response.content.decode("utf-8") @@ -338,7 +338,7 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes(client_off # Assert assert response.status_code == 200 - assert "23" in response_message + assert "26" in response_message # ---------------------------------------------------------------------------------------------------- @@ -514,7 +514,7 @@ async def test_get_correct_tools_general(client_offline_chat): # ---------------------------------------------------------------------------------------------------- @pytest.mark.anyio @pytest.mark.django_db(transaction=True) -async def test_get_correct_tools_with_chat_history(client_offline_chat): +async def test_get_correct_tools_with_chat_history(client_offline_chat, default_user2): # Arrange user_query = "What's the latest in the Israel/Palestine conflict?" chat_log = [ @@ -525,7 +525,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat): ), ("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []), ] - chat_history = populate_chat_history(chat_log) + chat_history = populate_chat_history(chat_log, default_user2) # Act tools = await aget_relevant_information_sources(user_query, chat_history)