@@ -263,9 +288,59 @@
})
};
+ function toggleEnableLocalLLLM(enable) {
+ const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
+ var toggleEnableLocalLLLMButton = document.getElementById("toggle-enable-offline-chat");
+ toggleEnableLocalLLLMButton.classList.remove("disabled");
+ toggleEnableLocalLLLMButton.classList.add("enabled");
+
+ fetch('/api/config/data/processor/conversation/enable_offline_chat' + '?enable_offline_chat=' + enable, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ 'X-CSRFToken': csrfToken
+ },
+ })
+ .then(response => response.json())
+ .then(data => {
+ if (data.status == "ok") {
+ // Toggle the Enabled/Disabled UI based on the action/response.
+ var enableLocalLLLMButton = document.getElementById("set-enable-offline-chat");
+ var disableLocalLLLMButton = document.getElementById("clear-enable-offline-chat");
+ var configuredIcon = document.getElementById("configured-icon-conversation-enable-offline-chat");
+ var toggleEnableLocalLLLMButton = document.getElementById("toggle-enable-offline-chat");
+
+ toggleEnableLocalLLLMButton.classList.remove("enabled");
+ toggleEnableLocalLLLMButton.classList.add("disabled");
+
+
+ if (enable) {
+ enableLocalLLLMButton.classList.add("disabled");
+ enableLocalLLLMButton.classList.remove("enabled");
+
+ configuredIcon.classList.add("enabled");
+ configuredIcon.classList.remove("disabled");
+
+ disableLocalLLLMButton.classList.remove("disabled");
+ disableLocalLLLMButton.classList.add("enabled");
+ } else {
+ enableLocalLLLMButton.classList.remove("disabled");
+ enableLocalLLLMButton.classList.add("enabled");
+
+ configuredIcon.classList.remove("enabled");
+ configuredIcon.classList.add("disabled");
+
+ disableLocalLLLMButton.classList.add("disabled");
+ disableLocalLLLMButton.classList.remove("enabled");
+ }
+
+ }
+ })
+ }
+
function clearConversationProcessor() {
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
- fetch('/api/delete/config/data/processor/conversation', {
+ fetch('/api/delete/config/data/processor/conversation/openai', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@@ -319,7 +394,7 @@
function updateIndex(force, successText, errorText, button, loadingText, emoji) {
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
button.disabled = true;
- button.innerHTML = emoji + loadingText;
+ button.innerHTML = emoji + " " + loadingText;
fetch('/api/update?&client=web&force=' + force, {
method: 'GET',
headers: {
diff --git a/src/khoj/interface/web/processor_conversation_input.html b/src/khoj/interface/web/processor_conversation_input.html
index 24cbc666..627d3ccf 100644
--- a/src/khoj/interface/web/processor_conversation_input.html
+++ b/src/khoj/interface/web/processor_conversation_input.html
@@ -13,7 +13,7 @@
Save
@@ -54,21 +36,23 @@
submit.addEventListener("click", function(event) {
event.preventDefault();
var openai_api_key = document.getElementById("openai-api-key").value;
- var conversation_logfile = document.getElementById("conversation-logfile").value;
- var model = document.getElementById("model").value;
var chat_model = document.getElementById("chat-model").value;
+ if (openai_api_key == "" || chat_model == "") {
+ document.getElementById("success").innerHTML = "⚠️ Please fill all the fields.";
+ document.getElementById("success").style.display = "block";
+ return;
+ }
+
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
- fetch('/api/config/data/processor/conversation', {
+ fetch('/api/config/data/processor/conversation/openai', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-CSRFToken': csrfToken
},
body: JSON.stringify({
- "openai_api_key": openai_api_key,
- "conversation_logfile": conversation_logfile,
- "model": model,
+ "api_key": openai_api_key,
"chat_model": chat_model
})
})
diff --git a/src/khoj/migrations/__init__.py b/src/khoj/migrations/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/khoj/migrations/migrate_processor_config_openai.py b/src/khoj/migrations/migrate_processor_config_openai.py
new file mode 100644
index 00000000..54912159
--- /dev/null
+++ b/src/khoj/migrations/migrate_processor_config_openai.py
@@ -0,0 +1,66 @@
+"""
+Current format of khoj.yml
+---
+app:
+ should-log-telemetry: true
+content-type:
+ ...
+processor:
+ conversation:
+ chat-model: gpt-3.5-turbo
+ conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
+ model: text-davinci-003
+ openai-api-key: sk-secret-key
+search-type:
+ ...
+
+New format of khoj.yml
+---
+app:
+ should-log-telemetry: true
+content-type:
+ ...
+processor:
+ conversation:
+ openai:
+ chat-model: gpt-3.5-turbo
+ openai-api-key: sk-secret-key
+ conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
+ enable-offline-chat: false
+search-type:
+ ...
+"""
+from khoj.utils.yaml import load_config_from_file, save_config_to_file
+
+
+def migrate_processor_conversation_schema(args):
+ raw_config = load_config_from_file(args.config_file)
+
+ raw_config["version"] = args.version_no
+
+ if "processor" not in raw_config:
+ return args
+ if raw_config["processor"] is None:
+ return args
+ if "conversation" not in raw_config["processor"]:
+ return args
+
+ # Add enable_offline_chat to khoj config schema
+ if "enable-offline-chat" not in raw_config["processor"]["conversation"]:
+ raw_config["processor"]["conversation"]["enable-offline-chat"] = False
+ save_config_to_file(raw_config, args.config_file)
+
+ current_openai_api_key = raw_config["processor"]["conversation"].get("openai-api-key", None)
+ current_chat_model = raw_config["processor"]["conversation"].get("chat-model", None)
+ if current_openai_api_key is None and current_chat_model is None:
+ return args
+
+ conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None)
+
+ raw_config["processor"]["conversation"] = {
+ "openai": {"chat-model": current_chat_model, "api-key": current_openai_api_key},
+ "conversation-logfile": conversation_logfile,
+ "enable-offline-chat": False,
+ }
+ save_config_to_file(raw_config, args.config_file)
+ return args
diff --git a/src/khoj/migrations/migrate_version.py b/src/khoj/migrations/migrate_version.py
new file mode 100644
index 00000000..d002fe1a
--- /dev/null
+++ b/src/khoj/migrations/migrate_version.py
@@ -0,0 +1,16 @@
+from khoj.utils.yaml import load_config_from_file, save_config_to_file
+
+
+def migrate_config_to_version(args):
+ raw_config = load_config_from_file(args.config_file)
+
+ # Add version to khoj config schema
+ if "version" not in raw_config:
+ raw_config["version"] = args.version_no
+ save_config_to_file(raw_config, args.config_file)
+
+ # regenerate khoj index on first start of this version
+ # this should refresh index and apply index corruption fixes from #325
+ args.regenerate = True
+
+ return args
diff --git a/src/khoj/processor/conversation/gpt4all/__init__.py b/src/khoj/processor/conversation/gpt4all/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py
new file mode 100644
index 00000000..9c1b710a
--- /dev/null
+++ b/src/khoj/processor/conversation/gpt4all/chat_model.py
@@ -0,0 +1,137 @@
+from typing import Union, List
+from datetime import datetime
+import sys
+import logging
+from threading import Thread
+
+from langchain.schema import ChatMessage
+
+from gpt4all import GPT4All
+
+
+from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
+from khoj.processor.conversation import prompts
+from khoj.utils.constants import empty_escape_sequences
+
+logger = logging.getLogger(__name__)
+
+
+def extract_questions_falcon(
+ text: str,
+ model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
+ loaded_model: Union[GPT4All, None] = None,
+ conversation_log={},
+ use_history: bool = False,
+ run_extraction: bool = False,
+):
+ """
+ Infer search queries to retrieve relevant notes to answer user query
+ """
+ all_questions = text.split("? ")
+ all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
+ if not run_extraction:
+ return all_questions
+
+ gpt4all_model = loaded_model or GPT4All(model)
+
+ # Extract Past User Message and Inferred Questions from Conversation Log
+ chat_history = ""
+
+ if use_history:
+ chat_history = "".join(
+ [
+ f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\nA: {chat["message"]}\n\n'
+ for chat in conversation_log.get("chat", [])[-4:]
+ if chat["by"] == "khoj"
+ ]
+ )
+
+ prompt = prompts.extract_questions_falcon.format(
+ chat_history=chat_history,
+ text=text,
+ )
+ message = prompts.general_conversation_falcon.format(query=prompt)
+ response = gpt4all_model.generate(message, max_tokens=200, top_k=2)
+
+ # Extract, Clean Message from GPT's Response
+ try:
+ questions = (
+ str(response)
+ .strip(empty_escape_sequences)
+ .replace("['", '["')
+ .replace("']", '"]')
+ .replace("', '", '", "')
+ .replace('["', "")
+ .replace('"]', "")
+ .split('", "')
+ )
+ except:
+ logger.warning(f"Falcon returned invalid JSON. Falling back to using user message as search query.\n{response}")
+ return all_questions
+ logger.debug(f"Extracted Questions by Falcon: {questions}")
+ questions.extend(all_questions)
+ return questions
+
+
+def converse_falcon(
+ references,
+ user_query,
+ conversation_log={},
+ model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
+ loaded_model: Union[GPT4All, None] = None,
+ completion_func=None,
+) -> ThreadedGenerator:
+ """
+ Converse with user using Falcon
+ """
+ gpt4all_model = loaded_model or GPT4All(model)
+ # Initialize Variables
+ current_date = datetime.now().strftime("%Y-%m-%d")
+ compiled_references_message = "\n\n".join({f"{item}" for item in references})
+
+ # Get Conversation Primer appropriate to Conversation Type
+ # TODO If compiled_references_message is too long, we need to truncate it.
+ if compiled_references_message == "":
+ conversation_primer = prompts.conversation_falcon.format(query=user_query)
+ else:
+ conversation_primer = prompts.notes_conversation.format(
+ current_date=current_date, query=user_query, references=compiled_references_message
+ )
+
+ # Setup Prompt with Primer or Conversation History
+ messages = generate_chatml_messages_with_context(
+ conversation_primer,
+ prompts.personality.format(),
+ conversation_log,
+ model_name="text-davinci-001", # This isn't actually the model, but this helps us get an approximate encoding to run message truncation.
+ )
+
+ g = ThreadedGenerator(references, completion_func=completion_func)
+ t = Thread(target=llm_thread, args=(g, messages, gpt4all_model))
+ t.start()
+ return g
+
+
+def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
+ user_message = messages[0]
+ system_message = messages[-1]
+ conversation_history = messages[1:-1]
+
+ formatted_messages = [
+ prompts.chat_history_falcon_from_assistant.format(message=system_message)
+ if message.role == "assistant"
+ else prompts.chat_history_falcon_from_user.format(message=message.content)
+ for message in conversation_history
+ ]
+
+ chat_history = "".join(formatted_messages)
+ full_message = system_message.content + chat_history + user_message.content
+
+ prompted_message = prompts.general_conversation_falcon.format(query=full_message)
+ response_iterator = model.generate(
+ prompted_message, streaming=True, max_tokens=256, top_k=1, temp=0, repeat_penalty=2.0
+ )
+ for response in response_iterator:
+ logger.info(response)
+ g.send(response)
+ g.close()
diff --git a/src/khoj/processor/conversation/openai/__init__.py b/src/khoj/processor/conversation/openai/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/khoj/processor/conversation/gpt.py b/src/khoj/processor/conversation/openai/gpt.py
similarity index 97%
rename from src/khoj/processor/conversation/gpt.py
rename to src/khoj/processor/conversation/openai/gpt.py
index e053be15..bf391f09 100644
--- a/src/khoj/processor/conversation/gpt.py
+++ b/src/khoj/processor/conversation/openai/gpt.py
@@ -9,11 +9,11 @@ from langchain.schema import ChatMessage
# Internal Packages
from khoj.utils.constants import empty_escape_sequences
from khoj.processor.conversation import prompts
-from khoj.processor.conversation.utils import (
+from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
completion_with_backoff,
- generate_chatml_messages_with_context,
)
+from khoj.processor.conversation.utils import generate_chatml_messages_with_context
logger = logging.getLogger(__name__)
diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py
new file mode 100644
index 00000000..130532e0
--- /dev/null
+++ b/src/khoj/processor/conversation/openai/utils.py
@@ -0,0 +1,101 @@
+# Standard Packages
+import os
+import logging
+from typing import Any
+from threading import Thread
+
+# External Packages
+from langchain.chat_models import ChatOpenAI
+from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
+from langchain.callbacks.base import BaseCallbackManager
+import openai
+from tenacity import (
+ before_sleep_log,
+ retry,
+ retry_if_exception_type,
+ stop_after_attempt,
+ wait_exponential,
+ wait_random_exponential,
+)
+
+# Internal Packages
+from khoj.processor.conversation.utils import ThreadedGenerator
+
+
+logger = logging.getLogger(__name__)
+
+
+class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
+ def __init__(self, gen: ThreadedGenerator):
+ super().__init__()
+ self.gen = gen
+
+ def on_llm_new_token(self, token: str, **kwargs) -> Any:
+ self.gen.send(token)
+
+
+@retry(
+ retry=(
+ retry_if_exception_type(openai.error.Timeout)
+ | retry_if_exception_type(openai.error.APIError)
+ | retry_if_exception_type(openai.error.APIConnectionError)
+ | retry_if_exception_type(openai.error.RateLimitError)
+ | retry_if_exception_type(openai.error.ServiceUnavailableError)
+ ),
+ wait=wait_random_exponential(min=1, max=10),
+ stop=stop_after_attempt(3),
+ before_sleep=before_sleep_log(logger, logging.DEBUG),
+ reraise=True,
+)
+def completion_with_backoff(**kwargs):
+ messages = kwargs.pop("messages")
+ if not "openai_api_key" in kwargs:
+ kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
+ llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
+ return llm(messages=messages)
+
+
+@retry(
+ retry=(
+ retry_if_exception_type(openai.error.Timeout)
+ | retry_if_exception_type(openai.error.APIError)
+ | retry_if_exception_type(openai.error.APIConnectionError)
+ | retry_if_exception_type(openai.error.RateLimitError)
+ | retry_if_exception_type(openai.error.ServiceUnavailableError)
+ ),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ stop=stop_after_attempt(3),
+ before_sleep=before_sleep_log(logger, logging.DEBUG),
+ reraise=True,
+)
+def chat_completion_with_backoff(
+ messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
+):
+ g = ThreadedGenerator(compiled_references, completion_func=completion_func)
+ t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
+ t.start()
+ return g
+
+
+def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
+ callback_handler = StreamingChatCallbackHandler(g)
+ chat = ChatOpenAI(
+ streaming=True,
+ verbose=True,
+ callback_manager=BaseCallbackManager([callback_handler]),
+ model_name=model_name, # type: ignore
+ temperature=temperature,
+ openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
+ request_timeout=20,
+ max_retries=1,
+ client=None,
+ )
+
+ chat(messages=messages)
+
+ g.close()
+
+
+def extract_summaries(metadata):
+ """Extract summaries from metadata"""
+ return "".join([f'\n{session["summary"]}' for session in metadata])
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index c04e9042..931ba91b 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -18,6 +18,36 @@ Question: {query}
""".strip()
)
+general_conversation_falcon = PromptTemplate.from_template(
+ """
+Using your general knowledge and our past conversations as context, answer the following question.
+### Instruct:
+{query}
+### Response:
+""".strip()
+)
+
+chat_history_falcon_from_user = PromptTemplate.from_template(
+ """
+### Human:
+{message}
+""".strip()
+)
+
+chat_history_falcon_from_assistant = PromptTemplate.from_template(
+ """
+### Assistant:
+{message}
+""".strip()
+)
+
+conversation_falcon = PromptTemplate.from_template(
+ """
+Using our past conversations as context, answer the following question.
+
+Question: {query}
+""".strip()
+)
## Notes Conversation
## --
@@ -33,6 +63,17 @@ Question: {query}
""".strip()
)
+notes_conversation_falcon = PromptTemplate.from_template(
+ """
+Using the notes and our past conversations as context, answer the following question. If the answer is not contained within the notes, say "I don't know."
+
+Notes:
+{references}
+
+Question: {query}
+""".strip()
+)
+
## Summarize Chat
## --
@@ -68,6 +109,40 @@ Question: {user_query}
Answer (in second person):"""
)
+extract_questions_falcon = PromptTemplate.from_template(
+ """
+You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes.
+- The user will provide their questions and answers to you for context.
+- Add as much context from the previous questions and answers as required into your search queries.
+- Break messages into multiple search queries when required to retrieve the relevant information.
+- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
+
+What searches, if any, will you need to perform to answer the users question?
+
+Q: How was my trip to Cambodia?
+
+["How was my trip to Cambodia?"]
+
+A: The trip was amazing. I went to the Angkor Wat temple and it was beautiful.
+
+Q: Who did i visit that temple with?
+
+["Who did I visit the Angkor Wat Temple in Cambodia with?"]
+
+A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
+
+Q: How many tennis balls fit in the back of a 2002 Honda Civic?
+
+["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]
+
+A: 1085 tennis balls will fit in the trunk of a Honda Civic
+
+{chat_history}
+Q: {text}
+
+"""
+)
+
## Extract Questions
## --
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index e77b7899..5cc5e1c6 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -1,35 +1,19 @@
# Standard Packages
-import os
import logging
-from datetime import datetime
from time import perf_counter
-from typing import Any
-from threading import Thread
import json
-
-# External Packages
-from langchain.chat_models import ChatOpenAI
-from langchain.schema import ChatMessage
-from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
-from langchain.callbacks.base import BaseCallbackManager
-import openai
+from datetime import datetime
import tiktoken
-from tenacity import (
- before_sleep_log,
- retry,
- retry_if_exception_type,
- stop_after_attempt,
- wait_exponential,
- wait_random_exponential,
-)
+
+# External packages
+from langchain.schema import ChatMessage
# Internal Packages
-from khoj.utils.helpers import merge_dicts
import queue
-
+from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__)
-max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
+max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "text-davinci-001": 910}
class ThreadedGenerator:
@@ -49,9 +33,9 @@ class ThreadedGenerator:
time_to_response = perf_counter() - self.start_time
logger.info(f"Chat streaming took: {time_to_response:.3f} seconds")
if self.completion_func:
- # The completion func effective acts as a callback.
- # It adds the aggregated response to the conversation history. It's constructed in api.py.
- self.completion_func(gpt_response=self.response)
+ # The completion func effectively acts as a callback.
+ # It adds the aggregated response to the conversation history.
+ self.completion_func(chat_response=self.response)
raise StopIteration
return item
@@ -65,75 +49,25 @@ class ThreadedGenerator:
self.queue.put(StopIteration)
-class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
- def __init__(self, gen: ThreadedGenerator):
- super().__init__()
- self.gen = gen
-
- def on_llm_new_token(self, token: str, **kwargs) -> Any:
- self.gen.send(token)
-
-
-@retry(
- retry=(
- retry_if_exception_type(openai.error.Timeout)
- | retry_if_exception_type(openai.error.APIError)
- | retry_if_exception_type(openai.error.APIConnectionError)
- | retry_if_exception_type(openai.error.RateLimitError)
- | retry_if_exception_type(openai.error.ServiceUnavailableError)
- ),
- wait=wait_random_exponential(min=1, max=10),
- stop=stop_after_attempt(3),
- before_sleep=before_sleep_log(logger, logging.DEBUG),
- reraise=True,
-)
-def completion_with_backoff(**kwargs):
- messages = kwargs.pop("messages")
- if not "openai_api_key" in kwargs:
- kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
- llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
- return llm(messages=messages)
-
-
-@retry(
- retry=(
- retry_if_exception_type(openai.error.Timeout)
- | retry_if_exception_type(openai.error.APIError)
- | retry_if_exception_type(openai.error.APIConnectionError)
- | retry_if_exception_type(openai.error.RateLimitError)
- | retry_if_exception_type(openai.error.ServiceUnavailableError)
- ),
- wait=wait_exponential(multiplier=1, min=4, max=10),
- stop=stop_after_attempt(3),
- before_sleep=before_sleep_log(logger, logging.DEBUG),
- reraise=True,
-)
-def chat_completion_with_backoff(
- messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
+def message_to_log(
+ user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
):
- g = ThreadedGenerator(compiled_references, completion_func=completion_func)
- t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
- t.start()
- return g
+ """Create json logs from messages, metadata for conversation log"""
+ default_khoj_message_metadata = {
+ "intent": {"type": "remember", "memory-type": "notes", "query": user_message},
+ "trigger-emotion": "calm",
+ }
+ khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ # Create json log from Human's message
+ human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
-def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
- callback_handler = StreamingChatCallbackHandler(g)
- chat = ChatOpenAI(
- streaming=True,
- verbose=True,
- callback_manager=BaseCallbackManager([callback_handler]),
- model_name=model_name, # type: ignore
- temperature=temperature,
- openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
- request_timeout=20,
- max_retries=1,
- client=None,
- )
+ # Create json log from GPT's response
+ khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
+ khoj_log = merge_dicts({"message": chat_response, "by": "khoj", "created": khoj_response_time}, khoj_log)
- chat(messages=messages)
-
- g.close()
+ conversation_log.extend([human_log, khoj_log])
+ return conversation_log
def generate_chatml_messages_with_context(
@@ -192,27 +126,3 @@ def truncate_messages(messages, max_prompt_size, model_name):
def reciprocal_conversation_to_chatml(message_pair):
"""Convert a single back and forth between user and assistant to chatml format"""
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
-
-
-def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]):
- """Create json logs from messages, metadata for conversation log"""
- default_khoj_message_metadata = {
- "intent": {"type": "remember", "memory-type": "notes", "query": user_message},
- "trigger-emotion": "calm",
- }
- khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
-
- # Create json log from Human's message
- human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
-
- # Create json log from GPT's response
- khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
- khoj_log = merge_dicts({"message": gpt_message, "by": "khoj", "created": khoj_response_time}, khoj_log)
-
- conversation_log.extend([human_log, khoj_log])
- return conversation_log
-
-
-def extract_summaries(metadata):
- """Extract summaries from metadata"""
- return "".join([f'\n{session["summary"]}' for session in metadata])
diff --git a/src/khoj/processor/github/github_to_jsonl.py b/src/khoj/processor/github/github_to_jsonl.py
index ddab24ce..ddfa6a67 100644
--- a/src/khoj/processor/github/github_to_jsonl.py
+++ b/src/khoj/processor/github/github_to_jsonl.py
@@ -52,7 +52,7 @@ class GithubToJsonl(TextToJsonl):
try:
markdown_files, org_files = self.get_files(repo_url, repo)
except Exception as e:
- logger.error(f"Unable to download github repo {repo_shorthand}")
+ logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
raise e
logger.info(f"Found {len(markdown_files)} markdown files in github repo {repo_shorthand}")
diff --git a/src/khoj/processor/notion/notion_to_jsonl.py b/src/khoj/processor/notion/notion_to_jsonl.py
index 489f0341..cb4c5f84 100644
--- a/src/khoj/processor/notion/notion_to_jsonl.py
+++ b/src/khoj/processor/notion/notion_to_jsonl.py
@@ -219,7 +219,7 @@ class NotionToJsonl(TextToJsonl):
page = self.get_page(page_id)
content = self.get_page_children(page_id)
except Exception as e:
- logger.error(f"Error getting page {page_id}: {e}")
+ logger.error(f"Error getting page {page_id}: {e}", exc_info=True)
return None, None
properties = page["properties"]
title_field = "title"
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 834e8997..ead61c53 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -5,7 +5,7 @@ import time
import yaml
import logging
import json
-from typing import Iterable, List, Optional, Union
+from typing import List, Optional, Union
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request
@@ -26,16 +26,19 @@ from khoj.utils.rawconfig import (
SearchConfig,
SearchResponse,
TextContentConfig,
- ConversationProcessorConfig,
+ OpenAIProcessorConfig,
GithubContentConfig,
NotionContentConfig,
+ ConversationProcessorConfig,
)
+from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.state import SearchType
from khoj.utils import state, constants
from khoj.utils.yaml import save_config_to_file_updated_state
from fastapi.responses import StreamingResponse, Response
from khoj.routers.helpers import perform_chat_checks, generate_chat_response, update_telemetry_state
-from khoj.processor.conversation.gpt import extract_questions
+from khoj.processor.conversation.openai.gpt import extract_questions
+from khoj.processor.conversation.gpt4all.chat_model import extract_questions_falcon, converse_falcon
from fastapi.requests import Request
@@ -50,6 +53,8 @@ if not state.demo:
if state.config is None:
state.config = FullConfig()
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
+ if state.processor_config is None:
+ state.processor_config = configure_processor(state.config.processor)
@api.get("/config/data", response_model=FullConfig)
def get_config_data():
@@ -181,22 +186,28 @@ if not state.demo:
except Exception as e:
return {"status": "error", "message": str(e)}
- @api.post("/delete/config/data/processor/conversation", status_code=200)
+ @api.post("/delete/config/data/processor/conversation/openai", status_code=200)
async def remove_processor_conversation_config_data(
request: Request,
client: Optional[str] = None,
):
- if not state.config or not state.config.processor or not state.config.processor.conversation:
+ if (
+ not state.config
+ or not state.config.processor
+ or not state.config.processor.conversation
+ or not state.config.processor.conversation.openai
+ ):
return {"status": "ok"}
- state.config.processor.conversation = None
+ state.config.processor.conversation.openai = None
+ state.processor_config = configure_processor(state.config.processor, state.processor_config)
update_telemetry_state(
request=request,
telemetry_type="api",
- api="delete_processor_config",
+ api="delete_processor_openai_config",
client=client,
- metadata={"processor_type": "conversation"},
+ metadata={"processor_conversation_type": "openai"},
)
try:
@@ -233,23 +244,66 @@ if not state.demo:
except Exception as e:
return {"status": "error", "message": str(e)}
- @api.post("/config/data/processor/conversation", status_code=200)
- async def set_processor_conversation_config_data(
+ @api.post("/config/data/processor/conversation/openai", status_code=200)
+ async def set_processor_openai_config_data(
request: Request,
- updated_config: Union[ConversationProcessorConfig, None],
+ updated_config: Union[OpenAIProcessorConfig, None],
client: Optional[str] = None,
):
_initialize_config()
- state.config.processor = ProcessorConfig(conversation=updated_config)
- state.processor_config = configure_processor(state.config.processor)
+ if not state.config.processor or not state.config.processor.conversation:
+ default_config = constants.default_config
+ default_conversation_logfile = resolve_absolute_path(
+ default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
+ )
+ conversation_logfile = resolve_absolute_path(default_conversation_logfile)
+ state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
+
+ assert state.config.processor.conversation is not None
+ state.config.processor.conversation.openai = updated_config
+ state.processor_config = configure_processor(state.config.processor, state.processor_config)
update_telemetry_state(
request=request,
telemetry_type="api",
- api="set_content_config",
+ api="set_processor_config",
client=client,
- metadata={"processor_type": "conversation"},
+ metadata={"processor_conversation_type": "conversation"},
+ )
+
+ try:
+ save_config_to_file_updated_state()
+ return {"status": "ok"}
+ except Exception as e:
+ return {"status": "error", "message": str(e)}
+
+ @api.post("/config/data/processor/conversation/enable_offline_chat", status_code=200)
+ async def set_processor_enable_offline_chat_config_data(
+ request: Request,
+ enable_offline_chat: bool,
+ client: Optional[str] = None,
+ ):
+ _initialize_config()
+
+ if not state.config.processor or not state.config.processor.conversation:
+ default_config = constants.default_config
+ default_conversation_logfile = resolve_absolute_path(
+ default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
+ )
+ conversation_logfile = resolve_absolute_path(default_conversation_logfile)
+ state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
+
+ assert state.config.processor.conversation is not None
+ state.config.processor.conversation.enable_offline_chat = enable_offline_chat
+ state.processor_config = configure_processor(state.config.processor, state.processor_config)
+
+ update_telemetry_state(
+ request=request,
+ telemetry_type="api",
+ api="set_processor_config",
+ client=client,
+ metadata={"processor_conversation_type": f"{'enable' if enable_offline_chat else 'disable'}_local_llm"},
)
try:
@@ -569,7 +623,9 @@ def chat_history(
perform_chat_checks()
# Load Conversation History
- meta_log = state.processor_config.conversation.meta_log
+ meta_log = {}
+ if state.processor_config.conversation:
+ meta_log = state.processor_config.conversation.meta_log
update_telemetry_state(
request=request,
@@ -598,24 +654,25 @@ async def chat(
perform_chat_checks()
compiled_references, inferred_queries = await extract_references_and_questions(request, q, (n or 5))
- # Get the (streamed) chat response from GPT.
- gpt_response = generate_chat_response(
+ # Get the (streamed) chat response from the LLM of choice.
+ llm_response = generate_chat_response(
q,
meta_log=state.processor_config.conversation.meta_log,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
)
- if gpt_response is None:
- return Response(content=gpt_response, media_type="text/plain", status_code=500)
+
+ if llm_response is None:
+ return Response(content=llm_response, media_type="text/plain", status_code=500)
if stream:
- return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
+ return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
# Get the full response from the generator if the stream is not requested.
aggregated_gpt_response = ""
while True:
try:
- aggregated_gpt_response += next(gpt_response)
+ aggregated_gpt_response += next(llm_response)
except StopIteration:
break
@@ -645,8 +702,6 @@ async def extract_references_and_questions(
meta_log = state.processor_config.conversation.meta_log
# Initialize Variables
- api_key = state.processor_config.conversation.openai_api_key
- chat_model = state.processor_config.conversation.chat_model
conversation_type = "general" if q.startswith("@general") else "notes"
compiled_references = []
inferred_queries = []
@@ -654,7 +709,13 @@ async def extract_references_and_questions(
if conversation_type == "notes":
# Infer search queries from user message
with timer("Extracting search queries took", logger):
- inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
+ if state.processor_config.conversation and state.processor_config.conversation.openai_model:
+ api_key = state.processor_config.conversation.openai_model.api_key
+ chat_model = state.processor_config.conversation.openai_model.chat_model
+ inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
+ else:
+ loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
+ inferred_queries = extract_questions_falcon(q, loaded_model=loaded_model, conversation_log=meta_log)
# Collate search results as context for GPT
with timer("Searching knowledge base took", logger):
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index fe4cdac2..91608b10 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -7,22 +7,23 @@ from fastapi import HTTPException, Request
from khoj.utils import state
from khoj.utils.helpers import timer, log_telemetry
-from khoj.processor.conversation.gpt import converse
-from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
-
+from khoj.processor.conversation.openai.gpt import converse
+from khoj.processor.conversation.gpt4all.chat_model import converse_falcon
+from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
logger = logging.getLogger(__name__)
def perform_chat_checks():
- if (
- state.processor_config is None
- or state.processor_config.conversation is None
- or state.processor_config.conversation.openai_api_key is None
+ if state.processor_config.conversation and (
+ state.processor_config.conversation.openai_model
+ or state.processor_config.conversation.gpt4all_model.loaded_model
):
- raise HTTPException(
- status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
- )
+ return
+
+ raise HTTPException(
+ status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings and restart it."
+ )
def update_telemetry_state(
@@ -57,19 +58,19 @@ def generate_chat_response(
meta_log: dict,
compiled_references: List[str] = [],
inferred_queries: List[str] = [],
-):
+) -> ThreadedGenerator:
def _save_to_conversation_log(
q: str,
- gpt_response: str,
+ chat_response: str,
user_message_time: str,
compiled_references: List[str],
inferred_queries: List[str],
meta_log,
):
- state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
+ state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, chat_response])
state.processor_config.conversation.meta_log["chat"] = message_to_log(
- q,
- gpt_response,
+ user_message=q,
+ chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
conversation_log=meta_log.get("chat", []),
@@ -79,8 +80,6 @@ def generate_chat_response(
meta_log = state.processor_config.conversation.meta_log
# Initialize Variables
- api_key = state.processor_config.conversation.openai_api_key
- chat_model = state.processor_config.conversation.chat_model
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_type = "general" if q.startswith("@general") else "notes"
@@ -99,12 +98,29 @@ def generate_chat_response(
meta_log=meta_log,
)
- gpt_response = converse(
- compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
- )
+ if state.processor_config.conversation.openai_model:
+ api_key = state.processor_config.conversation.openai_model.api_key
+ chat_model = state.processor_config.conversation.openai_model.chat_model
+ chat_response = converse(
+ compiled_references,
+ q,
+ meta_log,
+ model=chat_model,
+ api_key=api_key,
+ completion_func=partial_completion,
+ )
+ else:
+ loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
+ chat_response = converse_falcon(
+ references=compiled_references,
+ user_query=q,
+ loaded_model=loaded_model,
+ conversation_log=meta_log,
+ completion_func=partial_completion,
+ )
except Exception as e:
- logger.error(e)
+ logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
- return gpt_response
+ return chat_response
diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py
index 29179cdd..663b8675 100644
--- a/src/khoj/routers/web_client.py
+++ b/src/khoj/routers/web_client.py
@@ -3,7 +3,7 @@ from fastapi import APIRouter
from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.templating import Jinja2Templates
-from khoj.utils.rawconfig import TextContentConfig, ConversationProcessorConfig, FullConfig
+from khoj.utils.rawconfig import TextContentConfig, OpenAIProcessorConfig, FullConfig
# Internal Packages
from khoj.utils import constants, state
@@ -151,28 +151,29 @@ if not state.demo:
},
)
- @web_client.get("/config/processor/conversation", response_class=HTMLResponse)
+ @web_client.get("/config/processor/conversation/openai", response_class=HTMLResponse)
def conversation_processor_config_page(request: Request):
default_copy = constants.default_config.copy()
- default_processor_config = default_copy["processor"]["conversation"] # type: ignore
- default_processor_config = ConversationProcessorConfig(
- openai_api_key="",
- model=default_processor_config["model"],
- conversation_logfile=default_processor_config["conversation-logfile"],
+ default_processor_config = default_copy["processor"]["conversation"]["openai"] # type: ignore
+ default_openai_config = OpenAIProcessorConfig(
+ api_key="",
chat_model=default_processor_config["chat-model"],
)
- current_processor_conversation_config = (
- state.config.processor.conversation
- if state.config and state.config.processor and state.config.processor.conversation
- else default_processor_config
+ current_processor_openai_config = (
+ state.config.processor.conversation.openai
+ if state.config
+ and state.config.processor
+ and state.config.processor.conversation
+ and state.config.processor.conversation.openai
+ else default_openai_config
)
- current_processor_conversation_config = json.loads(current_processor_conversation_config.json())
+ current_processor_openai_config = json.loads(current_processor_openai_config.json())
return templates.TemplateResponse(
"processor_conversation_input.html",
context={
"request": request,
- "current_config": current_processor_conversation_config,
+ "current_config": current_processor_openai_config,
},
)
diff --git a/src/khoj/utils/cli.py b/src/khoj/utils/cli.py
index 49acd6e1..9236ab11 100644
--- a/src/khoj/utils/cli.py
+++ b/src/khoj/utils/cli.py
@@ -5,7 +5,9 @@ from importlib.metadata import version
# Internal Packages
from khoj.utils.helpers import resolve_absolute_path
-from khoj.utils.yaml import load_config_from_file, parse_config_from_file, save_config_to_file
+from khoj.utils.yaml import parse_config_from_file
+from khoj.migrations.migrate_version import migrate_config_to_version
+from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
def cli(args=None):
@@ -46,22 +48,14 @@ def cli(args=None):
if not args.config_file.exists():
args.config = None
else:
- args = migrate_config(args)
+ args = run_migrations(args)
args.config = parse_config_from_file(args.config_file)
return args
-def migrate_config(args):
- raw_config = load_config_from_file(args.config_file)
-
- # Add version to khoj config schema
- if "version" not in raw_config:
- raw_config["version"] = args.version_no
- save_config_to_file(raw_config, args.config_file)
-
- # regenerate khoj index on first start of this version
- # this should refresh index and apply index corruption fixes from #325
- args.regenerate = True
-
+def run_migrations(args):
+ migrations = [migrate_config_to_version, migrate_processor_conversation_schema]
+ for migration in migrations:
+ args = migration(args)
return args
diff --git a/src/khoj/utils/config.py b/src/khoj/utils/config.py
index 6ba8b639..1dffc9b8 100644
--- a/src/khoj/utils/config.py
+++ b/src/khoj/utils/config.py
@@ -1,9 +1,12 @@
# System Packages
from __future__ import annotations # to avoid quoting type hints
+
from enum import Enum
from dataclasses import dataclass
from pathlib import Path
-from typing import TYPE_CHECKING, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
+
+from gpt4all import GPT4All
# External Packages
import torch
@@ -13,7 +16,7 @@ if TYPE_CHECKING:
from sentence_transformers import CrossEncoder
from khoj.search_filter.base_filter import BaseFilter
from khoj.utils.models import BaseEncoder
- from khoj.utils.rawconfig import ConversationProcessorConfig, Entry
+ from khoj.utils.rawconfig import ConversationProcessorConfig, Entry, OpenAIProcessorConfig
class SearchType(str, Enum):
@@ -74,15 +77,29 @@ class SearchModels:
plugin_search: Optional[Dict[str, TextSearchModel]] = None
+@dataclass
+class GPT4AllProcessorConfig:
+ chat_model: Optional[str] = "ggml-model-gpt4all-falcon-q4_0.bin"
+ loaded_model: Union[Any, None] = None
+
+
class ConversationProcessorConfigModel:
- def __init__(self, processor_config: ConversationProcessorConfig):
- self.openai_api_key = processor_config.openai_api_key
- self.model = processor_config.model
- self.chat_model = processor_config.chat_model
- self.conversation_logfile = Path(processor_config.conversation_logfile)
+ def __init__(
+ self,
+ conversation_config: ConversationProcessorConfig,
+ ):
+ self.openai_model = conversation_config.openai
+ self.gpt4all_model = GPT4AllProcessorConfig()
+ self.enable_offline_chat = conversation_config.enable_offline_chat
+ self.conversation_logfile = Path(conversation_config.conversation_logfile)
self.chat_session: List[str] = []
self.meta_log: dict = {}
+ if not self.openai_model and self.enable_offline_chat:
+ self.gpt4all_model.loaded_model = GPT4All(self.gpt4all_model.chat_model) # type: ignore
+ else:
+ self.gpt4all_model.loaded_model = None
+
@dataclass
class ProcessorConfigModel:
diff --git a/src/khoj/utils/constants.py b/src/khoj/utils/constants.py
index f1de7d76..1b0efc00 100644
--- a/src/khoj/utils/constants.py
+++ b/src/khoj/utils/constants.py
@@ -62,10 +62,12 @@ default_config = {
},
"processor": {
"conversation": {
- "openai-api-key": None,
- "model": "text-davinci-003",
+ "openai": {
+ "api-key": None,
+ "chat-model": "gpt-3.5-turbo",
+ },
+ "enable-offline-chat": False,
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
- "chat-model": "gpt-3.5-turbo",
}
},
}
diff --git a/src/khoj/utils/models.py b/src/khoj/utils/models.py
index b5850851..b5bbe292 100644
--- a/src/khoj/utils/models.py
+++ b/src/khoj/utils/models.py
@@ -27,12 +27,12 @@ class OpenAI(BaseEncoder):
if (
not state.processor_config
or not state.processor_config.conversation
- or not state.processor_config.conversation.openai_api_key
+ or not state.processor_config.conversation.openai_model
):
raise Exception(
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
)
- openai.api_key = state.processor_config.conversation.openai_api_key
+ openai.api_key = state.processor_config.conversation.openai_model.api_key
self.embedding_dimensions = None
def encode(self, entries, device=None, **kwargs):
diff --git a/src/khoj/utils/rawconfig.py b/src/khoj/utils/rawconfig.py
index d3c9a4ea..af7dda67 100644
--- a/src/khoj/utils/rawconfig.py
+++ b/src/khoj/utils/rawconfig.py
@@ -1,7 +1,7 @@
# System Packages
import json
from pathlib import Path
-from typing import List, Dict, Optional
+from typing import List, Dict, Optional, Union, Any
# External Packages
from pydantic import BaseModel, validator
@@ -103,13 +103,17 @@ class SearchConfig(ConfigBase):
image: Optional[ImageSearchConfig]
-class ConversationProcessorConfig(ConfigBase):
- openai_api_key: str
- conversation_logfile: Path
- model: Optional[str] = "text-davinci-003"
+class OpenAIProcessorConfig(ConfigBase):
+ api_key: str
chat_model: Optional[str] = "gpt-3.5-turbo"
+class ConversationProcessorConfig(ConfigBase):
+ conversation_logfile: Path
+ openai: Optional[OpenAIProcessorConfig]
+ enable_offline_chat: Optional[bool] = False
+
+
class ProcessorConfig(ConfigBase):
conversation: Optional[ConversationProcessorConfig]
diff --git a/tests/data/config.yml b/tests/data/config.yml
index a4258028..96009a42 100644
--- a/tests/data/config.yml
+++ b/tests/data/config.yml
@@ -20,6 +20,7 @@ content-type:
embeddings-file: content_plugin_2_embeddings.pt
input-filter:
- '*2_new.jsonl.gz'
+enable-offline-chat: false
search-type:
asymmetric:
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py
new file mode 100644
index 00000000..f5b3955a
--- /dev/null
+++ b/tests/test_gpt4all_chat_actors.py
@@ -0,0 +1,426 @@
+# Standard Packages
+from datetime import datetime
+
+# External Packages
+import pytest
+
+SKIP_TESTS = True
+pytestmark = pytest.mark.skipif(
+ SKIP_TESTS,
+ reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.",
+)
+
+import freezegun
+from freezegun import freeze_time
+
+from gpt4all import GPT4All
+
+# Internal Packages
+from khoj.processor.conversation.gpt4all.chat_model import converse_falcon, extract_questions_falcon
+
+from khoj.processor.conversation.utils import message_to_log
+
+
+@pytest.fixture(scope="session")
+def loaded_model():
+ return GPT4All("ggml-model-gpt4all-falcon-q4_0.bin")
+
+
+freezegun.configure(extend_ignore_list=["transformers"])
+
+
+# Test
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+@freeze_time("1984-04-02")
+def test_extract_question_with_date_filter_from_relative_day(loaded_model):
+ # Act
+ response = extract_questions_falcon(
+ "Where did I go for dinner yesterday?", loaded_model=loaded_model, run_extraction=True
+ )
+
+ assert len(response) >= 1
+ assert response[-1] == "Where did I go for dinner yesterday?"
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+@freeze_time("1984-04-02")
+def test_extract_question_with_date_filter_from_relative_month(loaded_model):
+ # Act
+ response = extract_questions_falcon("Which countries did I visit last month?", loaded_model=loaded_model)
+
+ # Assert
+ assert len(response) == 1
+ assert response == ["Which countries did I visit last month?"]
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+@freeze_time("1984-04-02")
+def test_extract_question_with_date_filter_from_relative_year(loaded_model):
+ # Act
+ response = extract_questions_falcon(
+ "Which countries have I visited this year?", loaded_model=loaded_model, run_extraction=True
+ )
+
+ # Assert
+ assert len(response) >= 1
+ assert response[-1] == "Which countries have I visited this year?"
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+def test_extract_multiple_explicit_questions_from_message(loaded_model):
+ # Act
+ response = extract_questions_falcon("What is the Sun? What is the Moon?", loaded_model=loaded_model)
+
+ # Assert
+ expected_responses = ["What is the Sun?", "What is the Moon?"]
+ assert len(response) == 2
+ assert expected_responses == response
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+def test_extract_multiple_implicit_questions_from_message(loaded_model):
+ # Act
+ response = extract_questions_falcon("Is Morpheus taller than Neo?", loaded_model=loaded_model, run_extraction=True)
+
+ # Assert
+ expected_responses = [
+ ("morpheus", "neo"),
+ ]
+ assert len(response) == 2
+ assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
+ "Expected two search queries in response but got: " + response[0]
+ )
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+def test_generate_search_query_using_question_from_chat_history(loaded_model):
+ # Arrange
+ message_list = [
+ ("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
+ ]
+
+ # Act
+ response = extract_questions_falcon(
+ "Does he have any sons?",
+ conversation_log=populate_chat_history(message_list),
+ loaded_model=loaded_model,
+ run_extraction=True,
+ use_history=True,
+ )
+
+ expected_responses = [
+ "do not have",
+ "clarify",
+ "am sorry",
+ ]
+
+ # Assert
+ assert len(response) >= 1
+ assert any([expected_response in response[0] for expected_response in expected_responses]), (
+ "Expected chat actor to ask for clarification in response, but got: " + response[0]
+ )
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
+@pytest.mark.chatquality
+def test_generate_search_query_using_answer_from_chat_history(loaded_model):
+ # Arrange
+ message_list = [
+ ("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
+ ]
+
+ # Act
+ response = extract_questions_falcon(
+ "Is she a Jedi?",
+ conversation_log=populate_chat_history(message_list),
+ loaded_model=loaded_model,
+ run_extraction=True,
+ use_history=True,
+ )
+
+ # Assert
+ assert len(response) == 1
+ assert "Leia" in response[0]
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.xfail(reason="Chat actor is not sufficiently date-aware")
+@pytest.mark.chatquality
+def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model):
+ # Arrange
+ message_list = [
+ ("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []),
+ ]
+
+ # Act
+ response = extract_questions_falcon(
+ "What was the Pizza place we ate at over there?",
+ conversation_log=populate_chat_history(message_list),
+ run_extraction=True,
+ loaded_model=loaded_model,
+ )
+
+ # Assert
+ expected_responses = [
+ ("dt>='2000-04-01'", "dt<'2000-05-01'"),
+ ("dt>='2000-04-01'", "dt<='2000-04-30'"),
+ ('dt>="2000-04-01"', 'dt<"2000-05-01"'),
+ ('dt>="2000-04-01"', 'dt<="2000-04-30"'),
+ ]
+ assert len(response) == 1
+ assert "Masai Mara" in response[0]
+ assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
+ "Expected date filter to limit to April 2000 in response but got: " + response[0]
+ )
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
+ # Act
+ response_gen = converse_falcon(
+ references=[], # Assume no context retrieved from notes for the user_query
+ user_query="Hello, my name is Testatron. Who are you?",
+ loaded_model=loaded_model,
+ )
+ response = "".join([response_chunk for response_chunk in response_gen])
+
+ # Assert
+ expected_responses = ["Khoj", "khoj", "khooj", "Khooj", "KHOJ"]
+ assert len(response) > 0
+ assert any([expected_response in response for expected_response in expected_responses]), (
+ "Expected assistants name, [K|k]hoj, in response but got: " + response
+ )
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.xfail(reason="Chat actor isn't really good at proper nouns yet.")
+@pytest.mark.chatquality
+def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model):
+ "Chat actor needs to use context in previous notes and chat history to answer question"
+ # Arrange
+ message_list = [
+ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
+ (
+ "When was I born?",
+ "You were born on 1st April 1984.",
+ ["Testatron was born on 1st April 1984 in Testville."],
+ ),
+ ]
+
+ # Act
+ response_gen = converse_falcon(
+ references=[], # Assume no context retrieved from notes for the user_query
+ user_query="Where was I born?",
+ conversation_log=populate_chat_history(message_list),
+ loaded_model=loaded_model,
+ )
+ response = "".join([response_chunk for response_chunk in response_gen])
+
+ # Assert
+ assert len(response) > 0
+ # Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes
+ assert "Testville" in response
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
+ "Chat actor needs to use context across currently retrieved notes and chat history to answer question"
+ # Arrange
+ message_list = [
+ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
+ ("When was I born?", "You were born on 1st April 1984.", []),
+ ]
+
+ # Act
+ response_gen = converse_falcon(
+ references=[
+ "Testatron was born on 1st April 1984 in Testville."
+ ], # Assume context retrieved from notes for the user_query
+ user_query="Where was I born?",
+ conversation_log=populate_chat_history(message_list),
+ loaded_model=loaded_model,
+ )
+ response = "".join([response_chunk for response_chunk in response_gen])
+
+ # Assert
+ assert len(response) > 0
+ assert "Testville" in response
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.xfail(reason="Chat actor is rather liable to lying.")
+@pytest.mark.chatquality
+def test_refuse_answering_unanswerable_question(loaded_model):
+ "Chat actor should not try make up answers to unanswerable questions."
+ # Arrange
+ message_list = [
+ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
+ ("When was I born?", "You were born on 1st April 1984.", []),
+ ]
+
+ # Act
+ response_gen = converse_falcon(
+ references=[], # Assume no context retrieved from notes for the user_query
+ user_query="Where was I born?",
+ conversation_log=populate_chat_history(message_list),
+ loaded_model=loaded_model,
+ )
+ response = "".join([response_chunk for response_chunk in response_gen])
+
+ # Assert
+ expected_responses = [
+ "don't know",
+ "do not know",
+ "no information",
+ "do not have",
+ "don't have",
+ "cannot answer",
+ "I'm sorry",
+ ]
+ assert len(response) > 0
+ assert any([expected_response in response for expected_response in expected_responses]), (
+ "Expected chat actor to say they don't know in response, but got: " + response
+ )
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+def test_answer_requires_current_date_awareness(loaded_model):
+ "Chat actor should be able to answer questions relative to current date using provided notes"
+ # Arrange
+ context = [
+ f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
+Expenses:Food:Dining 10.00 USD""",
+ f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
+Expenses:Food:Dining 10.00 USD""",
+ f"""2020-04-01 "SuperMercado" "Bananas"
+Expenses:Food:Groceries 10.00 USD""",
+ f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
+Expenses:Food:Dining 10.00 USD""",
+ ]
+
+ # Act
+ response_gen = converse_falcon(
+ references=context, # Assume context retrieved from notes for the user_query
+ user_query="What did I have for Dinner today?",
+ loaded_model=loaded_model,
+ )
+ response = "".join([response_chunk for response_chunk in response_gen])
+
+ # Assert
+ expected_responses = ["tacos", "Tacos"]
+ assert len(response) > 0
+ assert any([expected_response in response for expected_response in expected_responses]), (
+ "Expected [T|t]acos in response, but got: " + response
+ )
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+def test_answer_requires_date_aware_aggregation_across_provided_notes(loaded_model):
+ "Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
+ # Arrange
+ context = [
+ f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
+Expenses:Food:Dining 10.00 USD""",
+ f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
+Expenses:Food:Dining 10.00 USD""",
+ f"""2020-04-01 "SuperMercado" "Bananas"
+Expenses:Food:Groceries 10.00 USD""",
+ f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
+Expenses:Food:Dining 10.00 USD""",
+ ]
+
+ # Act
+ response_gen = converse_falcon(
+ references=context, # Assume context retrieved from notes for the user_query
+ user_query="How much did I spend on dining this year?",
+ loaded_model=loaded_model,
+ )
+ response = "".join([response_chunk for response_chunk in response_gen])
+
+ # Assert
+ assert len(response) > 0
+ assert "20" in response
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.chatquality
+def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded_model):
+ "Chat actor should be able to answer general questions not requiring looking at chat history or notes"
+ # Arrange
+ message_list = [
+ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
+ ("When was I born?", "You were born on 1st April 1984.", []),
+ ("Where was I born?", "You were born Testville.", []),
+ ]
+
+ # Act
+ response_gen = converse_falcon(
+ references=[], # Assume no context retrieved from notes for the user_query
+ user_query="Write a haiku about unit testing in 3 lines",
+ conversation_log=populate_chat_history(message_list),
+ loaded_model=loaded_model,
+ )
+ response = "".join([response_chunk for response_chunk in response_gen])
+
+ # Assert
+ expected_responses = ["test", "testing"]
+ assert len(response.splitlines()) >= 3 # haikus are 3 lines long, but Falcon tends to add a lot of new lines.
+ assert any([expected_response in response.lower() for expected_response in expected_responses]), (
+ "Expected [T|t]est in response, but got: " + response
+ )
+
+
+# ----------------------------------------------------------------------------------------------------
+@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.")
+@pytest.mark.chatquality
+def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
+ "Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
+ # Arrange
+ context = [
+ f"""# Ramya
+My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""",
+ f"""# Fang
+My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""",
+ f"""# Aiyla
+My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""",
+ ]
+
+ # Act
+ response_gen = converse_falcon(
+ references=context, # Assume context retrieved from notes for the user_query
+ user_query="How many kids does my older sister have?",
+ loaded_model=loaded_model,
+ )
+ response = "".join([response_chunk for response_chunk in response_gen])
+
+ # Assert
+ expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"]
+ assert any([expected_response in response for expected_response in expected_responses]), (
+ "Expected chat actor to ask for clarification in response, but got: " + response
+ )
+
+
+# Helpers
+# ----------------------------------------------------------------------------------------------------
+def populate_chat_history(message_list):
+ # Generate conversation logs
+ conversation_log = {"chat": []}
+ for user_message, chat_response, context in message_list:
+ message_to_log(
+ user_message,
+ chat_response,
+ {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
+ conversation_log=conversation_log["chat"],
+ )
+ return conversation_log
diff --git a/tests/test_chat_actors.py b/tests/test_openai_chat_actors.py
similarity index 99%
rename from tests/test_chat_actors.py
rename to tests/test_openai_chat_actors.py
index a1f91188..e84f41f7 100644
--- a/tests/test_chat_actors.py
+++ b/tests/test_openai_chat_actors.py
@@ -8,7 +8,7 @@ import freezegun
from freezegun import freeze_time
# Internal Packages
-from khoj.processor.conversation.gpt import converse, extract_questions
+from khoj.processor.conversation.openai.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log
diff --git a/tests/test_chat_director.py b/tests/test_openai_chat_director.py
similarity index 100%
rename from tests/test_chat_director.py
rename to tests/test_openai_chat_director.py