diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 7e6cc409..769f015c 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -19,7 +19,7 @@ from khoj.utils.config import ( ) from khoj.utils.helpers import resolve_absolute_path, merge_dicts from khoj.utils.fs_syncer import collect_files -from khoj.utils.rawconfig import FullConfig, ProcessorConfig, ConversationProcessorConfig +from khoj.utils.rawconfig import FullConfig, OfflineChatProcessorConfig, ProcessorConfig, ConversationProcessorConfig from khoj.routers.indexer import configure_content, load_content, configure_search @@ -168,9 +168,7 @@ def configure_conversation_processor( conversation_config=ConversationProcessorConfig( conversation_logfile=conversation_logfile, openai=(conversation_config.openai if (conversation_config is not None) else None), - enable_offline_chat=( - conversation_config.enable_offline_chat if (conversation_config is not None) else False - ), + offline_chat=conversation_config.offline_chat if conversation_config else OfflineChatProcessorConfig(), ) ) else: diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 3b295a88..d41ca26b 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -236,7 +236,7 @@
-

Setup chat using OpenAI

+

Setup online chat using OpenAI

@@ -261,21 +261,21 @@ Chat

Offline Chat - Configured - {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat and not current_model_state.conversation_gpt4all %} + Configured + {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.offline_chat.enable_offline_chat and not current_model_state.conversation_gpt4all %} Not Configured {% endif %}

-

Setup offline chat (Llama V2)

+

Setup offline chat

