mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Replace Falcon 🦅 model with Llama V2 🦙 for offline chat (#352)
* Working example with LlamaV2 running locally on my machine - Download from huggingface - Plug in to GPT4All - Update prompts to fit the llama format * Add appropriate prompts for extracting questions based on a query based on llama format * Rename Falcon to Llama and make some improvements to the extract_questions flow * Do further tuning to extract question prompts and unit tests * Disable extracting questions dynamically from Llama, as results are still unreliable
This commit is contained in:
parent
55965eea7d
commit
124d97c26d
11 changed files with 248 additions and 141 deletions
|
@ -58,7 +58,7 @@ dependencies = [
|
|||
"pypdf >= 3.9.0",
|
||||
"requests >= 2.26.0",
|
||||
"bs4 >= 0.0.1",
|
||||
"gpt4all==1.0.5",
|
||||
"gpt4all >= 1.0.7",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
|
|
@ -229,7 +229,7 @@
|
|||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p class="card-description">Setup offline chat (Falcon 7B)</p>
|
||||
<p class="card-description">Setup offline chat (Llama V2)</p>
|
||||
</div>
|
||||
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
|
||||
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from typing import Union, List
|
||||
from datetime import datetime
|
||||
import sys
|
||||
import logging
|
||||
from threading import Thread
|
||||
|
||||
|
@ -8,7 +7,6 @@ from langchain.schema import ChatMessage
|
|||
|
||||
from gpt4all import GPT4All
|
||||
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
|
@ -16,20 +14,21 @@ from khoj.utils.constants import empty_escape_sequences
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_questions_falcon(
|
||||
def extract_questions_offline(
|
||||
text: str,
|
||||
model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
|
||||
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
|
||||
loaded_model: Union[GPT4All, None] = None,
|
||||
conversation_log={},
|
||||
use_history: bool = False,
|
||||
run_extraction: bool = False,
|
||||
use_history: bool = True,
|
||||
should_extract_questions: bool = True,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
all_questions = text.split("? ")
|
||||
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
|
||||
if not run_extraction:
|
||||
|
||||
if not should_extract_questions:
|
||||
return all_questions
|
||||
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
|
@ -38,51 +37,85 @@ def extract_questions_falcon(
|
|||
chat_history = ""
|
||||
|
||||
if use_history:
|
||||
chat_history = "".join(
|
||||
[
|
||||
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\nA: {chat["message"]}\n\n'
|
||||
for chat in conversation_log.get("chat", [])[-4:]
|
||||
if chat["by"] == "khoj"
|
||||
]
|
||||
)
|
||||
for chat in conversation_log.get("chat", [])[-4:]:
|
||||
if chat["by"] == "khoj":
|
||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||
chat_history += f"A: {chat['message']}\n"
|
||||
|
||||
prompt = prompts.extract_questions_falcon.format(
|
||||
chat_history=chat_history,
|
||||
text=text,
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
last_year = datetime.now().year - 1
|
||||
last_christmas_date = f"{last_year}-12-25"
|
||||
next_christmas_date = f"{datetime.now().year}-12-25"
|
||||
system_prompt = prompts.extract_questions_system_prompt_llamav2.format(
|
||||
message=(prompts.system_prompt_message_extract_questions_llamav2)
|
||||
)
|
||||
message = prompts.general_conversation_falcon.format(query=prompt)
|
||||
response = gpt4all_model.generate(message, max_tokens=200, top_k=2)
|
||||
example_questions = prompts.extract_questions_llamav2_sample.format(
|
||||
query=text,
|
||||
chat_history=chat_history,
|
||||
current_date=current_date,
|
||||
last_year=last_year,
|
||||
last_christmas_date=last_christmas_date,
|
||||
next_christmas_date=next_christmas_date,
|
||||
)
|
||||
message = system_prompt + example_questions
|
||||
response = gpt4all_model.generate(message, max_tokens=200, top_k=2, temp=0)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
try:
|
||||
# This will expect to be a list with a single string with a list of questions
|
||||
questions = (
|
||||
str(response)
|
||||
.strip(empty_escape_sequences)
|
||||
.replace("['", '["')
|
||||
.replace("<s>", "")
|
||||
.replace("</s>", "")
|
||||
.replace("']", '"]')
|
||||
.replace("', '", '", "')
|
||||
.replace('["', "")
|
||||
.replace('"]', "")
|
||||
.split('", "')
|
||||
.split("? ")
|
||||
)
|
||||
questions = [q + "?" for q in questions[:-1]] + [questions[-1]]
|
||||
questions = filter_questions(questions)
|
||||
except:
|
||||
logger.warning(f"Falcon returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||
logger.warning(f"Llama returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||
return all_questions
|
||||
logger.debug(f"Extracted Questions by Falcon: {questions}")
|
||||
logger.debug(f"Extracted Questions by Llama: {questions}")
|
||||
questions.extend(all_questions)
|
||||
return questions
|
||||
|
||||
|
||||
def converse_falcon(
|
||||
def filter_questions(questions: List[str]):
|
||||
# Skip questions that seem to be apologizing for not being able to answer the question
|
||||
hint_words = [
|
||||
"sorry",
|
||||
"apologize",
|
||||
"unable",
|
||||
"can't",
|
||||
"cannot",
|
||||
"don't know",
|
||||
"don't understand",
|
||||
"do not know",
|
||||
"do not understand",
|
||||
]
|
||||
filtered_questions = []
|
||||
for q in questions:
|
||||
if not any([word in q.lower() for word in hint_words]):
|
||||
filtered_questions.append(q)
|
||||
|
||||
return filtered_questions
|
||||
|
||||
|
||||
def converse_offline(
|
||||
references,
|
||||
user_query,
|
||||
conversation_log={},
|
||||
model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
|
||||
model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
|
||||
loaded_model: Union[GPT4All, None] = None,
|
||||
completion_func=None,
|
||||
) -> ThreadedGenerator:
|
||||
"""
|
||||
Converse with user using Falcon
|
||||
Converse with user using Llama
|
||||
"""
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
# Initialize Variables
|
||||
|
@ -92,18 +125,18 @@ def converse_falcon(
|
|||
# Get Conversation Primer appropriate to Conversation Type
|
||||
# TODO If compiled_references_message is too long, we need to truncate it.
|
||||
if compiled_references_message == "":
|
||||
conversation_primer = prompts.conversation_falcon.format(query=user_query)
|
||||
conversation_primer = prompts.conversation_llamav2.format(query=user_query)
|
||||
else:
|
||||
conversation_primer = prompts.notes_conversation.format(
|
||||
current_date=current_date, query=user_query, references=compiled_references_message
|
||||
conversation_primer = prompts.notes_conversation_llamav2.format(
|
||||
query=user_query, references=compiled_references_message
|
||||
)
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
prompts.personality.format(),
|
||||
prompts.system_prompt_message_llamav2,
|
||||
conversation_log,
|
||||
model_name="text-davinci-001", # This isn't actually the model, but this helps us get an approximate encoding to run message truncation.
|
||||
model_name=model,
|
||||
)
|
||||
|
||||
g = ThreadedGenerator(references, completion_func=completion_func)
|
||||
|
@ -113,24 +146,22 @@ def converse_falcon(
|
|||
|
||||
|
||||
def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
|
||||
user_message = messages[0]
|
||||
system_message = messages[-1]
|
||||
user_message = messages[-1]
|
||||
system_message = messages[0]
|
||||
conversation_history = messages[1:-1]
|
||||
|
||||
formatted_messages = [
|
||||
prompts.chat_history_falcon_from_assistant.format(message=system_message)
|
||||
prompts.chat_history_llamav2_from_assistant.format(message=message.content)
|
||||
if message.role == "assistant"
|
||||
else prompts.chat_history_falcon_from_user.format(message=message.content)
|
||||
else prompts.chat_history_llamav2_from_user.format(message=message.content)
|
||||
for message in conversation_history
|
||||
]
|
||||
|
||||
chat_history = "".join(formatted_messages)
|
||||
full_message = system_message.content + chat_history + user_message.content
|
||||
|
||||
prompted_message = prompts.general_conversation_falcon.format(query=full_message)
|
||||
response_iterator = model.generate(
|
||||
prompted_message, streaming=True, max_tokens=256, top_k=1, temp=0, repeat_penalty=2.0
|
||||
)
|
||||
templated_system_message = prompts.system_prompt_llamav2.format(message=system_message.content)
|
||||
templated_user_message = prompts.general_conversation_llamav2.format(query=user_message.content)
|
||||
prompted_message = templated_system_message + chat_history + templated_user_message
|
||||
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=2000)
|
||||
for response in response_iterator:
|
||||
logger.info(response)
|
||||
g.send(response)
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
model_name_to_url = {
|
||||
"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q3_K_M.bin"
|
||||
}
|
33
src/khoj/processor/conversation/gpt4all/utils.py
Normal file
33
src/khoj/processor/conversation/gpt4all/utils.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
import os
|
||||
import logging
|
||||
import requests
|
||||
from gpt4all import GPT4All
|
||||
import tqdm
|
||||
|
||||
from khoj.processor.conversation.gpt4all import model_metadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def download_model(model_name):
|
||||
url = model_metadata.model_name_to_url.get(model_name)
|
||||
if not url:
|
||||
logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
|
||||
return GPT4All(model_name)
|
||||
|
||||
filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}")
|
||||
if os.path.exists(filename):
|
||||
return GPT4All(model_name)
|
||||
|
||||
try:
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
with open(filename, "wb") as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
return GPT4All(model_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}")
|
||||
return None
|
|
@ -18,34 +18,54 @@ Question: {query}
|
|||
""".strip()
|
||||
)
|
||||
|
||||
general_conversation_falcon = PromptTemplate.from_template(
|
||||
"""
|
||||
Using your general knowledge and our past conversations as context, answer the following question.
|
||||
### Instruct:
|
||||
{query}
|
||||
### Response:
|
||||
""".strip()
|
||||
)
|
||||
system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant.
|
||||
Using your general knowledge and our past conversations as context, answer the following question."""
|
||||
|
||||
chat_history_falcon_from_user = PromptTemplate.from_template(
|
||||
system_prompt_message_extract_questions_llamav2 = f"""You are Khoj, a kind and intelligent personal assistant.
|
||||
- When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
|
||||
- Try to be as specific as possible. For example, rather than use "they" or "it", use the name of the person or thing you are referring to.
|
||||
- Write the question as if you can search for the answer on the user's personal notes.
|
||||
- Add as much context from the previous questions and notes as required into your search queries.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
- Provide search queries as a list of questions
|
||||
What follow-up questions, if any, will you need to ask to answer the user's question?
|
||||
"""
|
||||
|
||||
system_prompt_llamav2 = PromptTemplate.from_template(
|
||||
"""
|
||||
### Human:
|
||||
<s>[INST] <<SYS>>
|
||||
{message}
|
||||
""".strip()
|
||||
<</SYS>>Hi there! [/INST] Hello! How can I help you today? </s>"""
|
||||
)
|
||||
|
||||
chat_history_falcon_from_assistant = PromptTemplate.from_template(
|
||||
extract_questions_system_prompt_llamav2 = PromptTemplate.from_template(
|
||||
"""
|
||||
### Assistant:
|
||||
<s>[INST] <<SYS>>
|
||||
{message}
|
||||
<</SYS>>[/INST]</s>"""
|
||||
)
|
||||
|
||||
general_conversation_llamav2 = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST]{query}[/INST]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
conversation_falcon = PromptTemplate.from_template(
|
||||
chat_history_llamav2_from_user = PromptTemplate.from_template(
|
||||
"""
|
||||
Using our past conversations as context, answer the following question.
|
||||
<s>[INST]{message}[/INST]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
Question: {query}
|
||||
chat_history_llamav2_from_assistant = PromptTemplate.from_template(
|
||||
"""
|
||||
{message}</s>
|
||||
""".strip()
|
||||
)
|
||||
|
||||
conversation_llamav2 = PromptTemplate.from_template(
|
||||
"""
|
||||
<s>[INST]{query}[/INST]
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
@ -63,13 +83,10 @@ Question: {query}
|
|||
""".strip()
|
||||
)
|
||||
|
||||
notes_conversation_falcon = PromptTemplate.from_template(
|
||||
notes_conversation_llamav2 = PromptTemplate.from_template(
|
||||
"""
|
||||
Using the notes and our past conversations as context, answer the following question. If the answer is not contained within the notes, say "I don't know."
|
||||
|
||||
Notes:
|
||||
{references}
|
||||
|
||||
Question: {query}
|
||||
""".strip()
|
||||
)
|
||||
|
@ -109,37 +126,22 @@ Question: {user_query}
|
|||
Answer (in second person):"""
|
||||
)
|
||||
|
||||
extract_questions_falcon = PromptTemplate.from_template(
|
||||
extract_questions_llamav2_sample = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes.
|
||||
- The user will provide their questions and answers to you for context.
|
||||
- Add as much context from the previous questions and answers as required into your search queries.
|
||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
|
||||
What searches, if any, will you need to perform to answer the users question?
|
||||
|
||||
Q: How was my trip to Cambodia?
|
||||
|
||||
["How was my trip to Cambodia?"]
|
||||
|
||||
A: The trip was amazing. I went to the Angkor Wat temple and it was beautiful.
|
||||
|
||||
Q: Who did i visit that temple with?
|
||||
|
||||
["Who did I visit the Angkor Wat Temple in Cambodia with?"]
|
||||
|
||||
A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
|
||||
|
||||
Q: How many tennis balls fit in the back of a 2002 Honda Civic?
|
||||
|
||||
["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]
|
||||
|
||||
A: 1085 tennis balls will fit in the trunk of a Honda Civic
|
||||
|
||||
<s>[INST]<<SYS>>Current Date: {current_date}<</SYS>>[/INST]</s>
|
||||
<s>[INST]<<SYS>>
|
||||
Use these notes from the user's previous conversations to provide a response:
|
||||
{chat_history}
|
||||
Q: {text}
|
||||
<</SYS>>[/INST]</s>
|
||||
|
||||
<s>[INST]How was my trip to Cambodia?[/INST][]</s>
|
||||
<s>[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia?</s>
|
||||
<s>[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have?</s>
|
||||
<s>[INST]How many tennis balls fit in the back of a 2002 Honda Civic?[/INST]What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?</s>
|
||||
<s>[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'</s>
|
||||
<s>[INST]How are you feeling today?[/INST]</s>
|
||||
<s>[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age?</s>
|
||||
<s>[INST]{query}[/INST]
|
||||
"""
|
||||
)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ import queue
|
|||
from khoj.utils.helpers import merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "text-davinci-001": 910}
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 850}
|
||||
|
||||
|
||||
class ThreadedGenerator:
|
||||
|
@ -102,7 +102,10 @@ def generate_chatml_messages_with_context(
|
|||
|
||||
def truncate_messages(messages, max_prompt_size, model_name):
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
encoder = tiktoken.encoding_for_model("text-davinci-001")
|
||||
tokens = sum([len(encoder.encode(message.content)) for message in messages])
|
||||
while tokens > max_prompt_size and len(messages) > 1:
|
||||
messages.pop()
|
||||
|
|
|
@ -38,7 +38,7 @@ from khoj.utils.yaml import save_config_to_file_updated_state
|
|||
from fastapi.responses import StreamingResponse, Response
|
||||
from khoj.routers.helpers import perform_chat_checks, generate_chat_response, update_telemetry_state
|
||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_falcon, converse_falcon
|
||||
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_offline
|
||||
from fastapi.requests import Request
|
||||
|
||||
|
||||
|
@ -715,7 +715,9 @@ async def extract_references_and_questions(
|
|||
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
|
||||
else:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
inferred_queries = extract_questions_falcon(q, loaded_model=loaded_model, conversation_log=meta_log)
|
||||
inferred_queries = extract_questions_offline(
|
||||
q, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||
)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
with timer("Searching knowledge base took", logger):
|
||||
|
|
|
@ -8,7 +8,7 @@ from fastapi import HTTPException, Request
|
|||
from khoj.utils import state
|
||||
from khoj.utils.helpers import timer, log_telemetry
|
||||
from khoj.processor.conversation.openai.gpt import converse
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
||||
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -111,7 +111,7 @@ def generate_chat_response(
|
|||
)
|
||||
else:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
chat_response = converse_falcon(
|
||||
chat_response = converse_offline(
|
||||
references=compiled_references,
|
||||
user_query=q,
|
||||
loaded_model=loaded_model,
|
||||
|
|
|
@ -5,8 +5,7 @@ from enum import Enum
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
|
||||
|
||||
from gpt4all import GPT4All
|
||||
from khoj.processor.conversation.gpt4all.utils import download_model
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
@ -79,7 +78,7 @@ class SearchModels:
|
|||
|
||||
@dataclass
|
||||
class GPT4AllProcessorConfig:
|
||||
chat_model: Optional[str] = "ggml-model-gpt4all-falcon-q4_0.bin"
|
||||
chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
||||
loaded_model: Union[Any, None] = None
|
||||
|
||||
|
||||
|
@ -96,7 +95,7 @@ class ConversationProcessorConfigModel:
|
|||
self.meta_log: dict = {}
|
||||
|
||||
if not self.openai_model and self.enable_offline_chat:
|
||||
self.gpt4all_model.loaded_model = GPT4All(self.gpt4all_model.chat_model) # type: ignore
|
||||
self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
|
||||
else:
|
||||
self.gpt4all_model.loaded_model = None
|
||||
|
||||
|
|
|
@ -16,14 +16,18 @@ from freezegun import freeze_time
|
|||
from gpt4all import GPT4All
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon, extract_questions_falcon
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, extract_questions_offline, filter_questions
|
||||
from khoj.processor.conversation.gpt4all.utils import download_model
|
||||
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
|
||||
MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loaded_model():
|
||||
return GPT4All("ggml-model-gpt4all-falcon-q4_0.bin")
|
||||
download_model(MODEL_NAME)
|
||||
return GPT4All(MODEL_NAME)
|
||||
|
||||
|
||||
freezegun.configure(extend_ignore_list=["transformers"])
|
||||
|
@ -35,24 +39,40 @@ freezegun.configure(extend_ignore_list=["transformers"])
|
|||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
"Where did I go for dinner yesterday?", loaded_model=loaded_model, run_extraction=True
|
||||
)
|
||||
response = extract_questions_offline("Where did I go for dinner yesterday?", loaded_model=loaded_model)
|
||||
|
||||
assert len(response) >= 1
|
||||
assert response[-1] == "Where did I go for dinner yesterday?"
|
||||
|
||||
assert any(
|
||||
[
|
||||
"dt>='1984-04-01'" in response[0] and "dt<'1984-04-02'" in response[0],
|
||||
"dt>='1984-04-01'" in response[0] and "dt<='1984-04-01'" in response[0],
|
||||
'dt>="1984-04-01"' in response[0] and 'dt<"1984-04-02"' in response[0],
|
||||
'dt>="1984-04-01"' in response[0] and 'dt<="1984-04-01"' in response[0],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor still isn't very date aware nor capable of formatting")
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon("Which countries did I visit last month?", loaded_model=loaded_model)
|
||||
response = extract_questions_offline("Which countries did I visit last month?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(response) == 1
|
||||
assert response == ["Which countries did I visit last month?"]
|
||||
assert len(response) >= 1
|
||||
# The user query should be the last question in the response
|
||||
assert response[-1] == ["Which countries did I visit last month?"]
|
||||
assert any(
|
||||
[
|
||||
"dt>='1984-03-01'" in response[0] and "dt<'1984-04-01'" in response[0],
|
||||
"dt>='1984-03-01'" in response[0] and "dt<='1984-03-31'" in response[0],
|
||||
'dt>="1984-03-01"' in response[0] and 'dt<"1984-04-01"' in response[0],
|
||||
'dt>="1984-03-01"' in response[0] and 'dt<="1984-03-31"' in response[0],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -60,9 +80,7 @@ def test_extract_question_with_date_filter_from_relative_month(loaded_model):
|
|||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_with_date_filter_from_relative_year(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
"Which countries have I visited this year?", loaded_model=loaded_model, run_extraction=True
|
||||
)
|
||||
response = extract_questions_offline("Which countries have I visited this year?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
|
@ -73,25 +91,26 @@ def test_extract_question_with_date_filter_from_relative_year(loaded_model):
|
|||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_explicit_questions_from_message(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon("What is the Sun? What is the Moon?", loaded_model=loaded_model)
|
||||
response = extract_questions_offline("What is the Sun? What is the Moon?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = ["What is the Sun?", "What is the Moon?"]
|
||||
assert len(response) == 2
|
||||
assert expected_responses == response
|
||||
assert len(response) >= 2
|
||||
assert expected_responses[0] == response[-2]
|
||||
assert expected_responses[1] == response[-1]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_implicit_questions_from_message(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon("Is Morpheus taller than Neo?", loaded_model=loaded_model, run_extraction=True)
|
||||
response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
("morpheus", "neo"),
|
||||
("morpheus", "neo", "height", "taller", "shorter"),
|
||||
]
|
||||
assert len(response) == 2
|
||||
assert len(response) == 3
|
||||
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
|
||||
"Expected two search queries in response but got: " + response[0]
|
||||
)
|
||||
|
@ -106,18 +125,19 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
|||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
response = extract_questions_offline(
|
||||
"Does he have any sons?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
run_extraction=True,
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
expected_responses = [
|
||||
"do not have",
|
||||
"clarify",
|
||||
"am sorry",
|
||||
"Vader",
|
||||
"sons",
|
||||
"son",
|
||||
"Darth",
|
||||
"children",
|
||||
]
|
||||
|
||||
# Assert
|
||||
|
@ -128,7 +148,7 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
|
||||
# @pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
||||
# Arrange
|
||||
|
@ -137,17 +157,24 @@ def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
|||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
response = extract_questions_offline(
|
||||
"Is she a Jedi?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
run_extraction=True,
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
expected_responses = [
|
||||
"Leia",
|
||||
"Vader",
|
||||
"daughter",
|
||||
]
|
||||
|
||||
# Assert
|
||||
assert len(response) == 1
|
||||
assert "Leia" in response[0]
|
||||
assert len(response) >= 1
|
||||
assert any([expected_response in response[0] for expected_response in expected_responses]), (
|
||||
"Expected chat actor to mention Darth Vader's daughter, but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -160,10 +187,9 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
|
|||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
response = extract_questions_offline(
|
||||
"What was the Pizza place we ate at over there?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
run_extraction=True,
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
|
||||
|
@ -185,7 +211,7 @@ def test_generate_search_query_with_date_and_context_from_chat_history(loaded_mo
|
|||
@pytest.mark.chatquality
|
||||
def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Hello, my name is Testatron. Who are you?",
|
||||
loaded_model=loaded_model,
|
||||
|
@ -201,7 +227,6 @@ def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor isn't really good at proper nouns yet.")
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model):
|
||||
"Chat actor needs to use context in previous notes and chat history to answer question"
|
||||
|
@ -216,7 +241,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model)
|
|||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
|
@ -241,7 +266,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
|
|||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=[
|
||||
"Testatron was born on 1st April 1984 in Testville."
|
||||
], # Assume context retrieved from notes for the user_query
|
||||
|
@ -257,7 +282,6 @@ def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor is rather liable to lying.")
|
||||
@pytest.mark.chatquality
|
||||
def test_refuse_answering_unanswerable_question(loaded_model):
|
||||
"Chat actor should not try make up answers to unanswerable questions."
|
||||
|
@ -268,7 +292,7 @@ def test_refuse_answering_unanswerable_question(loaded_model):
|
|||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
|
@ -309,7 +333,7 @@ Expenses:Food:Dining 10.00 USD""",
|
|||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I have for Dinner today?",
|
||||
loaded_model=loaded_model,
|
||||
|
@ -341,7 +365,7 @@ Expenses:Food:Dining 10.00 USD""",
|
|||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="How much did I spend on dining this year?",
|
||||
loaded_model=loaded_model,
|
||||
|
@ -365,7 +389,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
|
|||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Write a haiku about unit testing in 3 lines",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
|
@ -382,7 +406,6 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded
|
|||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.")
|
||||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
|
||||
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
|
||||
|
@ -397,7 +420,7 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
|||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="How many kids does my older sister have?",
|
||||
loaded_model=loaded_model,
|
||||
|
@ -411,6 +434,17 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
|||
)
|
||||
|
||||
|
||||
def test_filter_questions():
|
||||
test_questions = [
|
||||
"I don't know how to answer that",
|
||||
"I cannot answer anything about the nuclear secrets",
|
||||
"Who is on the basketball team?",
|
||||
]
|
||||
filtered_questions = filter_questions(test_questions)
|
||||
assert len(filtered_questions) == 1
|
||||
assert filtered_questions[0] == "Who is on the basketball team?"
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
|
|
Loading…
Reference in a new issue