Replace Falcon 🦅 model with Llama V2 🦙 for offline chat (#352)

* Working example with LlamaV2 running locally on my machine

- Download from huggingface
- Plug in to GPT4All
- Update prompts to fit the llama format

* Add appropriate prompts for extracting questions based on a query based on llama format

* Rename Falcon to Llama and make some improvements to the extract_questions flow

* Do further tuning to extract question prompts and unit tests

* Disable extracting questions dynamically from Llama, as results are still unreliable
This commit is contained in:
sabaimran 2023-07-28 03:51:20 +00:00 committed by GitHub
parent 55965eea7d
commit 124d97c26d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 248 additions and 141 deletions

View file

@ -58,7 +58,7 @@ dependencies = [
"pypdf >= 3.9.0",
"requests >= 2.26.0",
"bs4 >= 0.0.1",
"gpt4all==1.0.5",
"gpt4all >= 1.0.7",
]
dynamic = ["version"]

View file

@ -229,7 +229,7 @@
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Setup offline chat (Falcon 7B)</p>
<p class="card-description">Setup offline chat (Llama V2)</p>
</div>
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">

View file

@ -1,6 +1,5 @@
from typing import Union, List
from datetime import datetime
import sys
import logging
from threading import Thread
@ -8,7 +7,6 @@ from langchain.schema import ChatMessage
from gpt4all import GPT4All
from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
from khoj.processor.conversation import prompts
from khoj.utils.constants import empty_escape_sequences
@ -16,20 +14,21 @@ from khoj.utils.constants import empty_escape_sequences
logger = logging.getLogger(__name__)
def extract_questions_falcon(
def extract_questions_offline(
text: str,
model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
loaded_model: Union[GPT4All, None] = None,
conversation_log={},
use_history: bool = False,
run_extraction: bool = False,
use_history: bool = True,
should_extract_questions: bool = True,
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
all_questions = text.split("? ")
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
if not run_extraction:
if not should_extract_questions:
return all_questions
gpt4all_model = loaded_model or GPT4All(model)
@ -38,51 +37,85 @@ def extract_questions_falcon(
chat_history = ""
if use_history:
chat_history = "".join(
[
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj"
]
)
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
prompt = prompts.extract_questions_falcon.format(
chat_history=chat_history,
text=text,
current_date = datetime.now().strftime("%Y-%m-%d")
last_year = datetime.now().year - 1
last_christmas_date = f"{last_year}-12-25"
next_christmas_date = f"{datetime.now().year}-12-25"
system_prompt = prompts.extract_questions_system_prompt_llamav2.format(
message=(prompts.system_prompt_message_extract_questions_llamav2)
)
message = prompts.general_conversation_falcon.format(query=prompt)
response = gpt4all_model.generate(message, max_tokens=200, top_k=2)
example_questions = prompts.extract_questions_llamav2_sample.format(
query=text,
chat_history=chat_history,
current_date=current_date,
last_year=last_year,
last_christmas_date=last_christmas_date,
next_christmas_date=next_christmas_date,
)
message = system_prompt + example_questions
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0)
# Extract, Clean Message from GPT's Response
try:
# This will expect to be a list with a single string with a list of questions
questions = (
str(response)
.strip(empty_escape_sequences)
.replace("['", '["')
.replace("<s>", "")
.replace("</s>", "")
.replace("']", '"]')
.replace("', '", '", "')
.replace('["', "")
.replace('"]', "")
.split('", "')
.split("? ")
)
questions = [q + "?" for q in questions[:-1]] + [questions[-1]]
questions = filter_questions(questions)
except:
logger.warning(f"Falcon returned invalid JSON. Falling back to using user message as search query.\n{response}")
logger.warning(f"Llama returned invalid JSON. Falling back to using user message as search query.\n{response}")
return all_questions
logger.debug(f"Extracted Questions by Falcon: {questions}")
logger.debug(f"Extracted Questions by Llama: {questions}")
questions.extend(all_questions)
return questions
def converse_falcon(
def filter_questions(questions: List[str]):
# Skip questions that seem to be apologizing for not being able to answer the question
hint_words = [
"sorry",
"apologize",
"unable",
"can't",
"cannot",
"don't know",
"don't understand",
"do not know",
"do not understand",
]
filtered_questions = []
for q in questions:
if not any([word in q.lower() for word in hint_words]):
filtered_questions.append(q)
return filtered_questions
def converse_offline(
references,
user_query,
conversation_log={},
model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
loaded_model: Union[GPT4All, None] = None,
completion_func=None,
) -> ThreadedGenerator:
"""
Converse with user using Falcon
Converse with user using Llama
"""
gpt4all_model = loaded_model or GPT4All(model)
# Initialize Variables
@ -92,18 +125,18 @@ def converse_falcon(
# Get Conversation Primer appropriate to Conversation Type
# TODO If compiled_references_message is too long, we need to truncate it.
if compiled_references_message == "":
conversation_primer = prompts.conversation_falcon.format(query=user_query)
conversation_primer = prompts.conversation_llamav2.format(query=user_query)
else:
conversation_primer = prompts.notes_conversation.format(
current_date=current_date, query=user_query, references=compiled_references_message
conversation_primer = prompts.notes_conversation_llamav2.format(
query=user_query, references=compiled_references_message
)
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
prompts.personality.format(),
prompts.system_prompt_message_llamav2,
conversation_log,
model_name="text-davinci-001", # This isn't actually the model, but this helps us get an approximate encoding to run message truncation.
model_name=model,
)
g = ThreadedGenerator(references, completion_func=completion_func)
@ -113,24 +146,22 @@ def converse_falcon(
def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
user_message = messages[0]
system_message = messages[-1]
user_message = messages[-1]
system_message = messages[0]
conversation_history = messages[1:-1]
formatted_messages = [
prompts.chat_history_falcon_from_assistant.format(message=system_message)
prompts.chat_history_llamav2_from_assistant.format(message=message.content)
if message.role == "assistant"
else prompts.chat_history_falcon_from_user.format(message=message.content)
else prompts.chat_history_llamav2_from_user.format(message=message.content)
for message in conversation_history
]
chat_history = "".join(formatted_messages)
full_message = system_message.content + chat_history + user_message.content
prompted_message = prompts.general_conversation_falcon.format(query=full_message)
response_iterator = model.generate(
prompted_message, streaming=True, max_tokens=256, top_k=1, temp=0, repeat_penalty=2.0
)
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
prompted_message = templated_system_message + chat_history + templated_user_message
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=2000)
for response in response_iterator:
logger.info(response)
g.send(response)

View file

@ -0,0 +1,3 @@
model_name_to_url = {
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q3_K_M.bin"
}

View file

@ -0,0 +1,33 @@
import os
import logging
import requests
from gpt4all import GPT4All
import tqdm
from khoj.processor.conversation.gpt4all import model_metadata
logger = logging.getLogger(__name__)
def download_model(model_name):
url = model_metadata.model_name_to_url.get(model_name)
if not url:
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
return GPT4All(model_name)
filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}")
if os.path.exists(filename):
return GPT4All(model_name)
try:
os.makedirs(os.path.dirname(filename), exist_ok=True)
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
return GPT4All(model_name)
except Exception as e:
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}")
return None

View file

@ -18,34 +18,54 @@ Question: {query}
""".strip()
)
general_conversation_falcon = PromptTemplate.from_template(
"""
Using your general knowledge and our past conversations as context, answer the following question.
### Instruct:
{query}
### Response:
""".strip()
)
system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant.
Using your general knowledge and our past conversations as context, answer the following question."""
chat_history_falcon_from_user = PromptTemplate.from_template(
system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant.
- When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
- Try to be as specific as possible. For example, rather than use "they" or "it", use the name of the person or thing you are referring to.
- Write the question as if you can search for the answer on the user's personal notes.
- Add as much context from the previous questions and notes as required into your search queries.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- Provide search queries as a list of questions
What follow-up questions, if any, will you need to ask to answer the user's question?
"""
system_prompt_llamav2 = PromptTemplate.from_template(
"""
### Human:
<s>[INST] <<SYS>>
{message}
""".strip()
<</SYS>>Hi there! [/INST] Hello! How can I help you today? </s>"""
)
chat_history_falcon_from_assistant = PromptTemplate.from_template(
extract_questions_system_prompt_llamav2 = PromptTemplate.from_template(
"""
### Assistant:
<s>[INST] <<SYS>>
{message}
<</SYS>>[/INST]</s>"""
)
general_conversation_llamav2 = PromptTemplate.from_template(
"""
<s>[INST]{query}[/INST]
""".strip()
)
conversation_falcon = PromptTemplate.from_template(
chat_history_llamav2_from_user = PromptTemplate.from_template(
"""
Using our past conversations as context, answer the following question.
<s>[INST]{message}[/INST]
""".strip()
)
Question: {query}
chat_history_llamav2_from_assistant = PromptTemplate.from_template(
"""
{message}</s>
""".strip()
)
conversation_llamav2 = PromptTemplate.from_template(
"""
<s>[INST]{query}[/INST]
""".strip()
)
@ -63,13 +83,10 @@ Question: {query}
""".strip()
)
notes_conversation_falcon = PromptTemplate.from_template(
notes_conversation_llamav2 = PromptTemplate.from_template(
"""
Using the notes and our past conversations as context, answer the following question. If the answer is not contained within the notes, say "I don't know."
Notes:
{references}
Question: {query}
""".strip()
)
@ -109,37 +126,22 @@ Question: {user_query}
Answer (in second person):"""
)
extract_questions_falcon = PromptTemplate.from_template(
extract_questions_llamav2_sample = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes.
- The user will provide their questions and answers to you for context.
- Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
What searches, if any, will you need to perform to answer the users question?
Q: How was my trip to Cambodia?
["How was my trip to Cambodia?"]
A: The trip was amazing. I went to the Angkor Wat temple and it was beautiful.
Q: Who did i visit that temple with?
["Who did I visit the Angkor Wat Temple in Cambodia with?"]
A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
Q: How many tennis balls fit in the back of a 2002 Honda Civic?
["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]
A: 1085 tennis balls will fit in the trunk of a Honda Civic
<s>[INST]<<SYS>>Current Date: {current_date}<</SYS>>[/INST]</s>
<s>[INST]<<SYS>>
Use these notes from the user's previous conversations to provide a response:
{chat_history}
Q: {text}
<</SYS>>[/INST]</s>
<s>[INST]How was my trip to Cambodia?[/INST][]</s>
<s>[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia?</s>
<s>[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have?</s>
<s>[INST]How many tennis balls fit in the back of a 2002 Honda Civic?[/INST]What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?</s>
<s>[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
<s>[INST]How are you feeling today?[/INST]</s>
<s>[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age?</s>
<s>[INST]{query}[/INST]
"""
)

View file

@ -13,7 +13,7 @@ import queue
from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__)
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "text-davinci-001": 910}
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 850}
class ThreadedGenerator:
@ -102,7 +102,10 @@ def generate_chatml_messages_with_context(
def truncate_messages(messages, max_prompt_size, model_name):
"""Truncate messages to fit within max prompt size supported by model"""
encoder = tiktoken.encoding_for_model(model_name)
try:
encoder = tiktoken.encoding_for_model(model_name)
except KeyError:
encoder = tiktoken.encoding_for_model("text-davinci-001")
tokens = sum([len(encoder.encode(message.content)) for message in messages])
while tokens > max_prompt_size and len(messages) > 1:
messages.pop()

View file

@ -38,7 +38,7 @@ from khoj.utils.yaml import save_config_to_file_updated_state
from fastapi.responses import StreamingResponse, Response
from khoj.routers.helpers import perform_chat_checks, generate_chat_response, update_telemetry_state
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_falcon, converse_falcon
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
from fastapi.requests import Request
@ -715,7 +715,9 @@ async def extract_references_and_questions(
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
else:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
inferred_queries = extract_questions_falcon(q, loaded_model=loaded_model, conversation_log=meta_log)
inferred_queries = extract_questions_offline(
q, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
)
# Collate search results as context for GPT
with timer("Searching knowledge base took", logger):

View file

@ -8,7 +8,7 @@ from fastapi import HTTPException, Request
from khoj.utils import state
from khoj.utils.helpers import timer, log_telemetry
from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
logger = logging.getLogger(__name__)
@ -111,7 +111,7 @@ def generate_chat_response(
)
else:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
chat_response = converse_falcon(
chat_response = converse_offline(
references=compiled_references,
user_query=q,
loaded_model=loaded_model,

View file

@ -5,8 +5,7 @@ from enum import Enum
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
from gpt4all import GPT4All
from khoj.processor.conversation.gpt4all.utils import download_model
# External Packages
import torch
@ -79,7 +78,7 @@ class SearchModels:
@dataclass
class GPT4AllProcessorConfig:
chat_model: Optional[str] = "ggml-model-gpt4all-falcon-q4_0.bin"
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
loaded_model: Union[Any, None] = None
@ -96,7 +95,7 @@ class ConversationProcessorConfigModel:
self.meta_log: dict = {}
if not self.openai_model and self.enable_offline_chat:
self.gpt4all_model.loaded_model = GPT4All(self.gpt4all_model.chat_model) # type: ignore
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
else:
self.gpt4all_model.loaded_model = None

View file

@ -16,14 +16,18 @@ from freezegun import freeze_time
from gpt4all import GPT4All
# Internal Packages
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon, extract_questions_falcon
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, extract_questions_offline, filter_questions
from khoj.processor.conversation.gpt4all.utils import download_model
from khoj.processor.conversation.utils import message_to_log
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
@pytest.fixture(scope="session")
def loaded_model():
return GPT4All("ggml-model-gpt4all-falcon-q4_0.bin")
download_model(MODEL_NAME)
return GPT4All(MODEL_NAME)
freezegun.configure(extend_ignore_list=["transformers"])
@ -35,24 +39,40 @@ freezegun.configure(extend_ignore_list=["transformers"])
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
# Act
response = extract_questions_falcon(
"Where did I go for dinner yesterday?", loaded_model=loaded_model, run_extraction=True
)
response = extract_questions_offline("Where did I go for dinner yesterday?", loaded_model=loaded_model)
assert len(response) >= 1
assert response[-1] == "Where did I go for dinner yesterday?"
assert any(
[
"dt>='1984-04-01'" in response[0] and "dt<'1984-04-02'" in response[0],
"dt>='1984-04-01'" in response[0] and "dt<='1984-04-01'" in response[0],
'dt>="1984-04-01"' in response[0] and 'dt<"1984-04-02"' in response[0],
'dt>="1984-04-01"' in response[0] and 'dt<="1984-04-01"' in response[0],
]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
@pytest.mark.chatquality
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
# Act
response = extract_questions_falcon("Which countries did I visit last month?", loaded_model=loaded_model)
response = extract_questions_offline("Which countries did I visit last month?", loaded_model=loaded_model)
# Assert
assert len(response) == 1
assert response == ["Which countries did I visit last month?"]
assert len(response) >= 1
# The user query should be the last question in the response
assert response[-1] == ["Which countries did I visit last month?"]
assert any(
[
"dt>='1984-03-01'" in response[0] and "dt<'1984-04-01'" in response[0],
"dt>='1984-03-01'" in response[0] and "dt<='1984-03-31'" in response[0],
'dt>="1984-03-01"' in response[0] and 'dt<"1984-04-01"' in response[0],
'dt>="1984-03-01"' in response[0] and 'dt<="1984-03-31"' in response[0],
]
)
# ----------------------------------------------------------------------------------------------------
@ -60,9 +80,7 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model):
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_year(loaded_model):
# Act
response = extract_questions_falcon(
"Which countries have I visited this year?", loaded_model=loaded_model, run_extraction=True
)
response = extract_questions_offline("Which countries have I visited this year?", loaded_model=loaded_model)
# Assert
assert len(response) >= 1
@ -73,25 +91,26 @@ def test_extract_question_with_date_filter_from_relative_year(loaded_model):
@pytest.mark.chatquality
def test_extract_multiple_explicit_questions_from_message(loaded_model):
# Act
response = extract_questions_falcon("What is the Sun? What is the Moon?", loaded_model=loaded_model)
response = extract_questions_offline("What is the Sun? What is the Moon?", loaded_model=loaded_model)
# Assert
expected_responses = ["What is the Sun?", "What is the Moon?"]
assert len(response) == 2
assert expected_responses == response
assert len(response) >= 2
assert expected_responses[0] == response[-2]
assert expected_responses[1] == response[-1]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message(loaded_model):
# Act
response = extract_questions_falcon("Is Morpheus taller than Neo?", loaded_model=loaded_model, run_extraction=True)
response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model)
# Assert
expected_responses = [
("morpheus", "neo"),
("morpheus", "neo", "height", "taller", "shorter"),
]
assert len(response) == 2
assert len(response) == 3
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
"Expected two search queries in response but got: " + response[0]
)
@ -106,18 +125,19 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
]
# Act
response = extract_questions_falcon(
response = extract_questions_offline(
"Does he have any sons?",
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
run_extraction=True,
use_history=True,
)
expected_responses = [
"do not have",
"clarify",
"am sorry",
"Vader",
"sons",
"son",
"Darth",
"children",
]
# Assert
@ -128,7 +148,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
# @pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
@pytest.mark.chatquality
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
# Arrange
@ -137,17 +157,24 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
]
# Act
response = extract_questions_falcon(
response = extract_questions_offline(
"Is she a Jedi?",
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
run_extraction=True,
use_history=True,
)
expected_responses = [
"Leia",
"Vader",
"daughter",
]
# Assert
assert len(response) == 1
assert "Leia" in response[0]
assert len(response) >= 1
assert any([expected_response in response[0] for expected_response in expected_responses]), (
"Expected chat actor to mention Darth Vader's daughter, but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@ -160,10 +187,9 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
]
# Act
response = extract_questions_falcon(
response = extract_questions_offline(
"What was the Pizza place we ate at over there?",
conversation_log=populate_chat_history(message_list),
run_extraction=True,
loaded_model=loaded_model,
)
@ -185,7 +211,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Hello, my name is Testatron. Who are you?",
loaded_model=loaded_model,
@ -201,7 +227,6 @@ def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor isn't really good at proper nouns yet.")
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model):
"Chat actor needs to use context in previous notes and chat history to answer question"
@ -216,7 +241,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model)
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
@ -241,7 +266,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=[
"Testatron was born on 1st April 1984 in Testville."
], # Assume context retrieved from notes for the user_query
@ -257,7 +282,6 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor is rather liable to lying.")
@pytest.mark.chatquality
def test_refuse_answering_unanswerable_question(loaded_model):
"Chat actor should not try make up answers to unanswerable questions."
@ -268,7 +292,7 @@ def test_refuse_answering_unanswerable_question(loaded_model):
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
@ -309,7 +333,7 @@ Expenses:Food:Dining 10.00 USD""",
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I have for Dinner today?",
loaded_model=loaded_model,
@ -341,7 +365,7 @@ Expenses:Food:Dining 10.00 USD""",
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="How much did I spend on dining this year?",
loaded_model=loaded_model,
@ -365,7 +389,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines",
conversation_log=populate_chat_history(message_list),
@ -382,7 +406,6 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.")
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
@ -397,7 +420,7 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
]
# Act
response_gen = converse_falcon(
response_gen = converse_offline(
references=context, # Assume context retrieved from notes for the user_query
user_query="How many kids does my older sister have?",
loaded_model=loaded_model,
@ -411,6 +434,17 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
)
def test_filter_questions():
test_questions = [
"I don't know how to answer that",
"I cannot answer anything about the nuclear secrets",
"Who is on the basketball team?",
]
filtered_questions = filter_questions(test_questions)
assert len(filtered_questions) == 1
assert filtered_questions[0] == "Who is on the basketball team?"
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):