-
+
-
+
@@ -346,7 +346,7 @@ featuresHintText.classList.add("show"); } - fetch('/api/config/data/processor/conversation/enable_offline_chat' + '?enable_offline_chat=' + enable, { + fetch('/api/config/data/processor/conversation/offline_chat' + '?enable_offline_chat=' + enable, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/src/khoj/migrations/migrate_offline_chat_schema.py b/src/khoj/migrations/migrate_offline_chat_schema.py new file mode 100644 index 00000000..873783a3 --- /dev/null +++ b/src/khoj/migrations/migrate_offline_chat_schema.py @@ -0,0 +1,83 @@ +""" +Current format of khoj.yml +--- +app: + ... +content-type: + ... +processor: + conversation: + enable-offline-chat: false + conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json + openai: + ... +search-type: + ... + +New format of khoj.yml +--- +app: + ... +content-type: + ... +processor: + conversation: + offline-chat: + enable-offline-chat: false + chat-model: llama-2-7b-chat.ggmlv3.q4_0.bin + tokenizer: null + max_prompt_size: null + conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json + openai: + ... +search-type: + ... +""" +import logging +from packaging import version + +from khoj.utils.yaml import load_config_from_file, save_config_to_file + + +logger = logging.getLogger(__name__) + + +def migrate_offline_chat_schema(args): + schema_version = "0.12.3" + raw_config = load_config_from_file(args.config_file) + previous_version = raw_config.get("version") + + if "processor" not in raw_config: + return args + if raw_config["processor"] is None: + return args + if "conversation" not in raw_config["processor"]: + return args + + if previous_version is None or version.parse(previous_version) < version.parse("0.12.3"): + logger.info( + f"Upgrading config schema to {schema_version} from {previous_version} to make (offline) chat more configuration" + ) + raw_config["version"] = schema_version + + # Create max-prompt-size field in conversation processor schema + raw_config["processor"]["conversation"]["max-prompt-size"] = None + raw_config["processor"]["conversation"]["tokenizer"] = None + + # Create offline chat schema based on existing enable_offline_chat field in khoj config schema + offline_chat_model = ( + raw_config["processor"]["conversation"] + .get("offline-chat", {}) + .get("chat-model", "llama-2-7b-chat.ggmlv3.q4_0.bin") + ) + raw_config["processor"]["conversation"]["offline-chat"] = { + "enable-offline-chat": raw_config["processor"]["conversation"].get("enable-offline-chat", False), + "chat-model": offline_chat_model, + } + + # Delete old enable-offline-chat field from conversation processor schema + if "enable-offline-chat" in raw_config["processor"]["conversation"]: + del raw_config["processor"]["conversation"]["enable-offline-chat"] + + save_config_to_file(raw_config, args.config_file) + return args diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index 9bc9ea52..7e92d002 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) def extract_questions_offline( text: str, - model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin", + model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin", loaded_model: Union[Any, None] = None, conversation_log={}, use_history: bool = True, @@ -113,7 +113,7 @@ def filter_questions(questions: List[str]): ] filtered_questions = [] for q in questions: - if not any([word in q.lower() for word in hint_words]): + if not any([word in q.lower() for word in hint_words]) and not is_none_or_empty(q): filtered_questions.append(q) return filtered_questions @@ -123,10 +123,12 @@ def converse_offline( references, user_query, conversation_log={}, - model: str = "llama-2-7b-chat.ggmlv3.q4_K_S.bin", + model: str = "llama-2-7b-chat.ggmlv3.q4_0.bin", loaded_model: Union[Any, None] = None, completion_func=None, conversation_command=ConversationCommand.Default, + max_prompt_size=None, + tokenizer_name=None, ) -> Union[ThreadedGenerator, Iterator[str]]: """ Converse with user using Llama @@ -158,6 +160,8 @@ def converse_offline( prompts.system_prompt_message_llamav2, conversation_log, model_name=model, + max_prompt_size=max_prompt_size, + tokenizer_name=tokenizer_name, ) g = ThreadedGenerator(references, completion_func=completion_func) diff --git a/src/khoj/processor/conversation/gpt4all/model_metadata.py b/src/khoj/processor/conversation/gpt4all/model_metadata.py deleted file mode 100644 index 065e3720..00000000 --- a/src/khoj/processor/conversation/gpt4all/model_metadata.py +++ /dev/null @@ -1,3 +0,0 @@ -model_name_to_url = { - "llama-2-7b-chat.ggmlv3.q4_K_S.bin": "https://huggingface.co/TheBloke/Llama-2-7B-Chat-GGML/resolve/main/llama-2-7b-chat.ggmlv3.q4_K_S.bin" -} diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py index 4042fbe2..d5201780 100644 --- a/src/khoj/processor/conversation/gpt4all/utils.py +++ b/src/khoj/processor/conversation/gpt4all/utils.py @@ -1,24 +1,8 @@ -import os import logging -import requests -import hashlib -from tqdm import tqdm - -from khoj.processor.conversation.gpt4all import model_metadata logger = logging.getLogger(__name__) -expected_checksum = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "cfa87b15d92fb15a2d7c354b0098578b"} - - -def get_md5_checksum(filename: str): - hash_md5 = hashlib.md5() - with open(filename, "rb") as f: - for chunk in iter(lambda: f.read(8192), b""): - hash_md5.update(chunk) - return hash_md5.hexdigest() - def download_model(model_name: str): try: @@ -27,57 +11,12 @@ def download_model(model_name: str): logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") raise e - url = model_metadata.model_name_to_url.get(model_name) - model_path = os.path.expanduser(f"~/.cache/gpt4all/") - if not url: - logger.debug(f"Model {model_name} not found in model metadata. Skipping download.") - return GPT4All(model_name=model_name, model_path=model_path) - - filename = os.path.expanduser(f"~/.cache/gpt4all/{model_name}") - if os.path.exists(filename): - # Check if the user is connected to the internet - try: - requests.get("https://www.google.com/", timeout=5) - except: - logger.debug("User is offline. Disabling allowed download flag") - return GPT4All(model_name=model_name, model_path=model_path, allow_download=False) - return GPT4All(model_name=model_name, model_path=model_path) - - # Download the model to a tmp file. Once the download is completed, move the tmp file to the actual file - tmp_filename = filename + ".tmp" - + # Use GPU for Chat Model, if available try: - os.makedirs(os.path.dirname(tmp_filename), exist_ok=True) - logger.debug(f"Downloading model {model_name} from {url} to {filename}...") - with requests.get(url, stream=True) as r: - r.raise_for_status() - total_size = int(r.headers.get("content-length", 0)) - with open(tmp_filename, "wb") as f, tqdm( - unit="B", # unit string to be displayed. - unit_scale=True, # let tqdm to determine the scale in kilo, mega..etc. - unit_divisor=1024, # is used when unit_scale is true - total=total_size, # the total iteration. - desc=model_name, # prefix to be displayed on progress bar. - ) as progress_bar: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - progress_bar.update(len(chunk)) + model = GPT4All(model_name=model_name, device="gpu") + logger.debug("Loaded chat model to GPU.") + except ValueError: + model = GPT4All(model_name=model_name) + logger.debug("Loaded chat model to CPU.") - # Verify the checksum - if expected_checksum.get(model_name) != get_md5_checksum(tmp_filename): - logger.error( - f"Checksum verification failed for {filename}. Removing the tmp file. Offline model will not be available." - ) - os.remove(tmp_filename) - raise ValueError(f"Checksum verification failed for downloading {model_name} from {url}.") - - # Move the tmp file to the actual file - os.rename(tmp_filename, filename) - logger.debug(f"Successfully downloaded model {model_name} from {url} to {filename}") - return GPT4All(model_name) - except Exception as e: - logger.error(f"Failed to download model {model_name} from {url} to {filename}. Error: {e}", exc_info=True) - # Remove the tmp file if it exists - if os.path.exists(tmp_filename): - os.remove(tmp_filename) - return None + return model diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 96510586..73b4f176 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -116,6 +116,8 @@ def converse( temperature: float = 0.2, completion_func=None, conversation_command=ConversationCommand.Default, + max_prompt_size=None, + tokenizer_name=None, ): """ Converse with user using OpenAI's ChatGPT @@ -141,6 +143,8 @@ def converse( prompts.personality.format(), conversation_log, model, + max_prompt_size, + tokenizer_name, ) truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages}) logger.debug(f"Conversation Context for GPT: {truncated_messages}") diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 4de3c623..d487609d 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -23,7 +23,7 @@ no_notes_found = PromptTemplate.from_template( """.strip() ) -system_prompt_message_llamav2 = f"""You are Khoj, a friendly, smart and helpful personal assistant. +system_prompt_message_llamav2 = f"""You are Khoj, a smart, inquisitive and helpful personal assistant. Using your general knowledge and our past conversations as context, answer the following question. If you do not know the answer, say 'I don't know.'""" @@ -51,13 +51,13 @@ extract_questions_system_prompt_llamav2 = PromptTemplate.from_template( general_conversation_llamav2 = PromptTemplate.from_template( """ -[INST]{query}[/INST] +[INST] {query} [/INST] """.strip() ) chat_history_llamav2_from_user = PromptTemplate.from_template( """ -[INST]{message}[/INST] +[INST] {message} [/INST] """.strip() ) @@ -69,7 +69,7 @@ chat_history_llamav2_from_assistant = PromptTemplate.from_template( conversation_llamav2 = PromptTemplate.from_template( """ -[INST]{query}[/INST] +[INST] {query} [/INST] """.strip() ) @@ -91,7 +91,7 @@ Question: {query} notes_conversation_llamav2 = PromptTemplate.from_template( """ -Notes: +User's Notes: {references} Question: {query} """.strip() @@ -134,19 +134,25 @@ Answer (in second person):""" extract_questions_llamav2_sample = PromptTemplate.from_template( """ -[INST]<>Current Date: {current_date}<>[/INST] -[INST]How was my trip to Cambodia?[/INST][] -[INST]Who did I visit the temple with on that trip?[/INST]Who did I visit the temple with in Cambodia? -[INST]How should I take care of my plants?[/INST]What kind of plants do I have? What issues do my plants have? -[INST]How many tennis balls fit in the back of a 2002 Honda Civic?[/INST]What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic? -[INST]What did I do for Christmas last year?[/INST]What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}' -[INST]How are you feeling today?[/INST] -[INST]Is Alice older than Bob?[/INST]When was Alice born? What is Bob's age? -[INST]<> +[INST] <>Current Date: {current_date}<> [/INST] +[INST] How was my trip to Cambodia? [/INST] +How was my trip to Cambodia? +[INST] Who did I visit the temple with on that trip? [/INST] +Who did I visit the temple with in Cambodia? +[INST] How should I take care of my plants? [/INST] +What kind of plants do I have? What issues do my plants have? +[INST] How many tennis balls fit in the back of a 2002 Honda Civic? [/INST] +What is the size of a tennis ball? What is the trunk size of a 2002 Honda Civic? +[INST] What did I do for Christmas last year? [/INST] +What did I do for Christmas {last_year} dt>='{last_christmas_date}' dt<'{next_christmas_date}' +[INST] How are you feeling today? [/INST] +[INST] Is Alice older than Bob? [/INST] +When was Alice born? What is Bob's age? +[INST] <> Use these notes from the user's previous conversations to provide a response: {chat_history} -<>[/INST] -[INST]{query}[/INST] +<> [/INST] +[INST] {query} [/INST] """ ) diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 4a92c367..83d51f2d 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -3,24 +3,27 @@ import logging from time import perf_counter import json from datetime import datetime +import queue import tiktoken # External packages from langchain.schema import ChatMessage -from transformers import LlamaTokenizerFast +from transformers import AutoTokenizer # Internal Packages -import queue from khoj.utils.helpers import merge_dicts + logger = logging.getLogger(__name__) -max_prompt_size = { +model_to_prompt_size = { "gpt-3.5-turbo": 4096, "gpt-4": 8192, - "llama-2-7b-chat.ggmlv3.q4_K_S.bin": 1548, + "llama-2-7b-chat.ggmlv3.q4_0.bin": 1548, "gpt-3.5-turbo-16k": 15000, } -tokenizer = {"llama-2-7b-chat.ggmlv3.q4_K_S.bin": "hf-internal-testing/llama-tokenizer"} +model_to_tokenizer = { + "llama-2-7b-chat.ggmlv3.q4_0.bin": "hf-internal-testing/llama-tokenizer", +} class ThreadedGenerator: @@ -82,9 +85,26 @@ def message_to_log( def generate_chatml_messages_with_context( - user_message, system_message, conversation_log={}, model_name="gpt-3.5-turbo", lookback_turns=2 + user_message, + system_message, + conversation_log={}, + model_name="gpt-3.5-turbo", + max_prompt_size=None, + tokenizer_name=None, ): """Generate messages for ChatGPT with context from previous conversation""" + # Set max prompt size from user config, pre-configured for model or to default prompt size + try: + max_prompt_size = max_prompt_size or model_to_prompt_size[model_name] + except: + max_prompt_size = 2000 + logger.warning( + f"Fallback to default prompt size: {max_prompt_size}.\nConfigure max_prompt_size for unsupported model: {model_name} in Khoj settings to longer context window." + ) + + # Scale lookback turns proportional to max prompt size supported by model + lookback_turns = max_prompt_size // 750 + # Extract Chat History for Context chat_logs = [] for chat in conversation_log.get("chat", []): @@ -105,19 +125,28 @@ def generate_chatml_messages_with_context( messages = user_chatml_message + rest_backnforths + system_chatml_message # Truncate oldest messages from conversation history until under max supported prompt size by model - messages = truncate_messages(messages, max_prompt_size[model_name], model_name) + messages = truncate_messages(messages, max_prompt_size, model_name, tokenizer_name) # Return message in chronological order return messages[::-1] -def truncate_messages(messages: list[ChatMessage], max_prompt_size, model_name) -> list[ChatMessage]: +def truncate_messages( + messages: list[ChatMessage], max_prompt_size, model_name: str, tokenizer_name=None +) -> list[ChatMessage]: """Truncate messages to fit within max prompt size supported by model""" - if "llama" in model_name: - encoder = LlamaTokenizerFast.from_pretrained(tokenizer[model_name]) - else: - encoder = tiktoken.encoding_for_model(model_name) + try: + if model_name.startswith("gpt-"): + encoder = tiktoken.encoding_for_model(model_name) + else: + encoder = AutoTokenizer.from_pretrained(tokenizer_name or model_to_tokenizer[model_name]) + except: + default_tokenizer = "hf-internal-testing/llama-tokenizer" + encoder = AutoTokenizer.from_pretrained(default_tokenizer) + logger.warning( + f"Fallback to default chat model tokenizer: {default_tokenizer}.\nConfigure tokenizer for unsupported model: {model_name} in Khoj settings to improve context stuffing." + ) system_message = messages.pop() system_message_tokens = len(encoder.encode(system_message.content)) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 780a6c57..0331500b 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -284,10 +284,11 @@ if not state.demo: except Exception as e: return {"status": "error", "message": str(e)} - @api.post("/config/data/processor/conversation/enable_offline_chat", status_code=200) + @api.post("/config/data/processor/conversation/offline_chat", status_code=200) async def set_processor_enable_offline_chat_config_data( request: Request, enable_offline_chat: bool, + offline_chat_model: Optional[str] = None, client: Optional[str] = None, ): _initialize_config() @@ -301,7 +302,9 @@ if not state.demo: state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore assert state.config.processor.conversation is not None - state.config.processor.conversation.enable_offline_chat = enable_offline_chat + state.config.processor.conversation.offline_chat.enable_offline_chat = enable_offline_chat + if offline_chat_model is not None: + state.config.processor.conversation.offline_chat.chat_model = offline_chat_model state.processor_config = configure_processor(state.config.processor, state.processor_config) update_telemetry_state( @@ -713,7 +716,7 @@ async def chat( conversation_command = ConversationCommand.General if conversation_command == ConversationCommand.Help: - model_type = "offline" if state.processor_config.conversation.enable_offline_chat else "openai" + model_type = "offline" if state.processor_config.conversation.offline_chat.enable_offline_chat else "openai" formatted_help = help_message.format(model=model_type, version=state.khoj_version) return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200) @@ -788,7 +791,7 @@ async def extract_references_and_questions( # Infer search queries from user message with timer("Extracting search queries took", logger): # If we've reached here, either the user has enabled offline chat or the openai model is enabled. - if state.processor_config.conversation.enable_offline_chat: + if state.processor_config.conversation.offline_chat.enable_offline_chat: loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model inferred_queries = extract_questions_offline( defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False @@ -804,7 +807,7 @@ async def extract_references_and_questions( with timer("Searching knowledge base took", logger): result_list = [] for query in inferred_queries: - n_items = min(n, 3) if state.processor_config.conversation.enable_offline_chat else n + n_items = min(n, 3) if state.processor_config.conversation.offline_chat.enable_offline_chat else n result_list.extend( await search( f"{query} {filters_in_query}", diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 267af330..6b42f29c 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -113,7 +113,7 @@ def generate_chat_response( meta_log=meta_log, ) - if state.processor_config.conversation.enable_offline_chat: + if state.processor_config.conversation.offline_chat.enable_offline_chat: loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model chat_response = converse_offline( references=compiled_references, @@ -122,6 +122,9 @@ def generate_chat_response( conversation_log=meta_log, completion_func=partial_completion, conversation_command=conversation_command, + model=state.processor_config.conversation.offline_chat.chat_model, + max_prompt_size=state.processor_config.conversation.max_prompt_size, + tokenizer_name=state.processor_config.conversation.tokenizer, ) elif state.processor_config.conversation.openai_model: @@ -135,6 +138,8 @@ def generate_chat_response( api_key=api_key, completion_func=partial_completion, conversation_command=conversation_command, + max_prompt_size=state.processor_config.conversation.max_prompt_size, + tokenizer_name=state.processor_config.conversation.tokenizer, ) except Exception as e: diff --git a/src/khoj/utils/cli.py b/src/khoj/utils/cli.py index 78a9ccf9..1d6106cb 100644 --- a/src/khoj/utils/cli.py +++ b/src/khoj/utils/cli.py @@ -9,6 +9,7 @@ from khoj.utils.yaml import parse_config_from_file from khoj.migrations.migrate_version import migrate_config_to_version from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema from khoj.migrations.migrate_offline_model import migrate_offline_model +from khoj.migrations.migrate_offline_chat_schema import migrate_offline_chat_schema def cli(args=None): @@ -55,7 +56,12 @@ def cli(args=None): def run_migrations(args): - migrations = [migrate_config_to_version, migrate_processor_conversation_schema, migrate_offline_model] + migrations = [ + migrate_config_to_version, + migrate_processor_conversation_schema, + migrate_offline_model, + migrate_offline_chat_schema, + ] for migration in migrations: args = migration(args) return args diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index a6532346..3930ec98 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -84,7 +84,6 @@ class SearchModels: @dataclass class GPT4AllProcessorConfig: - chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_K_S.bin" loaded_model: Union[Any, None] = None @@ -95,18 +94,20 @@ class ConversationProcessorConfigModel: ): self.openai_model = conversation_config.openai self.gpt4all_model = GPT4AllProcessorConfig() - self.enable_offline_chat = conversation_config.enable_offline_chat + self.offline_chat = conversation_config.offline_chat + self.max_prompt_size = conversation_config.max_prompt_size + self.tokenizer = conversation_config.tokenizer self.conversation_logfile = Path(conversation_config.conversation_logfile) self.chat_session: List[str] = [] self.meta_log: dict = {} - if self.enable_offline_chat: + if self.offline_chat.enable_offline_chat: try: - self.gpt4all_model.loaded_model = download_model(self.gpt4all_model.chat_model) + self.gpt4all_model.loaded_model = download_model(self.offline_chat.chat_model) except ValueError as e: + self.offline_chat.enable_offline_chat = False self.gpt4all_model.loaded_model = None logger.error(f"Error while loading offline chat model: {e}", exc_info=True) - self.enable_offline_chat = False else: self.gpt4all_model.loaded_model = None diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index 0a916db4..f7c42266 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -91,10 +91,17 @@ class OpenAIProcessorConfig(ConfigBase): chat_model: Optional[str] = "gpt-3.5-turbo" +class OfflineChatProcessorConfig(ConfigBase): + enable_offline_chat: Optional[bool] = False + chat_model: Optional[str] = "llama-2-7b-chat.ggmlv3.q4_0.bin" + + class ConversationProcessorConfig(ConfigBase): conversation_logfile: Path openai: Optional[OpenAIProcessorConfig] - enable_offline_chat: Optional[bool] = False + offline_chat: Optional[OfflineChatProcessorConfig] + max_prompt_size: Optional[int] + tokenizer: Optional[str] class ProcessorConfig(ConfigBase): diff --git a/tests/conftest.py b/tests/conftest.py index d851341d..f75dfceb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,7 @@ from khoj.utils.helpers import resolve_absolute_path from khoj.utils.rawconfig import ( ContentConfig, ConversationProcessorConfig, + OfflineChatProcessorConfig, OpenAIProcessorConfig, ProcessorConfig, TextContentConfig, @@ -205,8 +206,9 @@ def processor_config_offline_chat(tmp_path_factory): # Setup conversation processor processor_config = ProcessorConfig() + offline_chat = OfflineChatProcessorConfig(enable_offline_chat=True) processor_config.conversation = ConversationProcessorConfig( - enable_offline_chat=True, + offline_chat=offline_chat, conversation_logfile=processor_dir.joinpath("conversation_logs.json"), ) diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py index d7904ff8..76ed26e7 100644 --- a/tests/test_gpt4all_chat_actors.py +++ b/tests/test_gpt4all_chat_actors.py @@ -24,7 +24,7 @@ from khoj.processor.conversation.gpt4all.utils import download_model from khoj.processor.conversation.utils import message_to_log -MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_K_S.bin" +MODEL_NAME = "llama-2-7b-chat.ggmlv3.q4_0.bin" @pytest.fixture(scope="session") @@ -128,15 +128,15 @@ def test_extract_multiple_explicit_questions_from_message(loaded_model): @pytest.mark.chatquality def test_extract_multiple_implicit_questions_from_message(loaded_model): # Act - response = extract_questions_offline("Is Morpheus taller than Neo?", loaded_model=loaded_model) + response = extract_questions_offline("Is Carl taller than Ross?", loaded_model=loaded_model) # Assert - expected_responses = ["height", "taller", "shorter", "heights"] + expected_responses = ["height", "taller", "shorter", "heights", "who"] assert len(response) <= 3 for question in response: assert any([expected_response in question.lower() for expected_response in expected_responses]), ( - "Expected chat actor to ask follow-up questions about Morpheus and Neo, but got: " + question + "Expected chat actor to ask follow-up questions about Carl and Ross, but got: " + question ) @@ -145,7 +145,7 @@ def test_extract_multiple_implicit_questions_from_message(loaded_model): def test_generate_search_query_using_question_from_chat_history(loaded_model): # Arrange message_list = [ - ("What is the name of Mr. Vader's daughter?", "Princess Leia", []), + ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []), ] # Act @@ -156,17 +156,22 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model): use_history=True, ) - expected_responses = [ - "Vader", - "sons", + all_expected_in_response = [ + "Anderson", + ] + + any_expected_in_response = [ "son", - "Darth", + "sons", "children", ] # Assert assert len(response) >= 1 - assert any([expected_response in response[0] for expected_response in expected_responses]), ( + assert all([expected_response in response[0] for expected_response in all_expected_in_response]), ( + "Expected chat actor to ask for clarification in response, but got: " + response[0] + ) + assert any([expected_response in response[0] for expected_response in any_expected_in_response]), ( "Expected chat actor to ask for clarification in response, but got: " + response[0] ) @@ -176,20 +181,20 @@ def test_generate_search_query_using_question_from_chat_history(loaded_model): def test_generate_search_query_using_answer_from_chat_history(loaded_model): # Arrange message_list = [ - ("What is the name of Mr. Vader's daughter?", "Princess Leia", []), + ("What is the name of Mr. Anderson's daughter?", "Miss Barbara", []), ] # Act response = extract_questions_offline( - "Is she a Jedi?", + "Is she a Doctor?", conversation_log=populate_chat_history(message_list), loaded_model=loaded_model, use_history=True, ) expected_responses = [ - "Leia", - "Vader", + "Barbara", + "Robert", "daughter", ]