Add support for our first Local LLM 🤖🏠 (#330)

* Add support for gpt4all's falcon model as an additional conversation processor
- Update the UI pages to allow the user to point to the new endpoints for GPT
- Update the internal schemas to support both GPT4 models and OpenAI
- Add unit tests benchmarking some of the Falcon performance
* Add exc_info to include stack trace in error logs for text processors
* Pull shared functions into utils.py to be used across gpt4 and gpt
* Add migration for new processor conversation schema
* Skip GPT4All actor tests due to typing issues
* Fix Obsidian processor configuration in auto-configure flow
* Rename enable_local_llm to enable_offline_chat
This commit is contained in:
sabaimran 2023-07-26 23:27:08 +00:00 committed by GitHub
parent 23d77ee338
commit 8b2af0b5ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 1258 additions and 291 deletions

View file

@ -30,7 +30,6 @@ jobs:
fail-fast: false
matrix:
python_version:
- '3.8'
- '3.9'
- '3.10'
- '3.11'

View file

@ -58,6 +58,7 @@ dependencies = [
"pypdf >= 3.9.0",
"requests >= 2.26.0",
"bs4 >= 0.0.1",
"gpt4all==1.0.5",
]
dynamic = ["version"]

View file

@ -373,7 +373,7 @@ CONFIG is json obtained from Khoj config API."
(ignore-error json-end-of-file (json-parse-buffer :object-type 'alist :array-type 'list :null-object json-null :false-object json-false))))
(default-index-dir (khoj--get-directory-from-config default-config '(content-type org embeddings-file)))
(default-chat-dir (khoj--get-directory-from-config default-config '(processor conversation conversation-logfile)))
(chat-model (or khoj-chat-model (alist-get 'chat-model (alist-get 'conversation (alist-get 'processor default-config)))))
(chat-model (or khoj-chat-model (alist-get 'chat-model (alist-get 'openai (alist-get 'conversation (alist-get 'processor default-config))))))
(default-model (alist-get 'model (alist-get 'conversation (alist-get 'processor default-config))))
(config (or current-config default-config)))
@ -423,15 +423,27 @@ CONFIG is json obtained from Khoj config API."
;; Configure processors
(cond
((not khoj-openai-api-key)
(setq config (delq (assoc 'processor config) config)))
(let* ((processor (assoc 'processor config))
(conversation (assoc 'conversation processor))
(openai (assoc 'openai conversation)))
(when openai
;; Unset the `openai' field in the khoj conversation processor config
(message "khoj.el: disable Khoj Chat using OpenAI as your OpenAI API key got removed from config")
(setcdr conversation (delq openai (cdr conversation)))
(setcdr processor (delq conversation (cdr processor)))
(setq config (delq processor config))
(push conversation (cdr processor))
(push processor config))))
((not current-config)
(message "khoj.el: Chat not configured yet.")
(setq config (delq (assoc 'processor config) config))
(cl-pushnew `(processor . ((conversation . ((conversation-logfile . ,(format "%s/conversation.json" default-chat-dir))
(chat-model . ,chat-model)
(model . ,default-model)
(openai-api-key . ,khoj-openai-api-key)))))
(openai . (
(chat-model . ,chat-model)
(api-key . ,khoj-openai-api-key)
))
))))
config))
((not (alist-get 'conversation (alist-get 'processor config)))
@ -440,21 +452,19 @@ CONFIG is json obtained from Khoj config API."
(setq new-processor-type (delq (assoc 'conversation new-processor-type) new-processor-type))
(cl-pushnew `(conversation . ((conversation-logfile . ,(format "%s/conversation.json" default-chat-dir))
(chat-model . ,chat-model)
(model . ,default-model)
(openai-api-key . ,khoj-openai-api-key)))
new-processor-type)
(setq config (delq (assoc 'processor config) config))
(cl-pushnew `(processor . ,new-processor-type) config)))
;; Else if khoj is not configured with specified openai api key
((not (and (equal (alist-get 'openai-api-key (alist-get 'conversation (alist-get 'processor config))) khoj-openai-api-key)
(equal (alist-get 'chat-model (alist-get 'conversation (alist-get 'processor config))) khoj-chat-model)))
((not (and (equal (alist-get 'api-key (alist-get 'openai (alist-get 'conversation (alist-get 'processor config)))) khoj-openai-api-key)
(equal (alist-get 'chat-model (alist-get 'openai (alist-get 'conversation (alist-get 'processor config)))) khoj-chat-model)))
(message "khoj.el: Chat configuration has gone stale.")
(let* ((chat-directory (khoj--get-directory-from-config config '(processor conversation conversation-logfile)))
(new-processor-type (alist-get 'processor config)))
(setq new-processor-type (delq (assoc 'conversation new-processor-type) new-processor-type))
(cl-pushnew `(conversation . ((conversation-logfile . ,(format "%s/conversation.json" chat-directory))
(model . ,default-model)
(chat-model . ,khoj-chat-model)
(openai-api-key . ,khoj-openai-api-key)))
new-processor-type)

View file

@ -36,7 +36,7 @@ export async function configureKhojBackend(vault: Vault, setting: KhojSetting, n
let khojDefaultMdIndexDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["content-type"]["markdown"]["embeddings-file"]);
let khojDefaultPdfIndexDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["content-type"]["pdf"]["embeddings-file"]);
let khojDefaultChatDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["processor"]["conversation"]["conversation-logfile"]);
let khojDefaultChatModelName = defaultConfig["processor"]["conversation"]["model"];
let khojDefaultChatModelName = defaultConfig["processor"]["conversation"]["openai"]["chat-model"];
// Get current config if khoj backend configured, else get default config from khoj backend
await request(khoj_already_configured ? khojConfigUrl : `${khojConfigUrl}/default`)
@ -142,25 +142,35 @@ export async function configureKhojBackend(vault: Vault, setting: KhojSetting, n
data["processor"] = {
"conversation": {
"conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`,
"model": khojDefaultChatModelName,
"openai-api-key": setting.openaiApiKey,
}
"openai": {
"chat-model": khojDefaultChatModelName,
"api-key": setting.openaiApiKey,
}
},
}
}
// Else if khoj config has no conversation processor config
else if (!data["processor"]["conversation"]) {
data["processor"]["conversation"] = {
"conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`,
"model": khojDefaultChatModelName,
"openai-api-key": setting.openaiApiKey,
else if (!data["processor"]["conversation"] || !data["processor"]["conversation"]["openai"]) {
data["processor"] = {
"conversation": {
"conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`,
"openai": {
"chat-model": khojDefaultChatModelName,
"api-key": setting.openaiApiKey,
}
},
}
}
// Else if khoj is not configured with OpenAI API key from khoj plugin settings
else if (data["processor"]["conversation"]["openai-api-key"] !== setting.openaiApiKey) {
data["processor"]["conversation"] = {
"conversation-logfile": data["processor"]["conversation"]["conversation-logfile"],
"model": data["processor"]["conversation"]["model"],
"openai-api-key": setting.openaiApiKey,
else if (data["processor"]["conversation"]["openai"]["api-key"] !== setting.openaiApiKey) {
data["processor"] = {
"conversation": {
"conversation-logfile": data["processor"]["conversation"]["conversation-logfile"],
"openai": {
"chat-model": data["processor"]["conversation"]["openai"]["chat-model"],
"api-key": setting.openaiApiKey,
}
},
}
}

View file

@ -11,7 +11,6 @@ import schedule
from fastapi.staticfiles import StaticFiles
# Internal Packages
from khoj.processor.conversation.gpt import summarize
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
@ -28,7 +27,7 @@ from khoj.utils.config import (
ConversationProcessorConfigModel,
)
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig, ConversationProcessorConfig
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
@ -64,7 +63,7 @@ def configure_server(config: FullConfig, regenerate: bool, search_type: Optional
state.config_lock.acquire()
state.processor_config = configure_processor(state.config.processor)
except Exception as e:
logger.error(f"🚨 Failed to configure processor")
logger.error(f"🚨 Failed to configure processor", exc_info=True)
raise e
finally:
state.config_lock.release()
@ -75,7 +74,7 @@ def configure_server(config: FullConfig, regenerate: bool, search_type: Optional
state.SearchType = configure_search_types(state.config)
state.search_models = configure_search(state.search_models, state.config.search_type)
except Exception as e:
logger.error(f"🚨 Failed to configure search models")
logger.error(f"🚨 Failed to configure search models", exc_info=True)
raise e
finally:
state.config_lock.release()
@ -88,7 +87,7 @@ def configure_server(config: FullConfig, regenerate: bool, search_type: Optional
state.content_index, state.config.content_type, state.search_models, regenerate, search_type
)
except Exception as e:
logger.error(f"🚨 Failed to index content")
logger.error(f"🚨 Failed to index content", exc_info=True)
raise e
finally:
state.config_lock.release()
@ -117,7 +116,7 @@ if not state.demo:
)
logger.info("📬 Content index updated via Scheduler")
except Exception as e:
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)
finally:
state.config_lock.release()
@ -258,7 +257,9 @@ def configure_content(
return content_index
def configure_processor(processor_config: Optional[ProcessorConfig]):
def configure_processor(
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
):
if not processor_config:
logger.warning("🚨 No Processor configuration available.")
return None
@ -266,16 +267,47 @@ def configure_processor(processor_config: Optional[ProcessorConfig]):
processor = ProcessorConfigModel()
# Initialize Conversation Processor
if processor_config.conversation:
logger.info("💬 Setting up conversation processor")
processor.conversation = configure_conversation_processor(processor_config.conversation)
logger.info("💬 Setting up conversation processor")
processor.conversation = configure_conversation_processor(processor_config, state_processor_config)
return processor
def configure_conversation_processor(conversation_processor_config):
conversation_processor = ConversationProcessorConfigModel(conversation_processor_config)
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
def configure_conversation_processor(
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
):
if (
not processor_config
or not processor_config.conversation
or not processor_config.conversation.conversation_logfile
):
default_config = constants.default_config
default_conversation_logfile = resolve_absolute_path(
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
)
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
conversation_config = processor_config.conversation if processor_config else None
conversation_processor = ConversationProcessorConfigModel(
conversation_config=ConversationProcessorConfig(
conversation_logfile=conversation_logfile,
openai=(conversation_config.openai if (conversation_config is not None) else None),
enable_offline_chat=(
conversation_config.enable_offline_chat if (conversation_config is not None) else False
),
)
)
else:
conversation_processor = ConversationProcessorConfigModel(
conversation_config=processor_config.conversation,
)
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
# Load Conversation Logs from Disk
if state_processor_config and state_processor_config.conversation and state_processor_config.conversation.meta_log:
conversation_processor.meta_log = state_processor_config.conversation.meta_log
conversation_processor.chat_session = state_processor_config.conversation.chat_session
logger.debug(f"Loaded conversation logs from state")
return conversation_processor
if conversation_logfile.is_file():
# Load Metadata Logs from Conversation Logfile
@ -302,12 +334,8 @@ def save_chat_session():
return
# Summarize Conversation Logs for this Session
chat_session = state.processor_config.conversation.chat_session
openai_api_key = state.processor_config.conversation.openai_api_key
conversation_log = state.processor_config.conversation.meta_log
chat_model = state.processor_config.conversation.chat_model
session = {
"summary": summarize(chat_session, model=chat_model, api_key=openai_api_key),
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
"session-end": len(conversation_log["chat"]),
}
@ -344,6 +372,6 @@ def upload_telemetry():
log[field] = str(log[field])
requests.post(constants.telemetry_server, json=state.telemetry)
except Exception as e:
logger.error(f"📡 Error uploading telemetry: {e}")
logger.error(f"📡 Error uploading telemetry: {e}", exc_info=True)
else:
state.telemetry = []

View file

@ -0,0 +1 @@
<svg viewBox="0 0 320 320" xmlns="http://www.w3.org/2000/svg"><path d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>

After

Width:  |  Height:  |  Size: 1.7 KiB

View file

@ -167,10 +167,42 @@
text-align: left;
}
button.card-button.happy {
color: rgb(0, 146, 0);
}
img.configured-icon {
max-width: 16px;
}
div.card-action-row.enabled{
display: block;
}
img.configured-icon.enabled {
display: inline;
}
div.card-action-row.disabled,
img.configured-icon.disabled {
display: none;
}
.loader {
border: 16px solid #f3f3f3; /* Light grey */
border-top: 16px solid var(--primary);
border-radius: 50%;
width: 16px;
height: 16px;
animation: spin 2s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
div.finalize-actions {
grid-auto-flow: column;
grid-gap: 24px;

View file

@ -135,8 +135,8 @@
.then(response => response.json())
.then(data => {
if (data.detail) {
// If the server returns a 500 error with detail, render it as a message.
renderMessage("Hi 👋🏾, to get started <br/>1. Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a><br/>2. Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation'>chat settings</a> <br/>3. Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>", "khoj");
// If the server returns a 500 error with detail, render a setup hint.
renderMessage("Hi 👋🏾, to get started <br/>1. Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a><br/>2. Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a> <br/>3. Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>", "khoj");
// Disable chat input field and update placeholder text
document.getElementById("chat-input").setAttribute("disabled", "disabled");

View file

@ -20,7 +20,7 @@
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Set repositories for Khoj to index</p>
<p class="card-description">Set repositories to index</p>
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content_type/github">
@ -90,7 +90,7 @@
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Set markdown files for Khoj to index</p>
<p class="card-description">Set markdown files to index</p>
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content_type/markdown">
@ -125,7 +125,7 @@
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Set org files for Khoj to index</p>
<p class="card-description">Set org files to index</p>
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content_type/org">
@ -160,7 +160,7 @@
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Set PDF files for Khoj to index</p>
<p class="card-description">Set PDF files to index</p>
</div>
<div class="card-action-row">
<a class="card-button" href="/config/content_type/pdf">
@ -187,10 +187,10 @@
<div class="section-cards">
<div class="card">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<img class="card-icon" src="/static/assets/icons/openai-logomark.svg" alt="Chat">
<h3 class="card-title">
Chat
{% if current_config.processor and current_config.processor.conversation %}
{% if current_config.processor and current_config.processor.conversation.openai %}
{% if current_model_state.conversation == False %}
<img id="misconfigured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
{% else %}
@ -200,11 +200,11 @@
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Setup Khoj Chat with OpenAI</p>
<p class="card-description">Setup chat using OpenAI</p>
</div>
<div class="card-action-row">
<a class="card-button" href="/config/processor/conversation">
{% if current_config.processor and current_config.processor.conversation %}
<a class="card-button" href="/config/processor/conversation/openai">
{% if current_config.processor and current_config.processor.conversation.openai %}
Update
{% else %}
Setup
@ -212,7 +212,7 @@
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
</a>
</div>
{% if current_config.processor and current_config.processor.conversation %}
{% if current_config.processor and current_config.processor.conversation.openai %}
<div id="clear-conversation" class="card-action-row">
<button class="card-button" onclick="clearConversationProcessor()">
Disable
@ -220,6 +220,31 @@
</div>
{% endif %}
</div>
<div class="card">
<div class="card-title-row">
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
<h3 class="card-title">
Offline Chat
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
</h3>
</div>
<div class="card-description-row">
<p class="card-description">Setup offline chat (Falcon 7B)</p>
</div>
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
Disable
</button>
</div>
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
Enable
</button>
</div>
<div id="toggle-enable-offline-chat" class="card-action-row disabled">
<div class="loader"></div>
</div>
</div>
</div>
</div>
<div class="section">
@ -263,9 +288,59 @@
})
};
function toggleEnableLocalLLLM(enable) {
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
var toggleEnableLocalLLLMButton = document.getElementById("toggle-enable-offline-chat");
toggleEnableLocalLLLMButton.classList.remove("disabled");
toggleEnableLocalLLLMButton.classList.add("enabled");
fetch('/api/config/data/processor/conversation/enable_offline_chat' + '?enable_offline_chat=' + enable, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-CSRFToken': csrfToken
},
})
.then(response => response.json())
.then(data => {
if (data.status == "ok") {
// Toggle the Enabled/Disabled UI based on the action/response.
var enableLocalLLLMButton = document.getElementById("set-enable-offline-chat");
var disableLocalLLLMButton = document.getElementById("clear-enable-offline-chat");
var configuredIcon = document.getElementById("configured-icon-conversation-enable-offline-chat");
var toggleEnableLocalLLLMButton = document.getElementById("toggle-enable-offline-chat");
toggleEnableLocalLLLMButton.classList.remove("enabled");
toggleEnableLocalLLLMButton.classList.add("disabled");
if (enable) {
enableLocalLLLMButton.classList.add("disabled");
enableLocalLLLMButton.classList.remove("enabled");
configuredIcon.classList.add("enabled");
configuredIcon.classList.remove("disabled");
disableLocalLLLMButton.classList.remove("disabled");
disableLocalLLLMButton.classList.add("enabled");
} else {
enableLocalLLLMButton.classList.remove("disabled");
enableLocalLLLMButton.classList.add("enabled");
configuredIcon.classList.remove("enabled");
configuredIcon.classList.add("disabled");
disableLocalLLLMButton.classList.add("disabled");
disableLocalLLLMButton.classList.remove("enabled");
}
}
})
}
function clearConversationProcessor() {
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/delete/config/data/processor/conversation', {
fetch('/api/delete/config/data/processor/conversation/openai', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
@ -319,7 +394,7 @@
function updateIndex(force, successText, errorText, button, loadingText, emoji) {
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
button.disabled = true;
button.innerHTML = emoji + loadingText;
button.innerHTML = emoji + " " + loadingText;
fetch('/api/update?&client=web&force=' + force, {
method: 'GET',
headers: {

View file

@ -13,7 +13,7 @@
<label for="openai-api-key" title="Get your OpenAI key from https://platform.openai.com/account/api-keys">OpenAI API key</label>
</td>
<td>
<input type="text" id="openai-api-key" name="openai-api-key" value="{{ current_config['openai_api_key'] }}">
<input type="text" id="openai-api-key" name="openai-api-key" value="{{ current_config['api_key'] }}">
</td>
</tr>
<tr>
@ -25,24 +25,6 @@
</td>
</tr>
</table>
<table style="display: none;">
<tr>
<td>
<label for="conversation-logfile">Conversation Logfile</label>
</td>
<td>
<input type="text" id="conversation-logfile" name="conversation-logfile" value="{{ current_config['conversation_logfile'] }}">
</td>
</tr>
<tr>
<td>
<label for="model">Model</label>
</td>
<td>
<input type="text" id="model" name="model" value="{{ current_config['model'] }}">
</td>
</tr>
</table>
<div class="section">
<div id="success" style="display: none;" ></div>
<button id="submit" type="submit">Save</button>
@ -54,21 +36,23 @@
submit.addEventListener("click", function(event) {
event.preventDefault();
var openai_api_key = document.getElementById("openai-api-key").value;
var conversation_logfile = document.getElementById("conversation-logfile").value;
var model = document.getElementById("model").value;
var chat_model = document.getElementById("chat-model").value;
if (openai_api_key == "" || chat_model == "") {
document.getElementById("success").innerHTML = "⚠️ Please fill all the fields.";
document.getElementById("success").style.display = "block";
return;
}
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
fetch('/api/config/data/processor/conversation', {
fetch('/api/config/data/processor/conversation/openai', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'X-CSRFToken': csrfToken
},
body: JSON.stringify({
"openai_api_key": openai_api_key,
"conversation_logfile": conversation_logfile,
"model": model,
"api_key": openai_api_key,
"chat_model": chat_model
})
})

View file

View file

@ -0,0 +1,66 @@
"""
Current format of khoj.yml
---
app:
should-log-telemetry: true
content-type:
...
processor:
conversation:
chat-model: gpt-3.5-turbo
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
model: text-davinci-003
openai-api-key: sk-secret-key
search-type:
...
New format of khoj.yml
---
app:
should-log-telemetry: true
content-type:
...
processor:
conversation:
openai:
chat-model: gpt-3.5-turbo
openai-api-key: sk-secret-key
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
enable-offline-chat: false
search-type:
...
"""
from khoj.utils.yaml import load_config_from_file, save_config_to_file
def migrate_processor_conversation_schema(args):
raw_config = load_config_from_file(args.config_file)
raw_config["version"] = args.version_no
if "processor" not in raw_config:
return args
if raw_config["processor"] is None:
return args
if "conversation" not in raw_config["processor"]:
return args
# Add enable_offline_chat to khoj config schema
if "enable-offline-chat" not in raw_config["processor"]["conversation"]:
raw_config["processor"]["conversation"]["enable-offline-chat"] = False
save_config_to_file(raw_config, args.config_file)
current_openai_api_key = raw_config["processor"]["conversation"].get("openai-api-key", None)
current_chat_model = raw_config["processor"]["conversation"].get("chat-model", None)
if current_openai_api_key is None and current_chat_model is None:
return args
conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None)
raw_config["processor"]["conversation"] = {
"openai": {"chat-model": current_chat_model, "api-key": current_openai_api_key},
"conversation-logfile": conversation_logfile,
"enable-offline-chat": False,
}
save_config_to_file(raw_config, args.config_file)
return args

View file

@ -0,0 +1,16 @@
from khoj.utils.yaml import load_config_from_file, save_config_to_file
def migrate_config_to_version(args):
raw_config = load_config_from_file(args.config_file)
# Add version to khoj config schema
if "version" not in raw_config:
raw_config["version"] = args.version_no
save_config_to_file(raw_config, args.config_file)
# regenerate khoj index on first start of this version
# this should refresh index and apply index corruption fixes from #325
args.regenerate = True
return args

View file

@ -0,0 +1,137 @@
from typing import Union, List
from datetime import datetime
import sys
import logging
from threading import Thread
from langchain.schema import ChatMessage
from gpt4all import GPT4All
from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
from khoj.processor.conversation import prompts
from khoj.utils.constants import empty_escape_sequences
logger = logging.getLogger(__name__)
def extract_questions_falcon(
text: str,
model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
loaded_model: Union[GPT4All, None] = None,
conversation_log={},
use_history: bool = False,
run_extraction: bool = False,
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
all_questions = text.split("? ")
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
if not run_extraction:
return all_questions
gpt4all_model = loaded_model or GPT4All(model)
# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = ""
if use_history:
chat_history = "".join(
[
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj"
]
)
prompt = prompts.extract_questions_falcon.format(
chat_history=chat_history,
text=text,
)
message = prompts.general_conversation_falcon.format(query=prompt)
response = gpt4all_model.generate(message, max_tokens=200, top_k=2)
# Extract, Clean Message from GPT's Response
try:
questions = (
str(response)
.strip(empty_escape_sequences)
.replace("['", '["')
.replace("']", '"]')
.replace("', '", '", "')
.replace('["', "")
.replace('"]', "")
.split('", "')
)
except:
logger.warning(f"Falcon returned invalid JSON. Falling back to using user message as search query.\n{response}")
return all_questions
logger.debug(f"Extracted Questions by Falcon: {questions}")
questions.extend(all_questions)
return questions
def converse_falcon(
references,
user_query,
conversation_log={},
model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
loaded_model: Union[GPT4All, None] = None,
completion_func=None,
) -> ThreadedGenerator:
"""
Converse with user using Falcon
"""
gpt4all_model = loaded_model or GPT4All(model)
# Initialize Variables
current_date = datetime.now().strftime("%Y-%m-%d")
compiled_references_message = "\n\n".join({f"{item}" for item in references})
# Get Conversation Primer appropriate to Conversation Type
# TODO If compiled_references_message is too long, we need to truncate it.
if compiled_references_message == "":
conversation_primer = prompts.conversation_falcon.format(query=user_query)
else:
conversation_primer = prompts.notes_conversation.format(
current_date=current_date, query=user_query, references=compiled_references_message
)
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
prompts.personality.format(),
conversation_log,
model_name="text-davinci-001", # This isn't actually the model, but this helps us get an approximate encoding to run message truncation.
)
g = ThreadedGenerator(references, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, gpt4all_model))
t.start()
return g
def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
user_message = messages[0]
system_message = messages[-1]
conversation_history = messages[1:-1]
formatted_messages = [
prompts.chat_history_falcon_from_assistant.format(message=system_message)
if message.role == "assistant"
else prompts.chat_history_falcon_from_user.format(message=message.content)
for message in conversation_history
]
chat_history = "".join(formatted_messages)
full_message = system_message.content + chat_history + user_message.content
prompted_message = prompts.general_conversation_falcon.format(query=full_message)
response_iterator = model.generate(
prompted_message, streaming=True, max_tokens=256, top_k=1, temp=0, repeat_penalty=2.0
)
for response in response_iterator:
logger.info(response)
g.send(response)
g.close()

View file

@ -9,11 +9,11 @@ from langchain.schema import ChatMessage
# Internal Packages
from khoj.utils.constants import empty_escape_sequences
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
from khoj.processor.conversation.openai.utils import (
chat_completion_with_backoff,
completion_with_backoff,
generate_chatml_messages_with_context,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
logger = logging.getLogger(__name__)

View file

@ -0,0 +1,101 @@
# Standard Packages
import os
import logging
from typing import Any
from threading import Thread
# External Packages
from langchain.chat_models import ChatOpenAI
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.base import BaseCallbackManager
import openai
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
wait_random_exponential,
)
# Internal Packages
from khoj.processor.conversation.utils import ThreadedGenerator
logger = logging.getLogger(__name__)
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
def __init__(self, gen: ThreadedGenerator):
super().__init__()
self.gen = gen
def on_llm_new_token(self, token: str, **kwargs) -> Any:
self.gen.send(token)
@retry(
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def completion_with_backoff(**kwargs):
messages = kwargs.pop("messages")
if not "openai_api_key" in kwargs:
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
return llm(messages=messages)
@retry(
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def chat_completion_with_backoff(
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
):
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
t.start()
return g
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
callback_handler = StreamingChatCallbackHandler(g)
chat = ChatOpenAI(
streaming=True,
verbose=True,
callback_manager=BaseCallbackManager([callback_handler]),
model_name=model_name, # type: ignore
temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
request_timeout=20,
max_retries=1,
client=None,
)
chat(messages=messages)
g.close()
def extract_summaries(metadata):
"""Extract summaries from metadata"""
return "".join([f'\n{session["summary"]}' for session in metadata])

View file

@ -18,6 +18,36 @@ Question: {query}
""".strip()
)
general_conversation_falcon = PromptTemplate.from_template(
"""
Using your general knowledge and our past conversations as context, answer the following question.
### Instruct:
{query}
### Response:
""".strip()
)
chat_history_falcon_from_user = PromptTemplate.from_template(
"""
### Human:
{message}
""".strip()
)
chat_history_falcon_from_assistant = PromptTemplate.from_template(
"""
### Assistant:
{message}
""".strip()
)
conversation_falcon = PromptTemplate.from_template(
"""
Using our past conversations as context, answer the following question.
Question: {query}
""".strip()
)
## Notes Conversation
## --
@ -33,6 +63,17 @@ Question: {query}
""".strip()
)
notes_conversation_falcon = PromptTemplate.from_template(
"""
Using the notes and our past conversations as context, answer the following question. If the answer is not contained within the notes, say "I don't know."
Notes:
{references}
Question: {query}
""".strip()
)
## Summarize Chat
## --
@ -68,6 +109,40 @@ Question: {user_query}
Answer (in second person):"""
)
extract_questions_falcon = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes.
- The user will provide their questions and answers to you for context.
- Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
What searches, if any, will you need to perform to answer the users question?
Q: How was my trip to Cambodia?
["How was my trip to Cambodia?"]
A: The trip was amazing. I went to the Angkor Wat temple and it was beautiful.
Q: Who did i visit that temple with?
["Who did I visit the Angkor Wat Temple in Cambodia with?"]
A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
Q: How many tennis balls fit in the back of a 2002 Honda Civic?
["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]
A: 1085 tennis balls will fit in the trunk of a Honda Civic
{chat_history}
Q: {text}
"""
)
## Extract Questions
## --

View file

@ -1,35 +1,19 @@
# Standard Packages
import os
import logging
from datetime import datetime
from time import perf_counter
from typing import Any
from threading import Thread
import json
# External Packages
from langchain.chat_models import ChatOpenAI
from langchain.schema import ChatMessage
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.callbacks.base import BaseCallbackManager
import openai
from datetime import datetime
import tiktoken
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
wait_random_exponential,
)
# External packages
from langchain.schema import ChatMessage
# Internal Packages
from khoj.utils.helpers import merge_dicts
import queue
from khoj.utils.helpers import merge_dicts
logger = logging.getLogger(__name__)
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "text-davinci-001": 910}
class ThreadedGenerator:
@ -49,9 +33,9 @@ class ThreadedGenerator:
time_to_response = perf_counter() - self.start_time
logger.info(f"Chat streaming took: {time_to_response:.3f} seconds")
if self.completion_func:
# The completion func effective acts as a callback.
# It adds the aggregated response to the conversation history. It's constructed in api.py.
self.completion_func(gpt_response=self.response)
# The completion func effectively acts as a callback.
# It adds the aggregated response to the conversation history.
self.completion_func(chat_response=self.response)
raise StopIteration
return item
@ -65,75 +49,25 @@ class ThreadedGenerator:
self.queue.put(StopIteration)
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
def __init__(self, gen: ThreadedGenerator):
super().__init__()
self.gen = gen
def on_llm_new_token(self, token: str, **kwargs) -> Any:
self.gen.send(token)
@retry(
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def completion_with_backoff(**kwargs):
messages = kwargs.pop("messages")
if not "openai_api_key" in kwargs:
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
return llm(messages=messages)
@retry(
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def chat_completion_with_backoff(
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
def message_to_log(
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
):
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
t.start()
return g
"""Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = {
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
"trigger-emotion": "calm",
}
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Create json log from Human's message
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
callback_handler = StreamingChatCallbackHandler(g)
chat = ChatOpenAI(
streaming=True,
verbose=True,
callback_manager=BaseCallbackManager([callback_handler]),
model_name=model_name, # type: ignore
temperature=temperature,
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
request_timeout=20,
max_retries=1,
client=None,
)
# Create json log from GPT's response
khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
khoj_log = merge_dicts({"message": chat_response, "by": "khoj", "created": khoj_response_time}, khoj_log)
chat(messages=messages)
g.close()
conversation_log.extend([human_log, khoj_log])
return conversation_log
def generate_chatml_messages_with_context(
@ -192,27 +126,3 @@ def truncate_messages(messages, max_prompt_size, model_name):
def reciprocal_conversation_to_chatml(message_pair):
"""Convert a single back and forth between user and assistant to chatml format"""
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]):
"""Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = {
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
"trigger-emotion": "calm",
}
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Create json log from Human's message
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
# Create json log from GPT's response
khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
khoj_log = merge_dicts({"message": gpt_message, "by": "khoj", "created": khoj_response_time}, khoj_log)
conversation_log.extend([human_log, khoj_log])
return conversation_log
def extract_summaries(metadata):
"""Extract summaries from metadata"""
return "".join([f'\n{session["summary"]}' for session in metadata])

View file

@ -52,7 +52,7 @@ class GithubToJsonl(TextToJsonl):
try:
markdown_files, org_files = self.get_files(repo_url, repo)
except Exception as e:
logger.error(f"Unable to download github repo {repo_shorthand}")
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
raise e
logger.info(f"Found {len(markdown_files)} markdown files in github repo {repo_shorthand}")

View file

@ -219,7 +219,7 @@ class NotionToJsonl(TextToJsonl):
page = self.get_page(page_id)
content = self.get_page_children(page_id)
except Exception as e:
logger.error(f"Error getting page {page_id}: {e}")
logger.error(f"Error getting page {page_id}: {e}", exc_info=True)
return None, None
properties = page["properties"]
title_field = "title"

View file

@ -5,7 +5,7 @@ import time
import yaml
import logging
import json
from typing import Iterable, List, Optional, Union
from typing import List, Optional, Union
# External Packages
from fastapi import APIRouter, HTTPException, Header, Request
@ -26,16 +26,19 @@ from khoj.utils.rawconfig import (
SearchConfig,
SearchResponse,
TextContentConfig,
ConversationProcessorConfig,
OpenAIProcessorConfig,
GithubContentConfig,
NotionContentConfig,
ConversationProcessorConfig,
)
from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.state import SearchType
from khoj.utils import state, constants
from khoj.utils.yaml import save_config_to_file_updated_state
from fastapi.responses import StreamingResponse, Response
from khoj.routers.helpers import perform_chat_checks, generate_chat_response, update_telemetry_state
from khoj.processor.conversation.gpt import extract_questions
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_falcon, converse_falcon
from fastapi.requests import Request
@ -50,6 +53,8 @@ if not state.demo:
if state.config is None:
state.config = FullConfig()
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
if state.processor_config is None:
state.processor_config = configure_processor(state.config.processor)
@api.get("/config/data", response_model=FullConfig)
def get_config_data():
@ -181,22 +186,28 @@ if not state.demo:
except Exception as e:
return {"status": "error", "message": str(e)}
@api.post("/delete/config/data/processor/conversation", status_code=200)
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
async def remove_processor_conversation_config_data(
request: Request,
client: Optional[str] = None,
):
if not state.config or not state.config.processor or not state.config.processor.conversation:
if (
not state.config
or not state.config.processor
or not state.config.processor.conversation
or not state.config.processor.conversation.openai
):
return {"status": "ok"}
state.config.processor.conversation = None
state.config.processor.conversation.openai = None
state.processor_config = configure_processor(state.config.processor, state.processor_config)
update_telemetry_state(
request=request,
telemetry_type="api",
api="delete_processor_config",
api="delete_processor_openai_config",
client=client,
metadata={"processor_type": "conversation"},
metadata={"processor_conversation_type": "openai"},
)
try:
@ -233,23 +244,66 @@ if not state.demo:
except Exception as e:
return {"status": "error", "message": str(e)}
@api.post("/config/data/processor/conversation", status_code=200)
async def set_processor_conversation_config_data(
@api.post("/config/data/processor/conversation/openai", status_code=200)
async def set_processor_openai_config_data(
request: Request,
updated_config: Union[ConversationProcessorConfig, None],
updated_config: Union[OpenAIProcessorConfig, None],
client: Optional[str] = None,
):
_initialize_config()
state.config.processor = ProcessorConfig(conversation=updated_config)
state.processor_config = configure_processor(state.config.processor)
if not state.config.processor or not state.config.processor.conversation:
default_config = constants.default_config
default_conversation_logfile = resolve_absolute_path(
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
)
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
assert state.config.processor.conversation is not None
state.config.processor.conversation.openai = updated_config
state.processor_config = configure_processor(state.config.processor, state.processor_config)
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_content_config",
api="set_processor_config",
client=client,
metadata={"processor_type": "conversation"},
metadata={"processor_conversation_type": "conversation"},
)
try:
save_config_to_file_updated_state()
return {"status": "ok"}
except Exception as e:
return {"status": "error", "message": str(e)}
@api.post("/config/data/processor/conversation/enable_offline_chat", status_code=200)
async def set_processor_enable_offline_chat_config_data(
request: Request,
enable_offline_chat: bool,
client: Optional[str] = None,
):
_initialize_config()
if not state.config.processor or not state.config.processor.conversation:
default_config = constants.default_config
default_conversation_logfile = resolve_absolute_path(
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
)
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
assert state.config.processor.conversation is not None
state.config.processor.conversation.enable_offline_chat = enable_offline_chat
state.processor_config = configure_processor(state.config.processor, state.processor_config)
update_telemetry_state(
request=request,
telemetry_type="api",
api="set_processor_config",
client=client,
metadata={"processor_conversation_type": f"{'enable' if enable_offline_chat else 'disable'}_local_llm"},
)
try:
@ -569,7 +623,9 @@ def chat_history(
perform_chat_checks()
# Load Conversation History
meta_log = state.processor_config.conversation.meta_log
meta_log = {}
if state.processor_config.conversation:
meta_log = state.processor_config.conversation.meta_log
update_telemetry_state(
request=request,
@ -598,24 +654,25 @@ async def chat(
perform_chat_checks()
compiled_references, inferred_queries = await extract_references_and_questions(request, q, (n or 5))
# Get the (streamed) chat response from GPT.
gpt_response = generate_chat_response(
# Get the (streamed) chat response from the LLM of choice.
llm_response = generate_chat_response(
q,
meta_log=state.processor_config.conversation.meta_log,
compiled_references=compiled_references,
inferred_queries=inferred_queries,
)
if gpt_response is None:
return Response(content=gpt_response, media_type="text/plain", status_code=500)
if llm_response is None:
return Response(content=llm_response, media_type="text/plain", status_code=500)
if stream:
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
# Get the full response from the generator if the stream is not requested.
aggregated_gpt_response = ""
while True:
try:
aggregated_gpt_response += next(gpt_response)
aggregated_gpt_response += next(llm_response)
except StopIteration:
break
@ -645,8 +702,6 @@ async def extract_references_and_questions(
meta_log = state.processor_config.conversation.meta_log
# Initialize Variables
api_key = state.processor_config.conversation.openai_api_key
chat_model = state.processor_config.conversation.chat_model
conversation_type = "general" if q.startswith("@general") else "notes"
compiled_references = []
inferred_queries = []
@ -654,7 +709,13 @@ async def extract_references_and_questions(
if conversation_type == "notes":
# Infer search queries from user message
with timer("Extracting search queries took", logger):
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
if state.processor_config.conversation and state.processor_config.conversation.openai_model:
api_key = state.processor_config.conversation.openai_model.api_key
chat_model = state.processor_config.conversation.openai_model.chat_model
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
else:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
inferred_queries = extract_questions_falcon(q, loaded_model=loaded_model, conversation_log=meta_log)
# Collate search results as context for GPT
with timer("Searching knowledge base took", logger):

View file

@ -7,22 +7,23 @@ from fastapi import HTTPException, Request
from khoj.utils import state
from khoj.utils.helpers import timer, log_telemetry
from khoj.processor.conversation.gpt import converse
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
logger = logging.getLogger(__name__)
def perform_chat_checks():
if (
state.processor_config is None
or state.processor_config.conversation is None
or state.processor_config.conversation.openai_api_key is None
if state.processor_config.conversation and (
state.processor_config.conversation.openai_model
or state.processor_config.conversation.gpt4all_model.loaded_model
):
raise HTTPException(
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
)
return
raise HTTPException(
status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings and restart it."
)
def update_telemetry_state(
@ -57,19 +58,19 @@ def generate_chat_response(
meta_log: dict,
compiled_references: List[str] = [],
inferred_queries: List[str] = [],
):
) -> ThreadedGenerator:
def _save_to_conversation_log(
q: str,
gpt_response: str,
chat_response: str,
user_message_time: str,
compiled_references: List[str],
inferred_queries: List[str],
meta_log,
):
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, chat_response])
state.processor_config.conversation.meta_log["chat"] = message_to_log(
q,
gpt_response,
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
conversation_log=meta_log.get("chat", []),
@ -79,8 +80,6 @@ def generate_chat_response(
meta_log = state.processor_config.conversation.meta_log
# Initialize Variables
api_key = state.processor_config.conversation.openai_api_key
chat_model = state.processor_config.conversation.chat_model
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
conversation_type = "general" if q.startswith("@general") else "notes"
@ -99,12 +98,29 @@ def generate_chat_response(
meta_log=meta_log,
)
gpt_response = converse(
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
)
if state.processor_config.conversation.openai_model:
api_key = state.processor_config.conversation.openai_model.api_key
chat_model = state.processor_config.conversation.openai_model.chat_model
chat_response = converse(
compiled_references,
q,
meta_log,
model=chat_model,
api_key=api_key,
completion_func=partial_completion,
)
else:
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
chat_response = converse_falcon(
references=compiled_references,
user_query=q,
loaded_model=loaded_model,
conversation_log=meta_log,
completion_func=partial_completion,
)
except Exception as e:
logger.error(e)
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
return gpt_response
return chat_response

View file

@ -3,7 +3,7 @@ from fastapi import APIRouter
from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.templating import Jinja2Templates
from khoj.utils.rawconfig import TextContentConfig, ConversationProcessorConfig, FullConfig
from khoj.utils.rawconfig import TextContentConfig, OpenAIProcessorConfig, FullConfig
# Internal Packages
from khoj.utils import constants, state
@ -151,28 +151,29 @@ if not state.demo:
},
)
@web_client.get("/config/processor/conversation", response_class=HTMLResponse)
@web_client.get("/config/processor/conversation/openai", response_class=HTMLResponse)
def conversation_processor_config_page(request: Request):
default_copy = constants.default_config.copy()
default_processor_config = default_copy["processor"]["conversation"] # type: ignore
default_processor_config = ConversationProcessorConfig(
openai_api_key="",
model=default_processor_config["model"],
conversation_logfile=default_processor_config["conversation-logfile"],
default_processor_config = default_copy["processor"]["conversation"]["openai"] # type: ignore
default_openai_config = OpenAIProcessorConfig(
api_key="",
chat_model=default_processor_config["chat-model"],
)
current_processor_conversation_config = (
state.config.processor.conversation
if state.config and state.config.processor and state.config.processor.conversation
else default_processor_config
current_processor_openai_config = (
state.config.processor.conversation.openai
if state.config
and state.config.processor
and state.config.processor.conversation
and state.config.processor.conversation.openai
else default_openai_config
)
current_processor_conversation_config = json.loads(current_processor_conversation_config.json())
current_processor_openai_config = json.loads(current_processor_openai_config.json())
return templates.TemplateResponse(
"processor_conversation_input.html",
context={
"request": request,
"current_config": current_processor_conversation_config,
"current_config": current_processor_openai_config,
},
)

View file

@ -5,7 +5,9 @@ from importlib.metadata import version
# Internal Packages
from khoj.utils.helpers import resolve_absolute_path
from khoj.utils.yaml import load_config_from_file, parse_config_from_file, save_config_to_file
from khoj.utils.yaml import parse_config_from_file
from khoj.migrations.migrate_version import migrate_config_to_version
from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
def cli(args=None):
@ -46,22 +48,14 @@ def cli(args=None):
if not args.config_file.exists():
args.config = None
else:
args = migrate_config(args)
args = run_migrations(args)
args.config = parse_config_from_file(args.config_file)
return args
def migrate_config(args):
raw_config = load_config_from_file(args.config_file)
# Add version to khoj config schema
if "version" not in raw_config:
raw_config["version"] = args.version_no
save_config_to_file(raw_config, args.config_file)
# regenerate khoj index on first start of this version
# this should refresh index and apply index corruption fixes from #325
args.regenerate = True
def run_migrations(args):
migrations = [migrate_config_to_version, migrate_processor_conversation_schema]
for migration in migrations:
args = migration(args)
return args

View file

@ -1,9 +1,12 @@
# System Packages
from __future__ import annotations # to avoid quoting type hints
from enum import Enum
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
from gpt4all import GPT4All
# External Packages
import torch
@ -13,7 +16,7 @@ if TYPE_CHECKING:
from sentence_transformers import CrossEncoder
from khoj.search_filter.base_filter import BaseFilter
from khoj.utils.models import BaseEncoder
from khoj.utils.rawconfig import ConversationProcessorConfig, Entry
from khoj.utils.rawconfig import ConversationProcessorConfig, Entry, OpenAIProcessorConfig
class SearchType(str, Enum):
@ -74,15 +77,29 @@ class SearchModels:
plugin_search: Optional[Dict[str, TextSearchModel]] = None
@dataclass
class GPT4AllProcessorConfig:
chat_model: Optional[str] = "ggml-model-gpt4all-falcon-q4_0.bin"
loaded_model: Union[Any, None] = None
class ConversationProcessorConfigModel:
def __init__(self, processor_config: ConversationProcessorConfig):
self.openai_api_key = processor_config.openai_api_key
self.model = processor_config.model
self.chat_model = processor_config.chat_model
self.conversation_logfile = Path(processor_config.conversation_logfile)
def __init__(
self,
conversation_config: ConversationProcessorConfig,
):
self.openai_model = conversation_config.openai
self.gpt4all_model = GPT4AllProcessorConfig()
self.enable_offline_chat = conversation_config.enable_offline_chat
self.conversation_logfile = Path(conversation_config.conversation_logfile)
self.chat_session: List[str] = []
self.meta_log: dict = {}
if not self.openai_model and self.enable_offline_chat:
self.gpt4all_model.loaded_model = GPT4All(self.gpt4all_model.chat_model) # type: ignore
else:
self.gpt4all_model.loaded_model = None
@dataclass
class ProcessorConfigModel:

View file

@ -62,10 +62,12 @@ default_config = {
},
"processor": {
"conversation": {
"openai-api-key": None,
"model": "text-davinci-003",
"openai": {
"api-key": None,
"chat-model": "gpt-3.5-turbo",
},
"enable-offline-chat": False,
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
"chat-model": "gpt-3.5-turbo",
}
},
}

View file

@ -27,12 +27,12 @@ class OpenAI(BaseEncoder):
if (
not state.processor_config
or not state.processor_config.conversation
or not state.processor_config.conversation.openai_api_key
or not state.processor_config.conversation.openai_model
):
raise Exception(
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
)
openai.api_key = state.processor_config.conversation.openai_api_key
openai.api_key = state.processor_config.conversation.openai_model.api_key
self.embedding_dimensions = None
def encode(self, entries, device=None, **kwargs):

View file

@ -1,7 +1,7 @@
# System Packages
import json
from pathlib import Path
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Union, Any
# External Packages
from pydantic import BaseModel, validator
@ -103,13 +103,17 @@ class SearchConfig(ConfigBase):
image: Optional[ImageSearchConfig]
class ConversationProcessorConfig(ConfigBase):
openai_api_key: str
conversation_logfile: Path
model: Optional[str] = "text-davinci-003"
class OpenAIProcessorConfig(ConfigBase):
api_key: str
chat_model: Optional[str] = "gpt-3.5-turbo"
class ConversationProcessorConfig(ConfigBase):
conversation_logfile: Path
openai: Optional[OpenAIProcessorConfig]
enable_offline_chat: Optional[bool] = False
class ProcessorConfig(ConfigBase):
conversation: Optional[ConversationProcessorConfig]

View file

@ -20,6 +20,7 @@ content-type:
embeddings-file: content_plugin_2_embeddings.pt
input-filter:
- '*2_new.jsonl.gz'
enable-offline-chat: false
search-type:
asymmetric:
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2

View file

@ -0,0 +1,426 @@
# Standard Packages
from datetime import datetime
# External Packages
import pytest
SKIP_TESTS = True
pytestmark = pytest.mark.skipif(
SKIP_TESTS,
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.",
)
import freezegun
from freezegun import freeze_time
from gpt4all import GPT4All
# Internal Packages
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon, extract_questions_falcon
from khoj.processor.conversation.utils import message_to_log
@pytest.fixture(scope="session")
def loaded_model():
return GPT4All("ggml-model-gpt4all-falcon-q4_0.bin")
freezegun.configure(extend_ignore_list=["transformers"])
# Test
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
# Act
response = extract_questions_falcon(
"Where did I go for dinner yesterday?", loaded_model=loaded_model, run_extraction=True
)
assert len(response) >= 1
assert response[-1] == "Where did I go for dinner yesterday?"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
# Act
response = extract_questions_falcon("Which countries did I visit last month?", loaded_model=loaded_model)
# Assert
assert len(response) == 1
assert response == ["Which countries did I visit last month?"]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
@freeze_time("1984-04-02")
def test_extract_question_with_date_filter_from_relative_year(loaded_model):
# Act
response = extract_questions_falcon(
"Which countries have I visited this year?", loaded_model=loaded_model, run_extraction=True
)
# Assert
assert len(response) >= 1
assert response[-1] == "Which countries have I visited this year?"
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_explicit_questions_from_message(loaded_model):
# Act
response = extract_questions_falcon("What is the Sun? What is the Moon?", loaded_model=loaded_model)
# Assert
expected_responses = ["What is the Sun?", "What is the Moon?"]
assert len(response) == 2
assert expected_responses == response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_extract_multiple_implicit_questions_from_message(loaded_model):
# Act
response = extract_questions_falcon("Is Morpheus taller than Neo?", loaded_model=loaded_model, run_extraction=True)
# Assert
expected_responses = [
("morpheus", "neo"),
]
assert len(response) == 2
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
"Expected two search queries in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_generate_search_query_using_question_from_chat_history(loaded_model):
# Arrange
message_list = [
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
]
# Act
response = extract_questions_falcon(
"Does he have any sons?",
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
run_extraction=True,
use_history=True,
)
expected_responses = [
"do not have",
"clarify",
"am sorry",
]
# Assert
assert len(response) >= 1
assert any([expected_response in response[0] for expected_response in expected_responses]), (
"Expected chat actor to ask for clarification in response, but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
@pytest.mark.chatquality
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
# Arrange
message_list = [
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
]
# Act
response = extract_questions_falcon(
"Is she a Jedi?",
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
run_extraction=True,
use_history=True,
)
# Assert
assert len(response) == 1
assert "Leia" in response[0]
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor is not sufficiently date-aware")
@pytest.mark.chatquality
def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model):
# Arrange
message_list = [
("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []),
]
# Act
response = extract_questions_falcon(
"What was the Pizza place we ate at over there?",
conversation_log=populate_chat_history(message_list),
run_extraction=True,
loaded_model=loaded_model,
)
# Assert
expected_responses = [
("dt>='2000-04-01'", "dt<'2000-05-01'"),
("dt>='2000-04-01'", "dt<='2000-04-30'"),
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
]
assert len(response) == 1
assert "Masai Mara" in response[0]
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
"Expected date filter to limit to April 2000 in response but got: " + response[0]
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
# Act
response_gen = converse_falcon(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Hello, my name is Testatron. Who are you?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["Khoj", "khoj", "khooj", "Khooj", "KHOJ"]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor isn't really good at proper nouns yet.")
@pytest.mark.chatquality
def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model):
"Chat actor needs to use context in previous notes and chat history to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
(
"When was I born?",
"You were born on 1st April 1984.",
["Testatron was born on 1st April 1984 in Testville."],
),
]
# Act
response_gen = converse_falcon(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert len(response) > 0
# Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes
assert "Testville" in response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
"Chat actor needs to use context across currently retrieved notes and chat history to answer question"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
# Act
response_gen = converse_falcon(
references=[
"Testatron was born on 1st April 1984 in Testville."
], # Assume context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert len(response) > 0
assert "Testville" in response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor is rather liable to lying.")
@pytest.mark.chatquality
def test_refuse_answering_unanswerable_question(loaded_model):
"Chat actor should not try make up answers to unanswerable questions."
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
]
# Act
response_gen = converse_falcon(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = [
"don't know",
"do not know",
"no information",
"do not have",
"don't have",
"cannot answer",
"I'm sorry",
]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected chat actor to say they don't know in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_requires_current_date_awareness(loaded_model):
"Chat actor should be able to answer questions relative to current date using provided notes"
# Arrange
context = [
f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
Expenses:Food:Dining 10.00 USD""",
f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
Expenses:Food:Dining 10.00 USD""",
f"""2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries 10.00 USD""",
f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining 10.00 USD""",
]
# Act
response_gen = converse_falcon(
references=context, # Assume context retrieved from notes for the user_query
user_query="What did I have for Dinner today?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["tacos", "Tacos"]
assert len(response) > 0
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected [T|t]acos in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_requires_date_aware_aggregation_across_provided_notes(loaded_model):
"Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
# Arrange
context = [
f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
Expenses:Food:Dining 10.00 USD""",
f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
Expenses:Food:Dining 10.00 USD""",
f"""2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries 10.00 USD""",
f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining 10.00 USD""",
]
# Act
response_gen = converse_falcon(
references=context, # Assume context retrieved from notes for the user_query
user_query="How much did I spend on dining this year?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
assert len(response) > 0
assert "20" in response
# ----------------------------------------------------------------------------------------------------
@pytest.mark.chatquality
def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded_model):
"Chat actor should be able to answer general questions not requiring looking at chat history or notes"
# Arrange
message_list = [
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
("When was I born?", "You were born on 1st April 1984.", []),
("Where was I born?", "You were born Testville.", []),
]
# Act
response_gen = converse_falcon(
references=[], # Assume no context retrieved from notes for the user_query
user_query="Write a haiku about unit testing in 3 lines",
conversation_log=populate_chat_history(message_list),
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["test", "testing"]
assert len(response.splitlines()) >= 3 # haikus are 3 lines long, but Falcon tends to add a lot of new lines.
assert any([expected_response in response.lower() for expected_response in expected_responses]), (
"Expected [T|t]est in response, but got: " + response
)
# ----------------------------------------------------------------------------------------------------
@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.")
@pytest.mark.chatquality
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
# Arrange
context = [
f"""# Ramya
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""",
f"""# Fang
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""",
f"""# Aiyla
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""",
]
# Act
response_gen = converse_falcon(
references=context, # Assume context retrieved from notes for the user_query
user_query="How many kids does my older sister have?",
loaded_model=loaded_model,
)
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"]
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected chat actor to ask for clarification in response, but got: " + response
)
# Helpers
# ----------------------------------------------------------------------------------------------------
def populate_chat_history(message_list):
# Generate conversation logs
conversation_log = {"chat": []}
for user_message, chat_response, context in message_list:
message_to_log(
user_message,
chat_response,
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
conversation_log=conversation_log["chat"],
)
return conversation_log

View file

@ -8,7 +8,7 @@ import freezegun
from freezegun import freeze_time
# Internal Packages
from khoj.processor.conversation.gpt import converse, extract_questions
from khoj.processor.conversation.openai.gpt import converse, extract_questions
from khoj.processor.conversation.utils import message_to_log