diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b54236f8..5b0e1f6e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -30,7 +30,6 @@ jobs: fail-fast: false matrix: python_version: - - '3.8' - '3.9' - '3.10' - '3.11' diff --git a/pyproject.toml b/pyproject.toml index e0e6f720..3b22b31e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dependencies = [ "pypdf >= 3.9.0", "requests >= 2.26.0", "bs4 >= 0.0.1", + "gpt4all==1.0.5", ] dynamic = ["version"] diff --git a/src/interface/emacs/khoj.el b/src/interface/emacs/khoj.el index 3685c973..f8620ace 100644 --- a/src/interface/emacs/khoj.el +++ b/src/interface/emacs/khoj.el @@ -373,7 +373,7 @@ CONFIG is json obtained from Khoj config API." (ignore-error json-end-of-file (json-parse-buffer :object-type 'alist :array-type 'list :null-object json-null :false-object json-false)))) (default-index-dir (khoj--get-directory-from-config default-config '(content-type org embeddings-file))) (default-chat-dir (khoj--get-directory-from-config default-config '(processor conversation conversation-logfile))) - (chat-model (or khoj-chat-model (alist-get 'chat-model (alist-get 'conversation (alist-get 'processor default-config))))) + (chat-model (or khoj-chat-model (alist-get 'chat-model (alist-get 'openai (alist-get 'conversation (alist-get 'processor default-config)))))) (default-model (alist-get 'model (alist-get 'conversation (alist-get 'processor default-config)))) (config (or current-config default-config))) @@ -423,15 +423,27 @@ CONFIG is json obtained from Khoj config API." ;; Configure processors (cond ((not khoj-openai-api-key) - (setq config (delq (assoc 'processor config) config))) + (let* ((processor (assoc 'processor config)) + (conversation (assoc 'conversation processor)) + (openai (assoc 'openai conversation))) + (when openai + ;; Unset the `openai' field in the khoj conversation processor config + (message "khoj.el: disable Khoj Chat using OpenAI as your OpenAI API key got removed from config") + (setcdr conversation (delq openai (cdr conversation))) + (setcdr processor (delq conversation (cdr processor))) + (setq config (delq processor config)) + (push conversation (cdr processor)) + (push processor config)))) ((not current-config) (message "khoj.el: Chat not configured yet.") (setq config (delq (assoc 'processor config) config)) (cl-pushnew `(processor . ((conversation . ((conversation-logfile . ,(format "%s/conversation.json" default-chat-dir)) - (chat-model . ,chat-model) - (model . ,default-model) - (openai-api-key . ,khoj-openai-api-key))))) + (openai . ( + (chat-model . ,chat-model) + (api-key . ,khoj-openai-api-key) + )) + )))) config)) ((not (alist-get 'conversation (alist-get 'processor config))) @@ -440,21 +452,19 @@ CONFIG is json obtained from Khoj config API." (setq new-processor-type (delq (assoc 'conversation new-processor-type) new-processor-type)) (cl-pushnew `(conversation . ((conversation-logfile . ,(format "%s/conversation.json" default-chat-dir)) (chat-model . ,chat-model) - (model . ,default-model) (openai-api-key . ,khoj-openai-api-key))) new-processor-type) (setq config (delq (assoc 'processor config) config)) (cl-pushnew `(processor . ,new-processor-type) config))) ;; Else if khoj is not configured with specified openai api key - ((not (and (equal (alist-get 'openai-api-key (alist-get 'conversation (alist-get 'processor config))) khoj-openai-api-key) - (equal (alist-get 'chat-model (alist-get 'conversation (alist-get 'processor config))) khoj-chat-model))) + ((not (and (equal (alist-get 'api-key (alist-get 'openai (alist-get 'conversation (alist-get 'processor config)))) khoj-openai-api-key) + (equal (alist-get 'chat-model (alist-get 'openai (alist-get 'conversation (alist-get 'processor config)))) khoj-chat-model))) (message "khoj.el: Chat configuration has gone stale.") (let* ((chat-directory (khoj--get-directory-from-config config '(processor conversation conversation-logfile))) (new-processor-type (alist-get 'processor config))) (setq new-processor-type (delq (assoc 'conversation new-processor-type) new-processor-type)) (cl-pushnew `(conversation . ((conversation-logfile . ,(format "%s/conversation.json" chat-directory)) - (model . ,default-model) (chat-model . ,khoj-chat-model) (openai-api-key . ,khoj-openai-api-key))) new-processor-type) diff --git a/src/interface/obsidian/src/utils.ts b/src/interface/obsidian/src/utils.ts index 20991da9..13d72de1 100644 --- a/src/interface/obsidian/src/utils.ts +++ b/src/interface/obsidian/src/utils.ts @@ -36,7 +36,7 @@ export async function configureKhojBackend(vault: Vault, setting: KhojSetting, n let khojDefaultMdIndexDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["content-type"]["markdown"]["embeddings-file"]); let khojDefaultPdfIndexDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["content-type"]["pdf"]["embeddings-file"]); let khojDefaultChatDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["processor"]["conversation"]["conversation-logfile"]); - let khojDefaultChatModelName = defaultConfig["processor"]["conversation"]["model"]; + let khojDefaultChatModelName = defaultConfig["processor"]["conversation"]["openai"]["chat-model"]; // Get current config if khoj backend configured, else get default config from khoj backend await request(khoj_already_configured ? khojConfigUrl : `${khojConfigUrl}/default`) @@ -142,25 +142,35 @@ export async function configureKhojBackend(vault: Vault, setting: KhojSetting, n data["processor"] = { "conversation": { "conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`, - "model": khojDefaultChatModelName, - "openai-api-key": setting.openaiApiKey, - } + "openai": { + "chat-model": khojDefaultChatModelName, + "api-key": setting.openaiApiKey, + } + }, } } // Else if khoj config has no conversation processor config - else if (!data["processor"]["conversation"]) { - data["processor"]["conversation"] = { - "conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`, - "model": khojDefaultChatModelName, - "openai-api-key": setting.openaiApiKey, + else if (!data["processor"]["conversation"] || !data["processor"]["conversation"]["openai"]) { + data["processor"] = { + "conversation": { + "conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`, + "openai": { + "chat-model": khojDefaultChatModelName, + "api-key": setting.openaiApiKey, + } + }, } } // Else if khoj is not configured with OpenAI API key from khoj plugin settings - else if (data["processor"]["conversation"]["openai-api-key"] !== setting.openaiApiKey) { - data["processor"]["conversation"] = { - "conversation-logfile": data["processor"]["conversation"]["conversation-logfile"], - "model": data["processor"]["conversation"]["model"], - "openai-api-key": setting.openaiApiKey, + else if (data["processor"]["conversation"]["openai"]["api-key"] !== setting.openaiApiKey) { + data["processor"] = { + "conversation": { + "conversation-logfile": data["processor"]["conversation"]["conversation-logfile"], + "openai": { + "chat-model": data["processor"]["conversation"]["openai"]["chat-model"], + "api-key": setting.openaiApiKey, + } + }, } } diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 469f33ff..5adde214 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -11,7 +11,6 @@ import schedule from fastapi.staticfiles import StaticFiles # Internal Packages -from khoj.processor.conversation.gpt import summarize from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl @@ -28,7 +27,7 @@ from khoj.utils.config import ( ConversationProcessorConfigModel, ) from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts -from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig +from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig, ConversationProcessorConfig from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.file_filter import FileFilter @@ -64,7 +63,7 @@ def configure_server(config: FullConfig, regenerate: bool, search_type: Optional state.config_lock.acquire() state.processor_config = configure_processor(state.config.processor) except Exception as e: - logger.error(f"🚨 Failed to configure processor") + logger.error(f"🚨 Failed to configure processor", exc_info=True) raise e finally: state.config_lock.release() @@ -75,7 +74,7 @@ def configure_server(config: FullConfig, regenerate: bool, search_type: Optional state.SearchType = configure_search_types(state.config) state.search_models = configure_search(state.search_models, state.config.search_type) except Exception as e: - logger.error(f"🚨 Failed to configure search models") + logger.error(f"🚨 Failed to configure search models", exc_info=True) raise e finally: state.config_lock.release() @@ -88,7 +87,7 @@ def configure_server(config: FullConfig, regenerate: bool, search_type: Optional state.content_index, state.config.content_type, state.search_models, regenerate, search_type ) except Exception as e: - logger.error(f"🚨 Failed to index content") + logger.error(f"🚨 Failed to index content", exc_info=True) raise e finally: state.config_lock.release() @@ -117,7 +116,7 @@ if not state.demo: ) logger.info("📬 Content index updated via Scheduler") except Exception as e: - logger.error(f"🚨 Error updating content index via Scheduler: {e}") + logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True) finally: state.config_lock.release() @@ -258,7 +257,9 @@ def configure_content( return content_index -def configure_processor(processor_config: Optional[ProcessorConfig]): +def configure_processor( + processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None +): if not processor_config: logger.warning("🚨 No Processor configuration available.") return None @@ -266,16 +267,47 @@ def configure_processor(processor_config: Optional[ProcessorConfig]): processor = ProcessorConfigModel() # Initialize Conversation Processor - if processor_config.conversation: - logger.info("💬 Setting up conversation processor") - processor.conversation = configure_conversation_processor(processor_config.conversation) + logger.info("💬 Setting up conversation processor") + processor.conversation = configure_conversation_processor(processor_config, state_processor_config) return processor -def configure_conversation_processor(conversation_processor_config): - conversation_processor = ConversationProcessorConfigModel(conversation_processor_config) - conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile) +def configure_conversation_processor( + processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None +): + if ( + not processor_config + or not processor_config.conversation + or not processor_config.conversation.conversation_logfile + ): + default_config = constants.default_config + default_conversation_logfile = resolve_absolute_path( + default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore + ) + conversation_logfile = resolve_absolute_path(default_conversation_logfile) + conversation_config = processor_config.conversation if processor_config else None + conversation_processor = ConversationProcessorConfigModel( + conversation_config=ConversationProcessorConfig( + conversation_logfile=conversation_logfile, + openai=(conversation_config.openai if (conversation_config is not None) else None), + enable_offline_chat=( + conversation_config.enable_offline_chat if (conversation_config is not None) else False + ), + ) + ) + else: + conversation_processor = ConversationProcessorConfigModel( + conversation_config=processor_config.conversation, + ) + conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile) + + # Load Conversation Logs from Disk + if state_processor_config and state_processor_config.conversation and state_processor_config.conversation.meta_log: + conversation_processor.meta_log = state_processor_config.conversation.meta_log + conversation_processor.chat_session = state_processor_config.conversation.chat_session + logger.debug(f"Loaded conversation logs from state") + return conversation_processor if conversation_logfile.is_file(): # Load Metadata Logs from Conversation Logfile @@ -302,12 +334,8 @@ def save_chat_session(): return # Summarize Conversation Logs for this Session - chat_session = state.processor_config.conversation.chat_session - openai_api_key = state.processor_config.conversation.openai_api_key conversation_log = state.processor_config.conversation.meta_log - chat_model = state.processor_config.conversation.chat_model session = { - "summary": summarize(chat_session, model=chat_model, api_key=openai_api_key), "session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"], "session-end": len(conversation_log["chat"]), } @@ -344,6 +372,6 @@ def upload_telemetry(): log[field] = str(log[field]) requests.post(constants.telemetry_server, json=state.telemetry) except Exception as e: - logger.error(f"📡 Error uploading telemetry: {e}") + logger.error(f"📡 Error uploading telemetry: {e}", exc_info=True) else: state.telemetry = [] diff --git a/src/khoj/interface/web/assets/icons/openai-logomark.svg b/src/khoj/interface/web/assets/icons/openai-logomark.svg new file mode 100644 index 00000000..c0bcb8bc --- /dev/null +++ b/src/khoj/interface/web/assets/icons/openai-logomark.svg @@ -0,0 +1 @@ + diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index 70d4be70..cbaa230a 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -167,10 +167,42 @@ text-align: left; } + button.card-button.happy { + color: rgb(0, 146, 0); + } + img.configured-icon { max-width: 16px; } + div.card-action-row.enabled{ + display: block; + } + + img.configured-icon.enabled { + display: inline; + } + + div.card-action-row.disabled, + img.configured-icon.disabled { + display: none; + } + + .loader { + border: 16px solid #f3f3f3; /* Light grey */ + border-top: 16px solid var(--primary); + border-radius: 50%; + width: 16px; + height: 16px; + animation: spin 2s linear infinite; + } + + @keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } + } + + div.finalize-actions { grid-auto-flow: column; grid-gap: 24px; diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 37a5824c..cbdfefe9 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -135,8 +135,8 @@ .then(response => response.json()) .then(data => { if (data.detail) { - // If the server returns a 500 error with detail, render it as a message. - renderMessage("Hi 👋🏾, to get started
1. Get your OpenAI API key
2. Save it in the Khoj chat settings
3. Click Configure on the Khoj settings page", "khoj"); + // If the server returns a 500 error with detail, render a setup hint. + renderMessage("Hi 👋🏾, to get started
1. Get your OpenAI API key
2. Save it in the Khoj chat settings
3. Click Configure on the Khoj settings page", "khoj"); // Disable chat input field and update placeholder text document.getElementById("chat-input").setAttribute("disabled", "disabled"); diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 81968d94..eead2ade 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -20,7 +20,7 @@
-

Set repositories for Khoj to index

+

Set repositories to index

@@ -90,7 +90,7 @@
-

Set markdown files for Khoj to index

+

Set markdown files to index

@@ -125,7 +125,7 @@
-

Set org files for Khoj to index

+

Set org files to index

@@ -160,7 +160,7 @@
-

Set PDF files for Khoj to index

+

Set PDF files to index

@@ -187,10 +187,10 @@
+
+
+ Chat +

+ Offline Chat + Configured +

+
+
+

Setup offline chat (Falcon 7B)

+
+
+ +
+
+ +
+
+
+
+
@@ -263,9 +288,59 @@ }) }; + function toggleEnableLocalLLLM(enable) { + const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; + var toggleEnableLocalLLLMButton = document.getElementById("toggle-enable-offline-chat"); + toggleEnableLocalLLLMButton.classList.remove("disabled"); + toggleEnableLocalLLLMButton.classList.add("enabled"); + + fetch('/api/config/data/processor/conversation/enable_offline_chat' + '?enable_offline_chat=' + enable, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'X-CSRFToken': csrfToken + }, + }) + .then(response => response.json()) + .then(data => { + if (data.status == "ok") { + // Toggle the Enabled/Disabled UI based on the action/response. + var enableLocalLLLMButton = document.getElementById("set-enable-offline-chat"); + var disableLocalLLLMButton = document.getElementById("clear-enable-offline-chat"); + var configuredIcon = document.getElementById("configured-icon-conversation-enable-offline-chat"); + var toggleEnableLocalLLLMButton = document.getElementById("toggle-enable-offline-chat"); + + toggleEnableLocalLLLMButton.classList.remove("enabled"); + toggleEnableLocalLLLMButton.classList.add("disabled"); + + + if (enable) { + enableLocalLLLMButton.classList.add("disabled"); + enableLocalLLLMButton.classList.remove("enabled"); + + configuredIcon.classList.add("enabled"); + configuredIcon.classList.remove("disabled"); + + disableLocalLLLMButton.classList.remove("disabled"); + disableLocalLLLMButton.classList.add("enabled"); + } else { + enableLocalLLLMButton.classList.remove("disabled"); + enableLocalLLLMButton.classList.add("enabled"); + + configuredIcon.classList.remove("enabled"); + configuredIcon.classList.add("disabled"); + + disableLocalLLLMButton.classList.add("disabled"); + disableLocalLLLMButton.classList.remove("enabled"); + } + + } + }) + } + function clearConversationProcessor() { const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; - fetch('/api/delete/config/data/processor/conversation', { + fetch('/api/delete/config/data/processor/conversation/openai', { method: 'POST', headers: { 'Content-Type': 'application/json', @@ -319,7 +394,7 @@ function updateIndex(force, successText, errorText, button, loadingText, emoji) { const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; button.disabled = true; - button.innerHTML = emoji + loadingText; + button.innerHTML = emoji + " " + loadingText; fetch('/api/update?&client=web&force=' + force, { method: 'GET', headers: { diff --git a/src/khoj/interface/web/processor_conversation_input.html b/src/khoj/interface/web/processor_conversation_input.html index 24cbc666..627d3ccf 100644 --- a/src/khoj/interface/web/processor_conversation_input.html +++ b/src/khoj/interface/web/processor_conversation_input.html @@ -13,7 +13,7 @@ - + @@ -25,24 +25,6 @@ - - - - - - - - - -
- - - -
- - - -
@@ -54,21 +36,23 @@ submit.addEventListener("click", function(event) { event.preventDefault(); var openai_api_key = document.getElementById("openai-api-key").value; - var conversation_logfile = document.getElementById("conversation-logfile").value; - var model = document.getElementById("model").value; var chat_model = document.getElementById("chat-model").value; + if (openai_api_key == "" || chat_model == "") { + document.getElementById("success").innerHTML = "⚠️ Please fill all the fields."; + document.getElementById("success").style.display = "block"; + return; + } + const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1]; - fetch('/api/config/data/processor/conversation', { + fetch('/api/config/data/processor/conversation/openai', { method: 'POST', headers: { 'Content-Type': 'application/json', 'X-CSRFToken': csrfToken }, body: JSON.stringify({ - "openai_api_key": openai_api_key, - "conversation_logfile": conversation_logfile, - "model": model, + "api_key": openai_api_key, "chat_model": chat_model }) }) diff --git a/src/khoj/migrations/__init__.py b/src/khoj/migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/migrations/migrate_processor_config_openai.py b/src/khoj/migrations/migrate_processor_config_openai.py new file mode 100644 index 00000000..54912159 --- /dev/null +++ b/src/khoj/migrations/migrate_processor_config_openai.py @@ -0,0 +1,66 @@ +""" +Current format of khoj.yml +--- +app: + should-log-telemetry: true +content-type: + ... +processor: + conversation: + chat-model: gpt-3.5-turbo + conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json + model: text-davinci-003 + openai-api-key: sk-secret-key +search-type: + ... + +New format of khoj.yml +--- +app: + should-log-telemetry: true +content-type: + ... +processor: + conversation: + openai: + chat-model: gpt-3.5-turbo + openai-api-key: sk-secret-key + conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json + enable-offline-chat: false +search-type: + ... +""" +from khoj.utils.yaml import load_config_from_file, save_config_to_file + + +def migrate_processor_conversation_schema(args): + raw_config = load_config_from_file(args.config_file) + + raw_config["version"] = args.version_no + + if "processor" not in raw_config: + return args + if raw_config["processor"] is None: + return args + if "conversation" not in raw_config["processor"]: + return args + + # Add enable_offline_chat to khoj config schema + if "enable-offline-chat" not in raw_config["processor"]["conversation"]: + raw_config["processor"]["conversation"]["enable-offline-chat"] = False + save_config_to_file(raw_config, args.config_file) + + current_openai_api_key = raw_config["processor"]["conversation"].get("openai-api-key", None) + current_chat_model = raw_config["processor"]["conversation"].get("chat-model", None) + if current_openai_api_key is None and current_chat_model is None: + return args + + conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None) + + raw_config["processor"]["conversation"] = { + "openai": {"chat-model": current_chat_model, "api-key": current_openai_api_key}, + "conversation-logfile": conversation_logfile, + "enable-offline-chat": False, + } + save_config_to_file(raw_config, args.config_file) + return args diff --git a/src/khoj/migrations/migrate_version.py b/src/khoj/migrations/migrate_version.py new file mode 100644 index 00000000..d002fe1a --- /dev/null +++ b/src/khoj/migrations/migrate_version.py @@ -0,0 +1,16 @@ +from khoj.utils.yaml import load_config_from_file, save_config_to_file + + +def migrate_config_to_version(args): + raw_config = load_config_from_file(args.config_file) + + # Add version to khoj config schema + if "version" not in raw_config: + raw_config["version"] = args.version_no + save_config_to_file(raw_config, args.config_file) + + # regenerate khoj index on first start of this version + # this should refresh index and apply index corruption fixes from #325 + args.regenerate = True + + return args diff --git a/src/khoj/processor/conversation/gpt4all/__init__.py b/src/khoj/processor/conversation/gpt4all/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py new file mode 100644 index 00000000..9c1b710a --- /dev/null +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -0,0 +1,137 @@ +from typing import Union, List +from datetime import datetime +import sys +import logging +from threading import Thread + +from langchain.schema import ChatMessage + +from gpt4all import GPT4All + + +from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context +from khoj.processor.conversation import prompts +from khoj.utils.constants import empty_escape_sequences + +logger = logging.getLogger(__name__) + + +def extract_questions_falcon( + text: str, + model: str = "ggml-model-gpt4all-falcon-q4_0.bin", + loaded_model: Union[GPT4All, None] = None, + conversation_log={}, + use_history: bool = False, + run_extraction: bool = False, +): + """ + Infer search queries to retrieve relevant notes to answer user query + """ + all_questions = text.split("? ") + all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]] + if not run_extraction: + return all_questions + + gpt4all_model = loaded_model or GPT4All(model) + + # Extract Past User Message and Inferred Questions from Conversation Log + chat_history = "" + + if use_history: + chat_history = "".join( + [ + f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\nA: {chat["message"]}\n\n' + for chat in conversation_log.get("chat", [])[-4:] + if chat["by"] == "khoj" + ] + ) + + prompt = prompts.extract_questions_falcon.format( + chat_history=chat_history, + text=text, + ) + message = prompts.general_conversation_falcon.format(query=prompt) + response = gpt4all_model.generate(message, max_tokens=200, top_k=2) + + # Extract, Clean Message from GPT's Response + try: + questions = ( + str(response) + .strip(empty_escape_sequences) + .replace("['", '["') + .replace("']", '"]') + .replace("', '", '", "') + .replace('["', "") + .replace('"]', "") + .split('", "') + ) + except: + logger.warning(f"Falcon returned invalid JSON. Falling back to using user message as search query.\n{response}") + return all_questions + logger.debug(f"Extracted Questions by Falcon: {questions}") + questions.extend(all_questions) + return questions + + +def converse_falcon( + references, + user_query, + conversation_log={}, + model: str = "ggml-model-gpt4all-falcon-q4_0.bin", + loaded_model: Union[GPT4All, None] = None, + completion_func=None, +) -> ThreadedGenerator: + """ + Converse with user using Falcon + """ + gpt4all_model = loaded_model or GPT4All(model) + # Initialize Variables + current_date = datetime.now().strftime("%Y-%m-%d") + compiled_references_message = "\n\n".join({f"{item}" for item in references}) + + # Get Conversation Primer appropriate to Conversation Type + # TODO If compiled_references_message is too long, we need to truncate it. + if compiled_references_message == "": + conversation_primer = prompts.conversation_falcon.format(query=user_query) + else: + conversation_primer = prompts.notes_conversation.format( + current_date=current_date, query=user_query, references=compiled_references_message + ) + + # Setup Prompt with Primer or Conversation History + messages = generate_chatml_messages_with_context( + conversation_primer, + prompts.personality.format(), + conversation_log, + model_name="text-davinci-001", # This isn't actually the model, but this helps us get an approximate encoding to run message truncation. + ) + + g = ThreadedGenerator(references, completion_func=completion_func) + t = Thread(target=llm_thread, args=(g, messages, gpt4all_model)) + t.start() + return g + + +def llm_thread(g, messages: List[ChatMessage], model: GPT4All): + user_message = messages[0] + system_message = messages[-1] + conversation_history = messages[1:-1] + + formatted_messages = [ + prompts.chat_history_falcon_from_assistant.format(message=system_message) + if message.role == "assistant" + else prompts.chat_history_falcon_from_user.format(message=message.content) + for message in conversation_history + ] + + chat_history = "".join(formatted_messages) + full_message = system_message.content + chat_history + user_message.content + + prompted_message = prompts.general_conversation_falcon.format(query=full_message) + response_iterator = model.generate( + prompted_message, streaming=True, max_tokens=256, top_k=1, temp=0, repeat_penalty=2.0 + ) + for response in response_iterator: + logger.info(response) + g.send(response) + g.close() diff --git a/src/khoj/processor/conversation/openai/__init__.py b/src/khoj/processor/conversation/openai/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/openai/gpt.py similarity index 97% rename from src/khoj/processor/conversation/gpt.py rename to src/khoj/processor/conversation/openai/gpt.py index e053be15..bf391f09 100644 --- a/src/khoj/processor/conversation/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -9,11 +9,11 @@ from langchain.schema import ChatMessage # Internal Packages from khoj.utils.constants import empty_escape_sequences from khoj.processor.conversation import prompts -from khoj.processor.conversation.utils import ( +from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, completion_with_backoff, - generate_chatml_messages_with_context, ) +from khoj.processor.conversation.utils import generate_chatml_messages_with_context logger = logging.getLogger(__name__) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py new file mode 100644 index 00000000..130532e0 --- /dev/null +++ b/src/khoj/processor/conversation/openai/utils.py @@ -0,0 +1,101 @@ +# Standard Packages +import os +import logging +from typing import Any +from threading import Thread + +# External Packages +from langchain.chat_models import ChatOpenAI +from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler +from langchain.callbacks.base import BaseCallbackManager +import openai +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, + wait_random_exponential, +) + +# Internal Packages +from khoj.processor.conversation.utils import ThreadedGenerator + + +logger = logging.getLogger(__name__) + + +class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler): + def __init__(self, gen: ThreadedGenerator): + super().__init__() + self.gen = gen + + def on_llm_new_token(self, token: str, **kwargs) -> Any: + self.gen.send(token) + + +@retry( + retry=( + retry_if_exception_type(openai.error.Timeout) + | retry_if_exception_type(openai.error.APIError) + | retry_if_exception_type(openai.error.APIConnectionError) + | retry_if_exception_type(openai.error.RateLimitError) + | retry_if_exception_type(openai.error.ServiceUnavailableError) + ), + wait=wait_random_exponential(min=1, max=10), + stop=stop_after_attempt(3), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) +def completion_with_backoff(**kwargs): + messages = kwargs.pop("messages") + if not "openai_api_key" in kwargs: + kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") + llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1) + return llm(messages=messages) + + +@retry( + retry=( + retry_if_exception_type(openai.error.Timeout) + | retry_if_exception_type(openai.error.APIError) + | retry_if_exception_type(openai.error.APIConnectionError) + | retry_if_exception_type(openai.error.RateLimitError) + | retry_if_exception_type(openai.error.ServiceUnavailableError) + ), + wait=wait_exponential(multiplier=1, min=4, max=10), + stop=stop_after_attempt(3), + before_sleep=before_sleep_log(logger, logging.DEBUG), + reraise=True, +) +def chat_completion_with_backoff( + messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None +): + g = ThreadedGenerator(compiled_references, completion_func=completion_func) + t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key)) + t.start() + return g + + +def llm_thread(g, messages, model_name, temperature, openai_api_key=None): + callback_handler = StreamingChatCallbackHandler(g) + chat = ChatOpenAI( + streaming=True, + verbose=True, + callback_manager=BaseCallbackManager([callback_handler]), + model_name=model_name, # type: ignore + temperature=temperature, + openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), + request_timeout=20, + max_retries=1, + client=None, + ) + + chat(messages=messages) + + g.close() + + +def extract_summaries(metadata): + """Extract summaries from metadata""" + return "".join([f'\n{session["summary"]}' for session in metadata]) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index c04e9042..931ba91b 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -18,6 +18,36 @@ Question: {query} """.strip() ) +general_conversation_falcon = PromptTemplate.from_template( + """ +Using your general knowledge and our past conversations as context, answer the following question. +### Instruct: +{query} +### Response: +""".strip() +) + +chat_history_falcon_from_user = PromptTemplate.from_template( + """ +### Human: +{message} +""".strip() +) + +chat_history_falcon_from_assistant = PromptTemplate.from_template( + """ +### Assistant: +{message} +""".strip() +) + +conversation_falcon = PromptTemplate.from_template( + """ +Using our past conversations as context, answer the following question. + +Question: {query} +""".strip() +) ## Notes Conversation ## -- @@ -33,6 +63,17 @@ Question: {query} """.strip() ) +notes_conversation_falcon = PromptTemplate.from_template( + """ +Using the notes and our past conversations as context, answer the following question. If the answer is not contained within the notes, say "I don't know." + +Notes: +{references} + +Question: {query} +""".strip() +) + ## Summarize Chat ## -- @@ -68,6 +109,40 @@ Question: {user_query} Answer (in second person):""" ) +extract_questions_falcon = PromptTemplate.from_template( + """ +You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes. +- The user will provide their questions and answers to you for context. +- Add as much context from the previous questions and answers as required into your search queries. +- Break messages into multiple search queries when required to retrieve the relevant information. +- Add date filters to your search queries from questions and answers when required to retrieve the relevant information. + +What searches, if any, will you need to perform to answer the users question? + +Q: How was my trip to Cambodia? + +["How was my trip to Cambodia?"] + +A: The trip was amazing. I went to the Angkor Wat temple and it was beautiful. + +Q: Who did i visit that temple with? + +["Who did I visit the Angkor Wat Temple in Cambodia with?"] + +A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi. + +Q: How many tennis balls fit in the back of a 2002 Honda Civic? + +["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"] + +A: 1085 tennis balls will fit in the trunk of a Honda Civic + +{chat_history} +Q: {text} + +""" +) + ## Extract Questions ## -- diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e77b7899..5cc5e1c6 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,35 +1,19 @@ # Standard Packages -import os import logging -from datetime import datetime from time import perf_counter -from typing import Any -from threading import Thread import json - -# External Packages -from langchain.chat_models import ChatOpenAI -from langchain.schema import ChatMessage -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler -from langchain.callbacks.base import BaseCallbackManager -import openai +from datetime import datetime import tiktoken -from tenacity import ( - before_sleep_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, - wait_random_exponential, -) + +# External packages +from langchain.schema import ChatMessage # Internal Packages -from khoj.utils.helpers import merge_dicts import queue - +from khoj.utils.helpers import merge_dicts logger = logging.getLogger(__name__) -max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192} +max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "text-davinci-001": 910} class ThreadedGenerator: @@ -49,9 +33,9 @@ class ThreadedGenerator: time_to_response = perf_counter() - self.start_time logger.info(f"Chat streaming took: {time_to_response:.3f} seconds") if self.completion_func: - # The completion func effective acts as a callback. - # It adds the aggregated response to the conversation history. It's constructed in api.py. - self.completion_func(gpt_response=self.response) + # The completion func effectively acts as a callback. + # It adds the aggregated response to the conversation history. + self.completion_func(chat_response=self.response) raise StopIteration return item @@ -65,75 +49,25 @@ class ThreadedGenerator: self.queue.put(StopIteration) -class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler): - def __init__(self, gen: ThreadedGenerator): - super().__init__() - self.gen = gen - - def on_llm_new_token(self, token: str, **kwargs) -> Any: - self.gen.send(token) - - -@retry( - retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) - ), - wait=wait_random_exponential(min=1, max=10), - stop=stop_after_attempt(3), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, -) -def completion_with_backoff(**kwargs): - messages = kwargs.pop("messages") - if not "openai_api_key" in kwargs: - kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY") - llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1) - return llm(messages=messages) - - -@retry( - retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) - ), - wait=wait_exponential(multiplier=1, min=4, max=10), - stop=stop_after_attempt(3), - before_sleep=before_sleep_log(logger, logging.DEBUG), - reraise=True, -) -def chat_completion_with_backoff( - messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None +def message_to_log( + user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[] ): - g = ThreadedGenerator(compiled_references, completion_func=completion_func) - t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key)) - t.start() - return g + """Create json logs from messages, metadata for conversation log""" + default_khoj_message_metadata = { + "intent": {"type": "remember", "memory-type": "notes", "query": user_message}, + "trigger-emotion": "calm", + } + khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + # Create json log from Human's message + human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata) -def llm_thread(g, messages, model_name, temperature, openai_api_key=None): - callback_handler = StreamingChatCallbackHandler(g) - chat = ChatOpenAI( - streaming=True, - verbose=True, - callback_manager=BaseCallbackManager([callback_handler]), - model_name=model_name, # type: ignore - temperature=temperature, - openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"), - request_timeout=20, - max_retries=1, - client=None, - ) + # Create json log from GPT's response + khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata) + khoj_log = merge_dicts({"message": chat_response, "by": "khoj", "created": khoj_response_time}, khoj_log) - chat(messages=messages) - - g.close() + conversation_log.extend([human_log, khoj_log]) + return conversation_log def generate_chatml_messages_with_context( @@ -192,27 +126,3 @@ def truncate_messages(messages, max_prompt_size, model_name): def reciprocal_conversation_to_chatml(message_pair): """Convert a single back and forth between user and assistant to chatml format""" return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])] - - -def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]): - """Create json logs from messages, metadata for conversation log""" - default_khoj_message_metadata = { - "intent": {"type": "remember", "memory-type": "notes", "query": user_message}, - "trigger-emotion": "calm", - } - khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - - # Create json log from Human's message - human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata) - - # Create json log from GPT's response - khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata) - khoj_log = merge_dicts({"message": gpt_message, "by": "khoj", "created": khoj_response_time}, khoj_log) - - conversation_log.extend([human_log, khoj_log]) - return conversation_log - - -def extract_summaries(metadata): - """Extract summaries from metadata""" - return "".join([f'\n{session["summary"]}' for session in metadata]) diff --git a/src/khoj/processor/github/github_to_jsonl.py b/src/khoj/processor/github/github_to_jsonl.py index ddab24ce..ddfa6a67 100644 --- a/src/khoj/processor/github/github_to_jsonl.py +++ b/src/khoj/processor/github/github_to_jsonl.py @@ -52,7 +52,7 @@ class GithubToJsonl(TextToJsonl): try: markdown_files, org_files = self.get_files(repo_url, repo) except Exception as e: - logger.error(f"Unable to download github repo {repo_shorthand}") + logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True) raise e logger.info(f"Found {len(markdown_files)} markdown files in github repo {repo_shorthand}") diff --git a/src/khoj/processor/notion/notion_to_jsonl.py b/src/khoj/processor/notion/notion_to_jsonl.py index 489f0341..cb4c5f84 100644 --- a/src/khoj/processor/notion/notion_to_jsonl.py +++ b/src/khoj/processor/notion/notion_to_jsonl.py @@ -219,7 +219,7 @@ class NotionToJsonl(TextToJsonl): page = self.get_page(page_id) content = self.get_page_children(page_id) except Exception as e: - logger.error(f"Error getting page {page_id}: {e}") + logger.error(f"Error getting page {page_id}: {e}", exc_info=True) return None, None properties = page["properties"] title_field = "title" diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 834e8997..ead61c53 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -5,7 +5,7 @@ import time import yaml import logging import json -from typing import Iterable, List, Optional, Union +from typing import List, Optional, Union # External Packages from fastapi import APIRouter, HTTPException, Header, Request @@ -26,16 +26,19 @@ from khoj.utils.rawconfig import ( SearchConfig, SearchResponse, TextContentConfig, - ConversationProcessorConfig, + OpenAIProcessorConfig, GithubContentConfig, NotionContentConfig, + ConversationProcessorConfig, ) +from khoj.utils.helpers import resolve_absolute_path from khoj.utils.state import SearchType from khoj.utils import state, constants from khoj.utils.yaml import save_config_to_file_updated_state from fastapi.responses import StreamingResponse, Response from khoj.routers.helpers import perform_chat_checks, generate_chat_response, update_telemetry_state -from khoj.processor.conversation.gpt import extract_questions +from khoj.processor.conversation.openai.gpt import extract_questions +from khoj.processor.conversation.gpt4all.chat_model import extract_questions_falcon, converse_falcon from fastapi.requests import Request @@ -50,6 +53,8 @@ if not state.demo: if state.config is None: state.config = FullConfig() state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"]) + if state.processor_config is None: + state.processor_config = configure_processor(state.config.processor) @api.get("/config/data", response_model=FullConfig) def get_config_data(): @@ -181,22 +186,28 @@ if not state.demo: except Exception as e: return {"status": "error", "message": str(e)} - @api.post("/delete/config/data/processor/conversation", status_code=200) + @api.post("/delete/config/data/processor/conversation/openai", status_code=200) async def remove_processor_conversation_config_data( request: Request, client: Optional[str] = None, ): - if not state.config or not state.config.processor or not state.config.processor.conversation: + if ( + not state.config + or not state.config.processor + or not state.config.processor.conversation + or not state.config.processor.conversation.openai + ): return {"status": "ok"} - state.config.processor.conversation = None + state.config.processor.conversation.openai = None + state.processor_config = configure_processor(state.config.processor, state.processor_config) update_telemetry_state( request=request, telemetry_type="api", - api="delete_processor_config", + api="delete_processor_openai_config", client=client, - metadata={"processor_type": "conversation"}, + metadata={"processor_conversation_type": "openai"}, ) try: @@ -233,23 +244,66 @@ if not state.demo: except Exception as e: return {"status": "error", "message": str(e)} - @api.post("/config/data/processor/conversation", status_code=200) - async def set_processor_conversation_config_data( + @api.post("/config/data/processor/conversation/openai", status_code=200) + async def set_processor_openai_config_data( request: Request, - updated_config: Union[ConversationProcessorConfig, None], + updated_config: Union[OpenAIProcessorConfig, None], client: Optional[str] = None, ): _initialize_config() - state.config.processor = ProcessorConfig(conversation=updated_config) - state.processor_config = configure_processor(state.config.processor) + if not state.config.processor or not state.config.processor.conversation: + default_config = constants.default_config + default_conversation_logfile = resolve_absolute_path( + default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore + ) + conversation_logfile = resolve_absolute_path(default_conversation_logfile) + state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore + + assert state.config.processor.conversation is not None + state.config.processor.conversation.openai = updated_config + state.processor_config = configure_processor(state.config.processor, state.processor_config) update_telemetry_state( request=request, telemetry_type="api", - api="set_content_config", + api="set_processor_config", client=client, - metadata={"processor_type": "conversation"}, + metadata={"processor_conversation_type": "conversation"}, + ) + + try: + save_config_to_file_updated_state() + return {"status": "ok"} + except Exception as e: + return {"status": "error", "message": str(e)} + + @api.post("/config/data/processor/conversation/enable_offline_chat", status_code=200) + async def set_processor_enable_offline_chat_config_data( + request: Request, + enable_offline_chat: bool, + client: Optional[str] = None, + ): + _initialize_config() + + if not state.config.processor or not state.config.processor.conversation: + default_config = constants.default_config + default_conversation_logfile = resolve_absolute_path( + default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore + ) + conversation_logfile = resolve_absolute_path(default_conversation_logfile) + state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore + + assert state.config.processor.conversation is not None + state.config.processor.conversation.enable_offline_chat = enable_offline_chat + state.processor_config = configure_processor(state.config.processor, state.processor_config) + + update_telemetry_state( + request=request, + telemetry_type="api", + api="set_processor_config", + client=client, + metadata={"processor_conversation_type": f"{'enable' if enable_offline_chat else 'disable'}_local_llm"}, ) try: @@ -569,7 +623,9 @@ def chat_history( perform_chat_checks() # Load Conversation History - meta_log = state.processor_config.conversation.meta_log + meta_log = {} + if state.processor_config.conversation: + meta_log = state.processor_config.conversation.meta_log update_telemetry_state( request=request, @@ -598,24 +654,25 @@ async def chat( perform_chat_checks() compiled_references, inferred_queries = await extract_references_and_questions(request, q, (n or 5)) - # Get the (streamed) chat response from GPT. - gpt_response = generate_chat_response( + # Get the (streamed) chat response from the LLM of choice. + llm_response = generate_chat_response( q, meta_log=state.processor_config.conversation.meta_log, compiled_references=compiled_references, inferred_queries=inferred_queries, ) - if gpt_response is None: - return Response(content=gpt_response, media_type="text/plain", status_code=500) + + if llm_response is None: + return Response(content=llm_response, media_type="text/plain", status_code=500) if stream: - return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200) + return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200) # Get the full response from the generator if the stream is not requested. aggregated_gpt_response = "" while True: try: - aggregated_gpt_response += next(gpt_response) + aggregated_gpt_response += next(llm_response) except StopIteration: break @@ -645,8 +702,6 @@ async def extract_references_and_questions( meta_log = state.processor_config.conversation.meta_log # Initialize Variables - api_key = state.processor_config.conversation.openai_api_key - chat_model = state.processor_config.conversation.chat_model conversation_type = "general" if q.startswith("@general") else "notes" compiled_references = [] inferred_queries = [] @@ -654,7 +709,13 @@ async def extract_references_and_questions( if conversation_type == "notes": # Infer search queries from user message with timer("Extracting search queries took", logger): - inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log) + if state.processor_config.conversation and state.processor_config.conversation.openai_model: + api_key = state.processor_config.conversation.openai_model.api_key + chat_model = state.processor_config.conversation.openai_model.chat_model + inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log) + else: + loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model + inferred_queries = extract_questions_falcon(q, loaded_model=loaded_model, conversation_log=meta_log) # Collate search results as context for GPT with timer("Searching knowledge base took", logger): diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index fe4cdac2..91608b10 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -7,22 +7,23 @@ from fastapi import HTTPException, Request from khoj.utils import state from khoj.utils.helpers import timer, log_telemetry -from khoj.processor.conversation.gpt import converse -from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml - +from khoj.processor.conversation.openai.gpt import converse +from khoj.processor.conversation.gpt4all.chat_model import converse_falcon +from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator logger = logging.getLogger(__name__) def perform_chat_checks(): - if ( - state.processor_config is None - or state.processor_config.conversation is None - or state.processor_config.conversation.openai_api_key is None + if state.processor_config.conversation and ( + state.processor_config.conversation.openai_model + or state.processor_config.conversation.gpt4all_model.loaded_model ): - raise HTTPException( - status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat." - ) + return + + raise HTTPException( + status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings and restart it." + ) def update_telemetry_state( @@ -57,19 +58,19 @@ def generate_chat_response( meta_log: dict, compiled_references: List[str] = [], inferred_queries: List[str] = [], -): +) -> ThreadedGenerator: def _save_to_conversation_log( q: str, - gpt_response: str, + chat_response: str, user_message_time: str, compiled_references: List[str], inferred_queries: List[str], meta_log, ): - state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response]) + state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, chat_response]) state.processor_config.conversation.meta_log["chat"] = message_to_log( - q, - gpt_response, + user_message=q, + chat_response=chat_response, user_message_metadata={"created": user_message_time}, khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}}, conversation_log=meta_log.get("chat", []), @@ -79,8 +80,6 @@ def generate_chat_response( meta_log = state.processor_config.conversation.meta_log # Initialize Variables - api_key = state.processor_config.conversation.openai_api_key - chat_model = state.processor_config.conversation.chat_model user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") conversation_type = "general" if q.startswith("@general") else "notes" @@ -99,12 +98,29 @@ def generate_chat_response( meta_log=meta_log, ) - gpt_response = converse( - compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion - ) + if state.processor_config.conversation.openai_model: + api_key = state.processor_config.conversation.openai_model.api_key + chat_model = state.processor_config.conversation.openai_model.chat_model + chat_response = converse( + compiled_references, + q, + meta_log, + model=chat_model, + api_key=api_key, + completion_func=partial_completion, + ) + else: + loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model + chat_response = converse_falcon( + references=compiled_references, + user_query=q, + loaded_model=loaded_model, + conversation_log=meta_log, + completion_func=partial_completion, + ) except Exception as e: - logger.error(e) + logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) - return gpt_response + return chat_response diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 29179cdd..663b8675 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -3,7 +3,7 @@ from fastapi import APIRouter from fastapi import Request from fastapi.responses import HTMLResponse, FileResponse from fastapi.templating import Jinja2Templates -from khoj.utils.rawconfig import TextContentConfig, ConversationProcessorConfig, FullConfig +from khoj.utils.rawconfig import TextContentConfig, OpenAIProcessorConfig, FullConfig # Internal Packages from khoj.utils import constants, state @@ -151,28 +151,29 @@ if not state.demo: }, ) - @web_client.get("/config/processor/conversation", response_class=HTMLResponse) + @web_client.get("/config/processor/conversation/openai", response_class=HTMLResponse) def conversation_processor_config_page(request: Request): default_copy = constants.default_config.copy() - default_processor_config = default_copy["processor"]["conversation"] # type: ignore - default_processor_config = ConversationProcessorConfig( - openai_api_key="", - model=default_processor_config["model"], - conversation_logfile=default_processor_config["conversation-logfile"], + default_processor_config = default_copy["processor"]["conversation"]["openai"] # type: ignore + default_openai_config = OpenAIProcessorConfig( + api_key="", chat_model=default_processor_config["chat-model"], ) - current_processor_conversation_config = ( - state.config.processor.conversation - if state.config and state.config.processor and state.config.processor.conversation - else default_processor_config + current_processor_openai_config = ( + state.config.processor.conversation.openai + if state.config + and state.config.processor + and state.config.processor.conversation + and state.config.processor.conversation.openai + else default_openai_config ) - current_processor_conversation_config = json.loads(current_processor_conversation_config.json()) + current_processor_openai_config = json.loads(current_processor_openai_config.json()) return templates.TemplateResponse( "processor_conversation_input.html", context={ "request": request, - "current_config": current_processor_conversation_config, + "current_config": current_processor_openai_config, }, ) diff --git a/src/khoj/utils/cli.py b/src/khoj/utils/cli.py index 49acd6e1..9236ab11 100644 --- a/src/khoj/utils/cli.py +++ b/src/khoj/utils/cli.py @@ -5,7 +5,9 @@ from importlib.metadata import version # Internal Packages from khoj.utils.helpers import resolve_absolute_path -from khoj.utils.yaml import load_config_from_file, parse_config_from_file, save_config_to_file +from khoj.utils.yaml import parse_config_from_file +from khoj.migrations.migrate_version import migrate_config_to_version +from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema def cli(args=None): @@ -46,22 +48,14 @@ def cli(args=None): if not args.config_file.exists(): args.config = None else: - args = migrate_config(args) + args = run_migrations(args) args.config = parse_config_from_file(args.config_file) return args -def migrate_config(args): - raw_config = load_config_from_file(args.config_file) - - # Add version to khoj config schema - if "version" not in raw_config: - raw_config["version"] = args.version_no - save_config_to_file(raw_config, args.config_file) - - # regenerate khoj index on first start of this version - # this should refresh index and apply index corruption fixes from #325 - args.regenerate = True - +def run_migrations(args): + migrations = [migrate_config_to_version, migrate_processor_conversation_schema] + for migration in migrations: + args = migration(args) return args diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py index 6ba8b639..1dffc9b8 100644 --- a/src/khoj/utils/config.py +++ b/src/khoj/utils/config.py @@ -1,9 +1,12 @@ # System Packages from __future__ import annotations # to avoid quoting type hints + from enum import Enum from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any + +from gpt4all import GPT4All # External Packages import torch @@ -13,7 +16,7 @@ if TYPE_CHECKING: from sentence_transformers import CrossEncoder from khoj.search_filter.base_filter import BaseFilter from khoj.utils.models import BaseEncoder - from khoj.utils.rawconfig import ConversationProcessorConfig, Entry + from khoj.utils.rawconfig import ConversationProcessorConfig, Entry, OpenAIProcessorConfig class SearchType(str, Enum): @@ -74,15 +77,29 @@ class SearchModels: plugin_search: Optional[Dict[str, TextSearchModel]] = None +@dataclass +class GPT4AllProcessorConfig: + chat_model: Optional[str] = "ggml-model-gpt4all-falcon-q4_0.bin" + loaded_model: Union[Any, None] = None + + class ConversationProcessorConfigModel: - def __init__(self, processor_config: ConversationProcessorConfig): - self.openai_api_key = processor_config.openai_api_key - self.model = processor_config.model - self.chat_model = processor_config.chat_model - self.conversation_logfile = Path(processor_config.conversation_logfile) + def __init__( + self, + conversation_config: ConversationProcessorConfig, + ): + self.openai_model = conversation_config.openai + self.gpt4all_model = GPT4AllProcessorConfig() + self.enable_offline_chat = conversation_config.enable_offline_chat + self.conversation_logfile = Path(conversation_config.conversation_logfile) self.chat_session: List[str] = [] self.meta_log: dict = {} + if not self.openai_model and self.enable_offline_chat: + self.gpt4all_model.loaded_model = GPT4All(self.gpt4all_model.chat_model) # type: ignore + else: + self.gpt4all_model.loaded_model = None + @dataclass class ProcessorConfigModel: diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py index f1de7d76..1b0efc00 100644 --- a/src/khoj/utils/constants.py +++ b/src/khoj/utils/constants.py @@ -62,10 +62,12 @@ default_config = { }, "processor": { "conversation": { - "openai-api-key": None, - "model": "text-davinci-003", + "openai": { + "api-key": None, + "chat-model": "gpt-3.5-turbo", + }, + "enable-offline-chat": False, "conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json", - "chat-model": "gpt-3.5-turbo", } }, } diff --git a/src/khoj/utils/models.py b/src/khoj/utils/models.py index b5850851..b5bbe292 100644 --- a/src/khoj/utils/models.py +++ b/src/khoj/utils/models.py @@ -27,12 +27,12 @@ class OpenAI(BaseEncoder): if ( not state.processor_config or not state.processor_config.conversation - or not state.processor_config.conversation.openai_api_key + or not state.processor_config.conversation.openai_model ): raise Exception( f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}" ) - openai.api_key = state.processor_config.conversation.openai_api_key + openai.api_key = state.processor_config.conversation.openai_model.api_key self.embedding_dimensions = None def encode(self, entries, device=None, **kwargs): diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py index d3c9a4ea..af7dda67 100644 --- a/src/khoj/utils/rawconfig.py +++ b/src/khoj/utils/rawconfig.py @@ -1,7 +1,7 @@ # System Packages import json from pathlib import Path -from typing import List, Dict, Optional +from typing import List, Dict, Optional, Union, Any # External Packages from pydantic import BaseModel, validator @@ -103,13 +103,17 @@ class SearchConfig(ConfigBase): image: Optional[ImageSearchConfig] -class ConversationProcessorConfig(ConfigBase): - openai_api_key: str - conversation_logfile: Path - model: Optional[str] = "text-davinci-003" +class OpenAIProcessorConfig(ConfigBase): + api_key: str chat_model: Optional[str] = "gpt-3.5-turbo" +class ConversationProcessorConfig(ConfigBase): + conversation_logfile: Path + openai: Optional[OpenAIProcessorConfig] + enable_offline_chat: Optional[bool] = False + + class ProcessorConfig(ConfigBase): conversation: Optional[ConversationProcessorConfig] diff --git a/tests/data/config.yml b/tests/data/config.yml index a4258028..96009a42 100644 --- a/tests/data/config.yml +++ b/tests/data/config.yml @@ -20,6 +20,7 @@ content-type: embeddings-file: content_plugin_2_embeddings.pt input-filter: - '*2_new.jsonl.gz' +enable-offline-chat: false search-type: asymmetric: cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2 diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py new file mode 100644 index 00000000..f5b3955a --- /dev/null +++ b/tests/test_gpt4all_chat_actors.py @@ -0,0 +1,426 @@ +# Standard Packages +from datetime import datetime + +# External Packages +import pytest + +SKIP_TESTS = True +pytestmark = pytest.mark.skipif( + SKIP_TESTS, + reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.", +) + +import freezegun +from freezegun import freeze_time + +from gpt4all import GPT4All + +# Internal Packages +from khoj.processor.conversation.gpt4all.chat_model import converse_falcon, extract_questions_falcon + +from khoj.processor.conversation.utils import message_to_log + + +@pytest.fixture(scope="session") +def loaded_model(): + return GPT4All("ggml-model-gpt4all-falcon-q4_0.bin") + + +freezegun.configure(extend_ignore_list=["transformers"]) + + +# Test +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +@freeze_time("1984-04-02") +def test_extract_question_with_date_filter_from_relative_day(loaded_model): + # Act + response = extract_questions_falcon( + "Where did I go for dinner yesterday?", loaded_model=loaded_model, run_extraction=True + ) + + assert len(response) >= 1 + assert response[-1] == "Where did I go for dinner yesterday?" + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +@freeze_time("1984-04-02") +def test_extract_question_with_date_filter_from_relative_month(loaded_model): + # Act + response = extract_questions_falcon("Which countries did I visit last month?", loaded_model=loaded_model) + + # Assert + assert len(response) == 1 + assert response == ["Which countries did I visit last month?"] + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +@freeze_time("1984-04-02") +def test_extract_question_with_date_filter_from_relative_year(loaded_model): + # Act + response = extract_questions_falcon( + "Which countries have I visited this year?", loaded_model=loaded_model, run_extraction=True + ) + + # Assert + assert len(response) >= 1 + assert response[-1] == "Which countries have I visited this year?" + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_extract_multiple_explicit_questions_from_message(loaded_model): + # Act + response = extract_questions_falcon("What is the Sun? What is the Moon?", loaded_model=loaded_model) + + # Assert + expected_responses = ["What is the Sun?", "What is the Moon?"] + assert len(response) == 2 + assert expected_responses == response + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_extract_multiple_implicit_questions_from_message(loaded_model): + # Act + response = extract_questions_falcon("Is Morpheus taller than Neo?", loaded_model=loaded_model, run_extraction=True) + + # Assert + expected_responses = [ + ("morpheus", "neo"), + ] + assert len(response) == 2 + assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), ( + "Expected two search queries in response but got: " + response[0] + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_generate_search_query_using_question_from_chat_history(loaded_model): + # Arrange + message_list = [ + ("What is the name of Mr. Vader's daughter?", "Princess Leia", []), + ] + + # Act + response = extract_questions_falcon( + "Does he have any sons?", + conversation_log=populate_chat_history(message_list), + loaded_model=loaded_model, + run_extraction=True, + use_history=True, + ) + + expected_responses = [ + "do not have", + "clarify", + "am sorry", + ] + + # Assert + assert len(response) >= 1 + assert any([expected_response in response[0] for expected_response in expected_responses]), ( + "Expected chat actor to ask for clarification in response, but got: " + response[0] + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.") +@pytest.mark.chatquality +def test_generate_search_query_using_answer_from_chat_history(loaded_model): + # Arrange + message_list = [ + ("What is the name of Mr. Vader's daughter?", "Princess Leia", []), + ] + + # Act + response = extract_questions_falcon( + "Is she a Jedi?", + conversation_log=populate_chat_history(message_list), + loaded_model=loaded_model, + run_extraction=True, + use_history=True, + ) + + # Assert + assert len(response) == 1 + assert "Leia" in response[0] + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.xfail(reason="Chat actor is not sufficiently date-aware") +@pytest.mark.chatquality +def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model): + # Arrange + message_list = [ + ("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []), + ] + + # Act + response = extract_questions_falcon( + "What was the Pizza place we ate at over there?", + conversation_log=populate_chat_history(message_list), + run_extraction=True, + loaded_model=loaded_model, + ) + + # Assert + expected_responses = [ + ("dt>='2000-04-01'", "dt<'2000-05-01'"), + ("dt>='2000-04-01'", "dt<='2000-04-30'"), + ('dt>="2000-04-01"', 'dt<"2000-05-01"'), + ('dt>="2000-04-01"', 'dt<="2000-04-30"'), + ] + assert len(response) == 1 + assert "Masai Mara" in response[0] + assert any([start in response[0] and end in response[0] for start, end in expected_responses]), ( + "Expected date filter to limit to April 2000 in response but got: " + response[0] + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_chat_with_no_chat_history_or_retrieved_content(loaded_model): + # Act + response_gen = converse_falcon( + references=[], # Assume no context retrieved from notes for the user_query + user_query="Hello, my name is Testatron. Who are you?", + loaded_model=loaded_model, + ) + response = "".join([response_chunk for response_chunk in response_gen]) + + # Assert + expected_responses = ["Khoj", "khoj", "khooj", "Khooj", "KHOJ"] + assert len(response) > 0 + assert any([expected_response in response for expected_response in expected_responses]), ( + "Expected assistants name, [K|k]hoj, in response but got: " + response + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.xfail(reason="Chat actor isn't really good at proper nouns yet.") +@pytest.mark.chatquality +def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model): + "Chat actor needs to use context in previous notes and chat history to answer question" + # Arrange + message_list = [ + ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), + ( + "When was I born?", + "You were born on 1st April 1984.", + ["Testatron was born on 1st April 1984 in Testville."], + ), + ] + + # Act + response_gen = converse_falcon( + references=[], # Assume no context retrieved from notes for the user_query + user_query="Where was I born?", + conversation_log=populate_chat_history(message_list), + loaded_model=loaded_model, + ) + response = "".join([response_chunk for response_chunk in response_gen]) + + # Assert + assert len(response) > 0 + # Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes + assert "Testville" in response + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model): + "Chat actor needs to use context across currently retrieved notes and chat history to answer question" + # Arrange + message_list = [ + ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), + ("When was I born?", "You were born on 1st April 1984.", []), + ] + + # Act + response_gen = converse_falcon( + references=[ + "Testatron was born on 1st April 1984 in Testville." + ], # Assume context retrieved from notes for the user_query + user_query="Where was I born?", + conversation_log=populate_chat_history(message_list), + loaded_model=loaded_model, + ) + response = "".join([response_chunk for response_chunk in response_gen]) + + # Assert + assert len(response) > 0 + assert "Testville" in response + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.xfail(reason="Chat actor is rather liable to lying.") +@pytest.mark.chatquality +def test_refuse_answering_unanswerable_question(loaded_model): + "Chat actor should not try make up answers to unanswerable questions." + # Arrange + message_list = [ + ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), + ("When was I born?", "You were born on 1st April 1984.", []), + ] + + # Act + response_gen = converse_falcon( + references=[], # Assume no context retrieved from notes for the user_query + user_query="Where was I born?", + conversation_log=populate_chat_history(message_list), + loaded_model=loaded_model, + ) + response = "".join([response_chunk for response_chunk in response_gen]) + + # Assert + expected_responses = [ + "don't know", + "do not know", + "no information", + "do not have", + "don't have", + "cannot answer", + "I'm sorry", + ] + assert len(response) > 0 + assert any([expected_response in response for expected_response in expected_responses]), ( + "Expected chat actor to say they don't know in response, but got: " + response + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_answer_requires_current_date_awareness(loaded_model): + "Chat actor should be able to answer questions relative to current date using provided notes" + # Arrange + context = [ + f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner" +Expenses:Food:Dining 10.00 USD""", + f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch" +Expenses:Food:Dining 10.00 USD""", + f"""2020-04-01 "SuperMercado" "Bananas" +Expenses:Food:Groceries 10.00 USD""", + f"""2020-01-01 "Naco Taco" "Burittos for Dinner" +Expenses:Food:Dining 10.00 USD""", + ] + + # Act + response_gen = converse_falcon( + references=context, # Assume context retrieved from notes for the user_query + user_query="What did I have for Dinner today?", + loaded_model=loaded_model, + ) + response = "".join([response_chunk for response_chunk in response_gen]) + + # Assert + expected_responses = ["tacos", "Tacos"] + assert len(response) > 0 + assert any([expected_response in response for expected_response in expected_responses]), ( + "Expected [T|t]acos in response, but got: " + response + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_answer_requires_date_aware_aggregation_across_provided_notes(loaded_model): + "Chat actor should be able to answer questions that require date aware aggregation across multiple notes" + # Arrange + context = [ + f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner" +Expenses:Food:Dining 10.00 USD""", + f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch" +Expenses:Food:Dining 10.00 USD""", + f"""2020-04-01 "SuperMercado" "Bananas" +Expenses:Food:Groceries 10.00 USD""", + f"""2020-01-01 "Naco Taco" "Burittos for Dinner" +Expenses:Food:Dining 10.00 USD""", + ] + + # Act + response_gen = converse_falcon( + references=context, # Assume context retrieved from notes for the user_query + user_query="How much did I spend on dining this year?", + loaded_model=loaded_model, + ) + response = "".join([response_chunk for response_chunk in response_gen]) + + # Assert + assert len(response) > 0 + assert "20" in response + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded_model): + "Chat actor should be able to answer general questions not requiring looking at chat history or notes" + # Arrange + message_list = [ + ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), + ("When was I born?", "You were born on 1st April 1984.", []), + ("Where was I born?", "You were born Testville.", []), + ] + + # Act + response_gen = converse_falcon( + references=[], # Assume no context retrieved from notes for the user_query + user_query="Write a haiku about unit testing in 3 lines", + conversation_log=populate_chat_history(message_list), + loaded_model=loaded_model, + ) + response = "".join([response_chunk for response_chunk in response_gen]) + + # Assert + expected_responses = ["test", "testing"] + assert len(response.splitlines()) >= 3 # haikus are 3 lines long, but Falcon tends to add a lot of new lines. + assert any([expected_response in response.lower() for expected_response in expected_responses]), ( + "Expected [T|t]est in response, but got: " + response + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.") +@pytest.mark.chatquality +def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model): + "Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context" + # Arrange + context = [ + f"""# Ramya +My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""", + f"""# Fang +My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""", + f"""# Aiyla +My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""", + ] + + # Act + response_gen = converse_falcon( + references=context, # Assume context retrieved from notes for the user_query + user_query="How many kids does my older sister have?", + loaded_model=loaded_model, + ) + response = "".join([response_chunk for response_chunk in response_gen]) + + # Assert + expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"] + assert any([expected_response in response for expected_response in expected_responses]), ( + "Expected chat actor to ask for clarification in response, but got: " + response + ) + + +# Helpers +# ---------------------------------------------------------------------------------------------------- +def populate_chat_history(message_list): + # Generate conversation logs + conversation_log = {"chat": []} + for user_message, chat_response, context in message_list: + message_to_log( + user_message, + chat_response, + {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, + conversation_log=conversation_log["chat"], + ) + return conversation_log diff --git a/tests/test_chat_actors.py b/tests/test_openai_chat_actors.py similarity index 99% rename from tests/test_chat_actors.py rename to tests/test_openai_chat_actors.py index a1f91188..e84f41f7 100644 --- a/tests/test_chat_actors.py +++ b/tests/test_openai_chat_actors.py @@ -8,7 +8,7 @@ import freezegun from freezegun import freeze_time # Internal Packages -from khoj.processor.conversation.gpt import converse, extract_questions +from khoj.processor.conversation.openai.gpt import converse, extract_questions from khoj.processor.conversation.utils import message_to_log diff --git a/tests/test_chat_director.py b/tests/test_openai_chat_director.py similarity index 100% rename from tests/test_chat_director.py rename to tests/test_openai_chat_director.py