+
@@ -346,7 +346,7 @@
featuresHintText.classList.add("show");
}
- fetch('/api/config/data/processor/conversation/enable_offline_chat' + '?enable_offline_chat=' + enable, {
+ fetch('/api/config/data/processor/conversation/offline_chat' + '?enable_offline_chat=' + enable, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
diff --git a/src/khoj/migrations/migrate_offline_chat_schema.py b/src/khoj/migrations/migrate_offline_chat_schema.py
new file mode 100644
index 00000000..873783a3
--- /dev/null
+++ b/src/khoj/migrations/migrate_offline_chat_schema.py
@@ -0,0 +1,83 @@
+"""
+Current format of khoj.yml
+---
+app:
+ ...
+content-type:
+ ...
+processor:
+ conversation:
+ enable-offline-chat: false
+ conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
+ openai:
+ ...
+search-type:
+ ...
+
+New format of khoj.yml
+---
+app:
+ ...
+content-type:
+ ...
+processor:
+ conversation:
+ offline-chat:
+ enable-offline-chat: false
+ chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin
+ tokenizer: null
+ max_prompt_size: null
+ conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
+ openai:
+ ...
+search-type:
+ ...
+"""
+import logging
+from packaging import version
+
+from khoj.utils.yaml import load_config_from_file, save_config_to_file
+
+
+logger = logging.getLogger(__name__)
+
+
+def migrate_offline_chat_schema(args):
+ schema_version = "0.12.3"
+ raw_config = load_config_from_file(args.config_file)
+ previous_version = raw_config.get("version")
+
+ if "processor" not in raw_config:
+ return args
+ if raw_config["processor"] is None:
+ return args
+ if "conversation" not in raw_config["processor"]:
+ return args
+
+ if previous_version is None or version.parse(previous_version) < version.parse("0.12.3"):
+ logger.info(
+ f"Upgrading config schema to {schema_version} from {previous_version} to make (offline) chat more configuration"
+ )
+ raw_config["version"] = schema_version
+
+ # Create max-prompt-size field in conversation processor schema
+ raw_config["processor"]["conversation"]["max-prompt-size"] = None
+ raw_config["processor"]["conversation"]["tokenizer"] = None
+
+ # Create offline chat schema based on existing enable_offline_chat field in khoj config schema
+ offline_chat_model = (
+ raw_config["processor"]["conversation"]
+ .get("offline-chat", {})
+ .get("chat-model", "llama-2-7b-chat.ggmlv3.q4_0.bin")
+ )
+ raw_config["processor"]["conversation"]["offline-chat"] = {
+ "enable-offline-chat": raw_config["processor"]["conversation"].get("enable-offline-chat", False),
+ "chat-model": offline_chat_model,
+ }
+
+ # Delete old enable-offline-chat field from conversation processor schema
+ if "enable-offline-chat" in raw_config["processor"]["conversation"]:
+ del raw_config["processor"]["conversation"]["enable-offline-chat"]
+
+ save_config_to_file(raw_config, args.config_file)
+ return args
diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py
index 9bc9ea52..7e92d002 100644
--- a/src/khoj/processor/conversation/gpt4all/chat_model.py
+++ b/src/khoj/processor/conversation/gpt4all/chat_model.py
@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
def extract_questions_offline(
text: str,
- model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
+ model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
loaded_model: Union[Any, None] = None,
conversation_log={},
use_history: bool = True,
@@ -113,7 +113,7 @@ def filter_questions(questions: List[str]):
]
filtered_questions = []
for q in questions:
- if not any([word in q.lower() for word in hint_words]):
+ if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q):
filtered_questions.append(q)
return filtered_questions
@@ -123,10 +123,12 @@ def converse_offline(
references,
user_query,
conversation_log={},
- model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin",
+ model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin",
loaded_model: Union[Any, None] = None,
completion_func=None,
conversation_command=ConversationCommand.Default,
+ max_prompt_size=None,
+ tokenizer_name=None,
) -> Union[ThreadedGenerator, Iterator[str]]:
"""
Converse with user using Llama
@@ -158,6 +160,8 @@ def converse_offline(
prompts.system_prompt_message_llamav2,
conversation_log,
model_name=model,
+ max_prompt_size=max_prompt_size,
+ tokenizer_name=tokenizer_name,
)
g = ThreadedGenerator(references, completion_func=completion_func)
diff --git a/src/khoj/processor/conversation/gpt4all/model_metadata.py b/src/khoj/processor/conversation/gpt4all/model_metadata.py
deleted file mode 100644
index 065e3720..00000000
--- a/src/khoj/processor/conversation/gpt4all/model_metadata.py
+++ /dev/null
@@ -1,3 +0,0 @@
-model_name_to_url = {
- "llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin"
-}
diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py
index 4042fbe2..d5201780 100644
--- a/src/khoj/processor/conversation/gpt4all/utils.py
+++ b/src/khoj/processor/conversation/gpt4all/utils.py
@@ -1,24 +1,8 @@
-import os
import logging
-import requests
-import hashlib
-from tqdm import tqdm
-
-from khoj.processor.conversation.gpt4all import model_metadata
logger = logging.getLogger(__name__)
-expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"}
-
-
-def get_md5_checksum(filename: str):
- hash_md5 = hashlib.md5()
- with open(filename, "rb") as f:
- for chunk in iter(lambda: f.read(8192), b""):
- hash_md5.update(chunk)
- return hash_md5.hexdigest()
-
def download_model(model_name: str):
try:
@@ -27,57 +11,12 @@ def download_model(model_name: str):
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
- url = model_metadata.model_name_to_url.get(model_name)
- model_path = os.path.expanduser(f"~/.cache/gpt4all/")
- if not url:
- logger.debug(f"Model {model_name} not found in model metadata. Skipping download.")
- return GPT4All(model_name=model_name, model_path=model_path)
-
- filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}")
- if os.path.exists(filename):
- # Check if the user is connected to the internet
- try:
- requests.get("https://www.google.com/", timeout=5)
- except:
- logger.debug("User is offline. Disabling allowed download flag")
- return GPT4All(model_name=model_name, model_path=model_path, allow_download=False)
- return GPT4All(model_name=model_name, model_path=model_path)
-
- # Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file
- tmp_filename = filename + ".tmp"
-
+ # Use GPU for Chat Model, if available
try:
- os.makedirs(os.path.dirname(tmp_filename), exist_ok=True)
- logger.debug(f"Downloading model {model_name} from {url} to {filename}...")
- with requests.get(url, stream=True) as r:
- r.raise_for_status()
- total_size = int(r.headers.get("content-length", 0))
- with open(tmp_filename, "wb") as f, tqdm(
- unit="B", # unit string to be displayed.
- unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc.
- unit_divisor=1024, # is used when unit_scale is true
- total=total_size, # the total iteration.
- desc=model_name, # prefix to be displayed on progress bar.
- ) as progress_bar:
- for chunk in r.iter_content(chunk_size=8192):
- f.write(chunk)
- progress_bar.update(len(chunk))
+ model = GPT4All(model_name=model_name, device="gpu")
+ logger.debug("Loaded chat model to GPU.")
+ except ValueError:
+ model = GPT4All(model_name=model_name)
+ logger.debug("Loaded chat model to CPU.")
- # Verify the checksum
- if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename):
- logger.error(
- f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available."
- )
- os.remove(tmp_filename)
- raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.")
-
- # Move the tmp file to the actual file
- os.rename(tmp_filename, filename)
- logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}")
- return GPT4All(model_name)
- except Exception as e:
- logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True)
- # Remove the tmp file if it exists
- if os.path.exists(tmp_filename):
- os.remove(tmp_filename)
- return None
+ return model
diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py
index 96510586..73b4f176 100644
--- a/src/khoj/processor/conversation/openai/gpt.py
+++ b/src/khoj/processor/conversation/openai/gpt.py
@@ -116,6 +116,8 @@ def converse(
temperature: float = 0.2,
completion_func=None,
conversation_command=ConversationCommand.Default,
+ max_prompt_size=None,
+ tokenizer_name=None,
):
"""
Converse with user using OpenAI's ChatGPT
@@ -141,6 +143,8 @@ def converse(
prompts.personality.format(),
conversation_log,
model,
+ max_prompt_size,
+ tokenizer_name,
)
truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index 4de3c623..d487609d 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -23,7 +23,7 @@ no_notes_found = PromptTemplate.from_template(
""".strip()
)
-system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant.
+system_prompt_message_llamav2 = f"""You are Khoj, a smart, inquisitive and helpful personal assistant.
Using your general knowledge and our past conversations as context, answer the following question.
If you do not know the answer, say 'I don't know.'"""
@@ -51,13 +51,13 @@ extract_questions_system_prompt_llamav2 = PromptTemplate.from_template(
general_conversation_llamav2 = PromptTemplate.from_template(
"""
-[INST]{query}[/INST]
+[INST] {query} [/INST]
""".strip()
)
chat_history_llamav2_from_user = PromptTemplate.from_template(
"""
-[INST]{message}[/INST]
+[INST] {message} [/INST]
""".strip()
)
@@ -69,7 +69,7 @@ chat_history_llamav2_from_assistant = PromptTemplate.from_template(
conversation_llamav2 = PromptTemplate.from_template(
"""
-[INST]{query}[/INST]
+[INST] {query} [/INST]
""".strip()
)
@@ -91,7 +91,7 @@ Question: {query}
notes_conversation_llamav2 = PromptTemplate.from_template(
"""
-Notes:
+User's Notes:
{references}
Question: {query}
""".strip()
@@ -134,19 +134,25 @@ Answer (in second person):"""
extract_questions_llamav2_sample = PromptTemplate.from_template(
"""
-[INST]<>Current Date: {current_date}<>[/INST]
-[INST]How was my trip to Cambodia?[/INST][]
-[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia?
-[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have?
-[INST]How many tennis balls fit in the back of a 2002 Honda Civic?[/INST]What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?
-[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'
-[INST]How are you feeling today?[/INST]
-[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age?
-[INST]<>
+[INST] <>Current Date: {current_date}<> [/INST]
+[INST] How was my trip to Cambodia? [/INST]
+How was my trip to Cambodia?
+[INST] Who did I visit the temple with on that trip? [/INST]
+Who did I visit the temple with in Cambodia?
+[INST] How should I take care of my plants? [/INST]
+What kind of plants do I have? What issues do my plants have?
+[INST] How many tennis balls fit in the back of a 2002 Honda Civic? [/INST]
+What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic?
+[INST] What did I do for Christmas last year? [/INST]
+What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}'
+[INST] How are you feeling today? [/INST]
+[INST] Is Alice older than Bob? [/INST]
+When was Alice born? What is Bob's age?
+[INST] <>
Use these notes from the user's previous conversations to provide a response:
{chat_history}
-<>[/INST]
-[INST]{query}[/INST]
+<> [/INST]
+[INST] {query} [/INST]
"""
)
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index 4a92c367..83d51f2d 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -3,24 +3,27 @@ import logging
from time import perf_counter
import json
from datetime import datetime
+import queue
import tiktoken
# External packages
from langchain.schema import ChatMessage
-from transformers import LlamaTokenizerFast
+from transformers import AutoTokenizer
# Internal Packages
-import queue
from khoj.utils.helpers import merge_dicts
+
logger = logging.getLogger(__name__)
-max_prompt_size = {
+model_to_prompt_size = {
"gpt-3.5-turbo": 4096,
"gpt-4": 8192,
- "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548,
+ "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548,
"gpt-3.5-turbo-16k": 15000,
}
-tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"}
+model_to_tokenizer = {
+ "llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer",
+}
class ThreadedGenerator:
@@ -82,9 +85,26 @@ def message_to_log(
def generate_chatml_messages_with_context(
- user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2
+ user_message,
+ system_message,
+ conversation_log={},
+ model_name="gpt-3.5-turbo",
+ max_prompt_size=None,
+ tokenizer_name=None,
):
"""Generate messages for ChatGPT with context from previous conversation"""
+ # Set max prompt size from user config, pre-configured for model or to default prompt size
+ try:
+ max_prompt_size = max_prompt_size or model_to_prompt_size[model_name]
+ except:
+ max_prompt_size = 2000
+ logger.warning(
+ f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window."
+ )
+
+ # Scale lookback turns proportional to max prompt size supported by model
+ lookback_turns = max_prompt_size // 750
+
# Extract Chat History for Context
chat_logs = []
for chat in conversation_log.get("chat", []):
@@ -105,19 +125,28 @@ def generate_chatml_messages_with_context(
messages = user_chatml_message + rest_backnforths + system_chatml_message
# Truncate oldest messages from conversation history until under max supported prompt size by model
- messages = truncate_messages(messages, max_prompt_size[model_name], model_name)
+ messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name)
# Return message in chronological order
return messages[::-1]
-def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]:
+def truncate_messages(
+ messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None
+) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
- if "llama" in model_name:
- encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name])
- else:
- encoder = tiktoken.encoding_for_model(model_name)
+ try:
+ if model_name.startswith("gpt-"):
+ encoder = tiktoken.encoding_for_model(model_name)
+ else:
+ encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name])
+ except:
+ default_tokenizer = "hf-internal-testing/llama-tokenizer"
+ encoder = AutoTokenizer.from_pretrained(default_tokenizer)
+ logger.warning(
+ f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing."
+ )
system_message = messages.pop()
system_message_tokens = len(encoder.encode(system_message.content))
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 780a6c57..0331500b 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -284,10 +284,11 @@ if not state.demo:
except Exception as e:
return {"status": "error", "message": str(e)}
- @api.post("/config/data/processor/conversation/enable_offline_chat", status_code=200)
+ @api.post("/config/data/processor/conversation/offline_chat", status_code=200)
async def set_processor_enable_offline_chat_config_data(
request: Request,
enable_offline_chat: bool,
+ offline_chat_model: Optional[str] = None,
client: Optional[str] = None,
):
_initialize_config()
@@ -301,7 +302,9 @@ if not state.demo:
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
assert state.config.processor.conversation is not None
- state.config.processor.conversation.enable_offline_chat = enable_offline_chat
+ state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat
+ if offline_chat_model is not None:
+ state.config.processor.conversation.offline_chat.chat_model = offline_chat_model
state.processor_config = configure_processor(state.config.processor, state.processor_config)
update_telemetry_state(
@@ -713,7 +716,7 @@ async def chat(
conversation_command = ConversationCommand.General
if conversation_command == ConversationCommand.Help:
- model_type = "offline" if state.processor_config.conversation.enable_offline_chat else "openai"
+ model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai"
formatted_help = help_message.format(model=model_type, version=state.khoj_version)
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
@@ -788,7 +791,7 @@ async def extract_references_and_questions(
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
- if state.processor_config.conversation.enable_offline_chat:
+ if state.processor_config.conversation.offline_chat.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
@@ -804,7 +807,7 @@ async def extract_references_and_questions(
with timer("Searching knowledge base took", logger):
result_list = []
for query in inferred_queries:
- n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n
+ n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n
result_list.extend(
await search(
f"{query} {filters_in_query}",
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index 267af330..6b42f29c 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -113,7 +113,7 @@ def generate_chat_response(
meta_log=meta_log,
)
- if state.processor_config.conversation.enable_offline_chat:
+ if state.processor_config.conversation.offline_chat.enable_offline_chat:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
chat_response = converse_offline(
references=compiled_references,
@@ -122,6 +122,9 @@ def generate_chat_response(
conversation_log=meta_log,
completion_func=partial_completion,
conversation_command=conversation_command,
+ model=state.processor_config.conversation.offline_chat.chat_model,
+ max_prompt_size=state.processor_config.conversation.max_prompt_size,
+ tokenizer_name=state.processor_config.conversation.tokenizer,
)
elif state.processor_config.conversation.openai_model:
@@ -135,6 +138,8 @@ def generate_chat_response(
api_key=api_key,
completion_func=partial_completion,
conversation_command=conversation_command,
+ max_prompt_size=state.processor_config.conversation.max_prompt_size,
+ tokenizer_name=state.processor_config.conversation.tokenizer,
)
except Exception as e:
diff --git a/src/khoj/utils/cli.py b/src/khoj/utils/cli.py
index 78a9ccf9..1d6106cb 100644
--- a/src/khoj/utils/cli.py
+++ b/src/khoj/utils/cli.py
@@ -9,6 +9,7 @@ from khoj.utils.yaml import parse_config_from_file
from khoj.migrations.migrate_version import migrate_config_to_version
from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
from khoj.migrations.migrate_offline_model import migrate_offline_model
+from khoj.migrations.migrate_offline_chat_schema import migrate_offline_chat_schema
def cli(args=None):
@@ -55,7 +56,12 @@ def cli(args=None):
def run_migrations(args):
- migrations = [migrate_config_to_version, migrate_processor_conversation_schema, migrate_offline_model]
+ migrations = [
+ migrate_config_to_version,
+ migrate_processor_conversation_schema,
+ migrate_offline_model,
+ migrate_offline_chat_schema,
+ ]
for migration in migrations:
args = migration(args)
return args
diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py
index a6532346..3930ec98 100644
--- a/src/khoj/utils/config.py
+++ b/src/khoj/utils/config.py
@@ -84,7 +84,6 @@ class SearchModels:
@dataclass
class GPT4AllProcessorConfig:
- chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
loaded_model: Union[Any, None] = None
@@ -95,18 +94,20 @@ class ConversationProcessorConfigModel:
):
self.openai_model = conversation_config.openai
self.gpt4all_model = GPT4AllProcessorConfig()
- self.enable_offline_chat = conversation_config.enable_offline_chat
+ self.offline_chat = conversation_config.offline_chat
+ self.max_prompt_size = conversation_config.max_prompt_size
+ self.tokenizer = conversation_config.tokenizer
self.conversation_logfile = Path(conversation_config.conversation_logfile)
self.chat_session: List[str] = []
self.meta_log: dict = {}
- if self.enable_offline_chat:
+ if self.offline_chat.enable_offline_chat:
try:
- self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model)
+ self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model)
except ValueError as e:
+ self.offline_chat.enable_offline_chat = False
self.gpt4all_model.loaded_model = None
logger.error(f"Error while loading offline chat model: {e}", exc_info=True)
- self.enable_offline_chat = False
else:
self.gpt4all_model.loaded_model = None
diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py
index 0a916db4..f7c42266 100644
--- a/src/khoj/utils/rawconfig.py
+++ b/src/khoj/utils/rawconfig.py
@@ -91,10 +91,17 @@ class OpenAIProcessorConfig(ConfigBase):
chat_model: Optional[str] = "gpt-3.5-turbo"
+class OfflineChatProcessorConfig(ConfigBase):
+ enable_offline_chat: Optional[bool] = False
+ chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin"
+
+
class ConversationProcessorConfig(ConfigBase):
conversation_logfile: Path
openai: Optional[OpenAIProcessorConfig]
- enable_offline_chat: Optional[bool] = False
+ offline_chat: Optional[OfflineChatProcessorConfig]
+ max_prompt_size: Optional[int]
+ tokenizer: Optional[str]
class ProcessorConfig(ConfigBase):
diff --git a/tests/conftest.py b/tests/conftest.py
index d851341d..f75dfceb 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -16,6 +16,7 @@ from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.rawconfig import (
ContentConfig,
ConversationProcessorConfig,
+ OfflineChatProcessorConfig,
OpenAIProcessorConfig,
ProcessorConfig,
TextContentConfig,
@@ -205,8 +206,9 @@ def processor_config_offline_chat(tmp_path_factory):
# Setup conversation processor
processor_config = ProcessorConfig()
+ offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True)
processor_config.conversation = ConversationProcessorConfig(
- enable_offline_chat=True,
+ offline_chat=offline_chat,
conversation_logfile=processor_dir.joinpath("conversation_logs.json"),
)
diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py
index d7904ff8..76ed26e7 100644
--- a/tests/test_gpt4all_chat_actors.py
+++ b/tests/test_gpt4all_chat_actors.py
@@ -24,7 +24,7 @@ from khoj.processor.conversation.gpt4all.utils import download_model
from khoj.processor.conversation.utils import message_to_log
-MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin"
+MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_0.bin"
@pytest.fixture(scope="session")
@@ -128,15 +128,15 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model):
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message(loaded_model):
# Act
- response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model)
+ response = extract_questions_offline("Is Carl taller than Ross?", loaded_model=loaded_model)
# Assert
- expected_responses = ["height", "taller", "shorter", "heights"]
+ expected_responses = ["height", "taller", "shorter", "heights", "who"]
assert len(response) <= 3
for question in response:
assert any([expected_response in question.lower() for expected_response in expected_responses]), (
- "Expected chat actor to ask follow-up questions about Morpheus and Neo, but got: " + question
+ "Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question
)
@@ -145,7 +145,7 @@ def test_extract_multiple_implicit_questions_from_message(loaded_model):
def test_generate_search_query_using_question_from_chat_history(loaded_model):
# Arrange
message_list = [
- ("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
+ ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
]
# Act
@@ -156,17 +156,22 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
use_history=True,
)
- expected_responses = [
- "Vader",
- "sons",
+ all_expected_in_response = [
+ "Anderson",
+ ]
+
+ any_expected_in_response = [
"son",
- "Darth",
+ "sons",
"children",
]
# Assert
assert len(response) >= 1
- assert any([expected_response in response[0] for expected_response in expected_responses]), (
+ assert all([expected_response in response[0] for expected_response in all_expected_in_response]), (
+ "Expected chat actor to ask for clarification in response, but got: " + response[0]
+ )
+ assert any([expected_response in response[0] for expected_response in any_expected_in_response]), (
"Expected chat actor to ask for clarification in response, but got: " + response[0]
)
@@ -176,20 +181,20 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model):
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
# Arrange
message_list = [
- ("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
+ ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []),
]
# Act
response = extract_questions_offline(
- "Is she a Jedi?",
+ "Is she a Doctor?",
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
use_history=True,
)
expected_responses = [
- "Leia",
- "Vader",
+ "Barbara",
+ "Robert",
"daughter",
]