mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Add support for our first Local LLM 🤖🏠 (#330)
* Add support for gpt4all's falcon model as an additional conversation processor - Update the UI pages to allow the user to point to the new endpoints for GPT - Update the internal schemas to support both GPT4 models and OpenAI - Add unit tests benchmarking some of the Falcon performance * Add exc_info to include stack trace in error logs for text processors * Pull shared functions into utils.py to be used across gpt4 and gpt * Add migration for new processor conversation schema * Skip GPT4All actor tests due to typing issues * Fix Obsidian processor configuration in auto-configure flow * Rename enable_local_llm to enable_offline_chat
This commit is contained in:
parent
23d77ee338
commit
8b2af0b5ef
34 changed files with 1258 additions and 291 deletions
1
.github/workflows/test.yml
vendored
1
.github/workflows/test.yml
vendored
|
@ -30,7 +30,6 @@ jobs:
|
|||
fail-fast: false
|
||||
matrix:
|
||||
python_version:
|
||||
- '3.8'
|
||||
- '3.9'
|
||||
- '3.10'
|
||||
- '3.11'
|
||||
|
|
|
@ -58,6 +58,7 @@ dependencies = [
|
|||
"pypdf >= 3.9.0",
|
||||
"requests >= 2.26.0",
|
||||
"bs4 >= 0.0.1",
|
||||
"gpt4all==1.0.5",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
|
|
@ -373,7 +373,7 @@ CONFIG is json obtained from Khoj config API."
|
|||
(ignore-error json-end-of-file (json-parse-buffer :object-type 'alist :array-type 'list :null-object json-null :false-object json-false))))
|
||||
(default-index-dir (khoj--get-directory-from-config default-config '(content-type org embeddings-file)))
|
||||
(default-chat-dir (khoj--get-directory-from-config default-config '(processor conversation conversation-logfile)))
|
||||
(chat-model (or khoj-chat-model (alist-get 'chat-model (alist-get 'conversation (alist-get 'processor default-config)))))
|
||||
(chat-model (or khoj-chat-model (alist-get 'chat-model (alist-get 'openai (alist-get 'conversation (alist-get 'processor default-config))))))
|
||||
(default-model (alist-get 'model (alist-get 'conversation (alist-get 'processor default-config))))
|
||||
(config (or current-config default-config)))
|
||||
|
||||
|
@ -423,15 +423,27 @@ CONFIG is json obtained from Khoj config API."
|
|||
;; Configure processors
|
||||
(cond
|
||||
((not khoj-openai-api-key)
|
||||
(setq config (delq (assoc 'processor config) config)))
|
||||
(let* ((processor (assoc 'processor config))
|
||||
(conversation (assoc 'conversation processor))
|
||||
(openai (assoc 'openai conversation)))
|
||||
(when openai
|
||||
;; Unset the `openai' field in the khoj conversation processor config
|
||||
(message "khoj.el: disable Khoj Chat using OpenAI as your OpenAI API key got removed from config")
|
||||
(setcdr conversation (delq openai (cdr conversation)))
|
||||
(setcdr processor (delq conversation (cdr processor)))
|
||||
(setq config (delq processor config))
|
||||
(push conversation (cdr processor))
|
||||
(push processor config))))
|
||||
|
||||
((not current-config)
|
||||
(message "khoj.el: Chat not configured yet.")
|
||||
(setq config (delq (assoc 'processor config) config))
|
||||
(cl-pushnew `(processor . ((conversation . ((conversation-logfile . ,(format "%s/conversation.json" default-chat-dir))
|
||||
(chat-model . ,chat-model)
|
||||
(model . ,default-model)
|
||||
(openai-api-key . ,khoj-openai-api-key)))))
|
||||
(openai . (
|
||||
(chat-model . ,chat-model)
|
||||
(api-key . ,khoj-openai-api-key)
|
||||
))
|
||||
))))
|
||||
config))
|
||||
|
||||
((not (alist-get 'conversation (alist-get 'processor config)))
|
||||
|
@ -440,21 +452,19 @@ CONFIG is json obtained from Khoj config API."
|
|||
(setq new-processor-type (delq (assoc 'conversation new-processor-type) new-processor-type))
|
||||
(cl-pushnew `(conversation . ((conversation-logfile . ,(format "%s/conversation.json" default-chat-dir))
|
||||
(chat-model . ,chat-model)
|
||||
(model . ,default-model)
|
||||
(openai-api-key . ,khoj-openai-api-key)))
|
||||
new-processor-type)
|
||||
(setq config (delq (assoc 'processor config) config))
|
||||
(cl-pushnew `(processor . ,new-processor-type) config)))
|
||||
|
||||
;; Else if khoj is not configured with specified openai api key
|
||||
((not (and (equal (alist-get 'openai-api-key (alist-get 'conversation (alist-get 'processor config))) khoj-openai-api-key)
|
||||
(equal (alist-get 'chat-model (alist-get 'conversation (alist-get 'processor config))) khoj-chat-model)))
|
||||
((not (and (equal (alist-get 'api-key (alist-get 'openai (alist-get 'conversation (alist-get 'processor config)))) khoj-openai-api-key)
|
||||
(equal (alist-get 'chat-model (alist-get 'openai (alist-get 'conversation (alist-get 'processor config)))) khoj-chat-model)))
|
||||
(message "khoj.el: Chat configuration has gone stale.")
|
||||
(let* ((chat-directory (khoj--get-directory-from-config config '(processor conversation conversation-logfile)))
|
||||
(new-processor-type (alist-get 'processor config)))
|
||||
(setq new-processor-type (delq (assoc 'conversation new-processor-type) new-processor-type))
|
||||
(cl-pushnew `(conversation . ((conversation-logfile . ,(format "%s/conversation.json" chat-directory))
|
||||
(model . ,default-model)
|
||||
(chat-model . ,khoj-chat-model)
|
||||
(openai-api-key . ,khoj-openai-api-key)))
|
||||
new-processor-type)
|
||||
|
|
|
@ -36,7 +36,7 @@ export async function configureKhojBackend(vault: Vault, setting: KhojSetting, n
|
|||
let khojDefaultMdIndexDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["content-type"]["markdown"]["embeddings-file"]);
|
||||
let khojDefaultPdfIndexDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["content-type"]["pdf"]["embeddings-file"]);
|
||||
let khojDefaultChatDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["processor"]["conversation"]["conversation-logfile"]);
|
||||
let khojDefaultChatModelName = defaultConfig["processor"]["conversation"]["model"];
|
||||
let khojDefaultChatModelName = defaultConfig["processor"]["conversation"]["openai"]["chat-model"];
|
||||
|
||||
// Get current config if khoj backend configured, else get default config from khoj backend
|
||||
await request(khoj_already_configured ? khojConfigUrl : `${khojConfigUrl}/default`)
|
||||
|
@ -142,25 +142,35 @@ export async function configureKhojBackend(vault: Vault, setting: KhojSetting, n
|
|||
data["processor"] = {
|
||||
"conversation": {
|
||||
"conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`,
|
||||
"model": khojDefaultChatModelName,
|
||||
"openai-api-key": setting.openaiApiKey,
|
||||
}
|
||||
"openai": {
|
||||
"chat-model": khojDefaultChatModelName,
|
||||
"api-key": setting.openaiApiKey,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
// Else if khoj config has no conversation processor config
|
||||
else if (!data["processor"]["conversation"]) {
|
||||
data["processor"]["conversation"] = {
|
||||
"conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`,
|
||||
"model": khojDefaultChatModelName,
|
||||
"openai-api-key": setting.openaiApiKey,
|
||||
else if (!data["processor"]["conversation"] || !data["processor"]["conversation"]["openai"]) {
|
||||
data["processor"] = {
|
||||
"conversation": {
|
||||
"conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`,
|
||||
"openai": {
|
||||
"chat-model": khojDefaultChatModelName,
|
||||
"api-key": setting.openaiApiKey,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
// Else if khoj is not configured with OpenAI API key from khoj plugin settings
|
||||
else if (data["processor"]["conversation"]["openai-api-key"] !== setting.openaiApiKey) {
|
||||
data["processor"]["conversation"] = {
|
||||
"conversation-logfile": data["processor"]["conversation"]["conversation-logfile"],
|
||||
"model": data["processor"]["conversation"]["model"],
|
||||
"openai-api-key": setting.openaiApiKey,
|
||||
else if (data["processor"]["conversation"]["openai"]["api-key"] !== setting.openaiApiKey) {
|
||||
data["processor"] = {
|
||||
"conversation": {
|
||||
"conversation-logfile": data["processor"]["conversation"]["conversation-logfile"],
|
||||
"openai": {
|
||||
"chat-model": data["processor"]["conversation"]["openai"]["chat-model"],
|
||||
"api-key": setting.openaiApiKey,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@ import schedule
|
|||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation.gpt import summarize
|
||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||
|
@ -28,7 +27,7 @@ from khoj.utils.config import (
|
|||
ConversationProcessorConfigModel,
|
||||
)
|
||||
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
|
||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig
|
||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig, ConversationProcessorConfig
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
|
@ -64,7 +63,7 @@ def configure_server(config: FullConfig, regenerate: bool, search_type: Optional
|
|||
state.config_lock.acquire()
|
||||
state.processor_config = configure_processor(state.config.processor)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to configure processor")
|
||||
logger.error(f"🚨 Failed to configure processor", exc_info=True)
|
||||
raise e
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
|
@ -75,7 +74,7 @@ def configure_server(config: FullConfig, regenerate: bool, search_type: Optional
|
|||
state.SearchType = configure_search_types(state.config)
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to configure search models")
|
||||
logger.error(f"🚨 Failed to configure search models", exc_info=True)
|
||||
raise e
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
|
@ -88,7 +87,7 @@ def configure_server(config: FullConfig, regenerate: bool, search_type: Optional
|
|||
state.content_index, state.config.content_type, state.search_models, regenerate, search_type
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Failed to index content")
|
||||
logger.error(f"🚨 Failed to index content", exc_info=True)
|
||||
raise e
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
|
@ -117,7 +116,7 @@ if not state.demo:
|
|||
)
|
||||
logger.info("📬 Content index updated via Scheduler")
|
||||
except Exception as e:
|
||||
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
|
||||
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)
|
||||
finally:
|
||||
state.config_lock.release()
|
||||
|
||||
|
@ -258,7 +257,9 @@ def configure_content(
|
|||
return content_index
|
||||
|
||||
|
||||
def configure_processor(processor_config: Optional[ProcessorConfig]):
|
||||
def configure_processor(
|
||||
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
|
||||
):
|
||||
if not processor_config:
|
||||
logger.warning("🚨 No Processor configuration available.")
|
||||
return None
|
||||
|
@ -266,16 +267,47 @@ def configure_processor(processor_config: Optional[ProcessorConfig]):
|
|||
processor = ProcessorConfigModel()
|
||||
|
||||
# Initialize Conversation Processor
|
||||
if processor_config.conversation:
|
||||
logger.info("💬 Setting up conversation processor")
|
||||
processor.conversation = configure_conversation_processor(processor_config.conversation)
|
||||
logger.info("💬 Setting up conversation processor")
|
||||
processor.conversation = configure_conversation_processor(processor_config, state_processor_config)
|
||||
|
||||
return processor
|
||||
|
||||
|
||||
def configure_conversation_processor(conversation_processor_config):
|
||||
conversation_processor = ConversationProcessorConfigModel(conversation_processor_config)
|
||||
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
|
||||
def configure_conversation_processor(
|
||||
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
|
||||
):
|
||||
if (
|
||||
not processor_config
|
||||
or not processor_config.conversation
|
||||
or not processor_config.conversation.conversation_logfile
|
||||
):
|
||||
default_config = constants.default_config
|
||||
default_conversation_logfile = resolve_absolute_path(
|
||||
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
|
||||
)
|
||||
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
|
||||
conversation_config = processor_config.conversation if processor_config else None
|
||||
conversation_processor = ConversationProcessorConfigModel(
|
||||
conversation_config=ConversationProcessorConfig(
|
||||
conversation_logfile=conversation_logfile,
|
||||
openai=(conversation_config.openai if (conversation_config is not None) else None),
|
||||
enable_offline_chat=(
|
||||
conversation_config.enable_offline_chat if (conversation_config is not None) else False
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
conversation_processor = ConversationProcessorConfigModel(
|
||||
conversation_config=processor_config.conversation,
|
||||
)
|
||||
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
|
||||
|
||||
# Load Conversation Logs from Disk
|
||||
if state_processor_config and state_processor_config.conversation and state_processor_config.conversation.meta_log:
|
||||
conversation_processor.meta_log = state_processor_config.conversation.meta_log
|
||||
conversation_processor.chat_session = state_processor_config.conversation.chat_session
|
||||
logger.debug(f"Loaded conversation logs from state")
|
||||
return conversation_processor
|
||||
|
||||
if conversation_logfile.is_file():
|
||||
# Load Metadata Logs from Conversation Logfile
|
||||
|
@ -302,12 +334,8 @@ def save_chat_session():
|
|||
return
|
||||
|
||||
# Summarize Conversation Logs for this Session
|
||||
chat_session = state.processor_config.conversation.chat_session
|
||||
openai_api_key = state.processor_config.conversation.openai_api_key
|
||||
conversation_log = state.processor_config.conversation.meta_log
|
||||
chat_model = state.processor_config.conversation.chat_model
|
||||
session = {
|
||||
"summary": summarize(chat_session, model=chat_model, api_key=openai_api_key),
|
||||
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
||||
"session-end": len(conversation_log["chat"]),
|
||||
}
|
||||
|
@ -344,6 +372,6 @@ def upload_telemetry():
|
|||
log[field] = str(log[field])
|
||||
requests.post(constants.telemetry_server, json=state.telemetry)
|
||||
except Exception as e:
|
||||
logger.error(f"📡 Error uploading telemetry: {e}")
|
||||
logger.error(f"📡 Error uploading telemetry: {e}", exc_info=True)
|
||||
else:
|
||||
state.telemetry = []
|
||||
|
|
1
src/khoj/interface/web/assets/icons/openai-logomark.svg
Normal file
1
src/khoj/interface/web/assets/icons/openai-logomark.svg
Normal file
|
@ -0,0 +1 @@
|
|||
<svg viewBox="0 0 320 320" xmlns="http://www.w3.org/2000/svg"><path d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>
|
After Width: | Height: | Size: 1.7 KiB |
|
@ -167,10 +167,42 @@
|
|||
text-align: left;
|
||||
}
|
||||
|
||||
button.card-button.happy {
|
||||
color: rgb(0, 146, 0);
|
||||
}
|
||||
|
||||
img.configured-icon {
|
||||
max-width: 16px;
|
||||
}
|
||||
|
||||
div.card-action-row.enabled{
|
||||
display: block;
|
||||
}
|
||||
|
||||
img.configured-icon.enabled {
|
||||
display: inline;
|
||||
}
|
||||
|
||||
div.card-action-row.disabled,
|
||||
img.configured-icon.disabled {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.loader {
|
||||
border: 16px solid #f3f3f3; /* Light grey */
|
||||
border-top: 16px solid var(--primary);
|
||||
border-radius: 50%;
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
animation: spin 2s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
|
||||
div.finalize-actions {
|
||||
grid-auto-flow: column;
|
||||
grid-gap: 24px;
|
||||
|
|
|
@ -135,8 +135,8 @@
|
|||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.detail) {
|
||||
// If the server returns a 500 error with detail, render it as a message.
|
||||
renderMessage("Hi 👋🏾, to get started <br/>1. Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a><br/>2. Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation'>chat settings</a> <br/>3. Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>", "khoj");
|
||||
// If the server returns a 500 error with detail, render a setup hint.
|
||||
renderMessage("Hi 👋🏾, to get started <br/>1. Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a><br/>2. Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a> <br/>3. Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>", "khoj");
|
||||
|
||||
// Disable chat input field and update placeholder text
|
||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p class="card-description">Set repositories for Khoj to index</p>
|
||||
<p class="card-description">Set repositories to index</p>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content_type/github">
|
||||
|
@ -90,7 +90,7 @@
|
|||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p class="card-description">Set markdown files for Khoj to index</p>
|
||||
<p class="card-description">Set markdown files to index</p>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content_type/markdown">
|
||||
|
@ -125,7 +125,7 @@
|
|||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p class="card-description">Set org files for Khoj to index</p>
|
||||
<p class="card-description">Set org files to index</p>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content_type/org">
|
||||
|
@ -160,7 +160,7 @@
|
|||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p class="card-description">Set PDF files for Khoj to index</p>
|
||||
<p class="card-description">Set PDF files to index</p>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/content_type/pdf">
|
||||
|
@ -187,10 +187,10 @@
|
|||
<div class="section-cards">
|
||||
<div class="card">
|
||||
<div class="card-title-row">
|
||||
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||
<img class="card-icon" src="/static/assets/icons/openai-logomark.svg" alt="Chat">
|
||||
<h3 class="card-title">
|
||||
Chat
|
||||
{% if current_config.processor and current_config.processor.conversation %}
|
||||
{% if current_config.processor and current_config.processor.conversation.openai %}
|
||||
{% if current_model_state.conversation == False %}
|
||||
<img id="misconfigured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||
{% else %}
|
||||
|
@ -200,11 +200,11 @@
|
|||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p class="card-description">Setup Khoj Chat with OpenAI</p>
|
||||
<p class="card-description">Setup chat using OpenAI</p>
|
||||
</div>
|
||||
<div class="card-action-row">
|
||||
<a class="card-button" href="/config/processor/conversation">
|
||||
{% if current_config.processor and current_config.processor.conversation %}
|
||||
<a class="card-button" href="/config/processor/conversation/openai">
|
||||
{% if current_config.processor and current_config.processor.conversation.openai %}
|
||||
Update
|
||||
{% else %}
|
||||
Setup
|
||||
|
@ -212,7 +212,7 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||
</a>
|
||||
</div>
|
||||
{% if current_config.processor and current_config.processor.conversation %}
|
||||
{% if current_config.processor and current_config.processor.conversation.openai %}
|
||||
<div id="clear-conversation" class="card-action-row">
|
||||
<button class="card-button" onclick="clearConversationProcessor()">
|
||||
Disable
|
||||
|
@ -220,6 +220,31 @@
|
|||
</div>
|
||||
{% endif %}
|
||||
</div>
|
||||
<div class="card">
|
||||
<div class="card-title-row">
|
||||
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||
<h3 class="card-title">
|
||||
Offline Chat
|
||||
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||
</h3>
|
||||
</div>
|
||||
<div class="card-description-row">
|
||||
<p class="card-description">Setup offline chat (Falcon 7B)</p>
|
||||
</div>
|
||||
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
|
||||
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
|
||||
Disable
|
||||
</button>
|
||||
</div>
|
||||
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
|
||||
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
|
||||
Enable
|
||||
</button>
|
||||
</div>
|
||||
<div id="toggle-enable-offline-chat" class="card-action-row disabled">
|
||||
<div class="loader"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="section">
|
||||
|
@ -263,9 +288,59 @@
|
|||
})
|
||||
};
|
||||
|
||||
function toggleEnableLocalLLLM(enable) {
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
var toggleEnableLocalLLLMButton = document.getElementById("toggle-enable-offline-chat");
|
||||
toggleEnableLocalLLLMButton.classList.remove("disabled");
|
||||
toggleEnableLocalLLLMButton.classList.add("enabled");
|
||||
|
||||
fetch('/api/config/data/processor/conversation/enable_offline_chat' + '?enable_offline_chat=' + enable, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-CSRFToken': csrfToken
|
||||
},
|
||||
})
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.status == "ok") {
|
||||
// Toggle the Enabled/Disabled UI based on the action/response.
|
||||
var enableLocalLLLMButton = document.getElementById("set-enable-offline-chat");
|
||||
var disableLocalLLLMButton = document.getElementById("clear-enable-offline-chat");
|
||||
var configuredIcon = document.getElementById("configured-icon-conversation-enable-offline-chat");
|
||||
var toggleEnableLocalLLLMButton = document.getElementById("toggle-enable-offline-chat");
|
||||
|
||||
toggleEnableLocalLLLMButton.classList.remove("enabled");
|
||||
toggleEnableLocalLLLMButton.classList.add("disabled");
|
||||
|
||||
|
||||
if (enable) {
|
||||
enableLocalLLLMButton.classList.add("disabled");
|
||||
enableLocalLLLMButton.classList.remove("enabled");
|
||||
|
||||
configuredIcon.classList.add("enabled");
|
||||
configuredIcon.classList.remove("disabled");
|
||||
|
||||
disableLocalLLLMButton.classList.remove("disabled");
|
||||
disableLocalLLLMButton.classList.add("enabled");
|
||||
} else {
|
||||
enableLocalLLLMButton.classList.remove("disabled");
|
||||
enableLocalLLLMButton.classList.add("enabled");
|
||||
|
||||
configuredIcon.classList.remove("enabled");
|
||||
configuredIcon.classList.add("disabled");
|
||||
|
||||
disableLocalLLLMButton.classList.add("disabled");
|
||||
disableLocalLLLMButton.classList.remove("enabled");
|
||||
}
|
||||
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
function clearConversationProcessor() {
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
fetch('/api/delete/config/data/processor/conversation', {
|
||||
fetch('/api/delete/config/data/processor/conversation/openai', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
@ -319,7 +394,7 @@
|
|||
function updateIndex(force, successText, errorText, button, loadingText, emoji) {
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
button.disabled = true;
|
||||
button.innerHTML = emoji + loadingText;
|
||||
button.innerHTML = emoji + " " + loadingText;
|
||||
fetch('/api/update?&client=web&force=' + force, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
<label for="openai-api-key" title="Get your OpenAI key from https://platform.openai.com/account/api-keys">OpenAI API key</label>
|
||||
</td>
|
||||
<td>
|
||||
<input type="text" id="openai-api-key" name="openai-api-key" value="{{ current_config['openai_api_key'] }}">
|
||||
<input type="text" id="openai-api-key" name="openai-api-key" value="{{ current_config['api_key'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
|
@ -25,24 +25,6 @@
|
|||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<table style="display: none;">
|
||||
<tr>
|
||||
<td>
|
||||
<label for="conversation-logfile">Conversation Logfile</label>
|
||||
</td>
|
||||
<td>
|
||||
<input type="text" id="conversation-logfile" name="conversation-logfile" value="{{ current_config['conversation_logfile'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<label for="model">Model</label>
|
||||
</td>
|
||||
<td>
|
||||
<input type="text" id="model" name="model" value="{{ current_config['model'] }}">
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
<div class="section">
|
||||
<div id="success" style="display: none;" ></div>
|
||||
<button id="submit" type="submit">Save</button>
|
||||
|
@ -54,21 +36,23 @@
|
|||
submit.addEventListener("click", function(event) {
|
||||
event.preventDefault();
|
||||
var openai_api_key = document.getElementById("openai-api-key").value;
|
||||
var conversation_logfile = document.getElementById("conversation-logfile").value;
|
||||
var model = document.getElementById("model").value;
|
||||
var chat_model = document.getElementById("chat-model").value;
|
||||
|
||||
if (openai_api_key == "" || chat_model == "") {
|
||||
document.getElementById("success").innerHTML = "⚠️ Please fill all the fields.";
|
||||
document.getElementById("success").style.display = "block";
|
||||
return;
|
||||
}
|
||||
|
||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||
fetch('/api/config/data/processor/conversation', {
|
||||
fetch('/api/config/data/processor/conversation/openai', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-CSRFToken': csrfToken
|
||||
},
|
||||
body: JSON.stringify({
|
||||
"openai_api_key": openai_api_key,
|
||||
"conversation_logfile": conversation_logfile,
|
||||
"model": model,
|
||||
"api_key": openai_api_key,
|
||||
"chat_model": chat_model
|
||||
})
|
||||
})
|
||||
|
|
0
src/khoj/migrations/__init__.py
Normal file
0
src/khoj/migrations/__init__.py
Normal file
66
src/khoj/migrations/migrate_processor_config_openai.py
Normal file
66
src/khoj/migrations/migrate_processor_config_openai.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
"""
|
||||
Current format of khoj.yml
|
||||
---
|
||||
app:
|
||||
should-log-telemetry: true
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
chat-model: gpt-3.5-turbo
|
||||
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||
model: text-davinci-003
|
||||
openai-api-key: sk-secret-key
|
||||
search-type:
|
||||
...
|
||||
|
||||
New format of khoj.yml
|
||||
---
|
||||
app:
|
||||
should-log-telemetry: true
|
||||
content-type:
|
||||
...
|
||||
processor:
|
||||
conversation:
|
||||
openai:
|
||||
chat-model: gpt-3.5-turbo
|
||||
openai-api-key: sk-secret-key
|
||||
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||
enable-offline-chat: false
|
||||
search-type:
|
||||
...
|
||||
"""
|
||||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
|
||||
def migrate_processor_conversation_schema(args):
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
|
||||
raw_config["version"] = args.version_no
|
||||
|
||||
if "processor" not in raw_config:
|
||||
return args
|
||||
if raw_config["processor"] is None:
|
||||
return args
|
||||
if "conversation" not in raw_config["processor"]:
|
||||
return args
|
||||
|
||||
# Add enable_offline_chat to khoj config schema
|
||||
if "enable-offline-chat" not in raw_config["processor"]["conversation"]:
|
||||
raw_config["processor"]["conversation"]["enable-offline-chat"] = False
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
||||
current_openai_api_key = raw_config["processor"]["conversation"].get("openai-api-key", None)
|
||||
current_chat_model = raw_config["processor"]["conversation"].get("chat-model", None)
|
||||
if current_openai_api_key is None and current_chat_model is None:
|
||||
return args
|
||||
|
||||
conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None)
|
||||
|
||||
raw_config["processor"]["conversation"] = {
|
||||
"openai": {"chat-model": current_chat_model, "api-key": current_openai_api_key},
|
||||
"conversation-logfile": conversation_logfile,
|
||||
"enable-offline-chat": False,
|
||||
}
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
return args
|
16
src/khoj/migrations/migrate_version.py
Normal file
16
src/khoj/migrations/migrate_version.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||
|
||||
|
||||
def migrate_config_to_version(args):
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
|
||||
# Add version to khoj config schema
|
||||
if "version" not in raw_config:
|
||||
raw_config["version"] = args.version_no
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
||||
# regenerate khoj index on first start of this version
|
||||
# this should refresh index and apply index corruption fixes from #325
|
||||
args.regenerate = True
|
||||
|
||||
return args
|
0
src/khoj/processor/conversation/gpt4all/__init__.py
Normal file
0
src/khoj/processor/conversation/gpt4all/__init__.py
Normal file
137
src/khoj/processor/conversation/gpt4all/chat_model.py
Normal file
137
src/khoj/processor/conversation/gpt4all/chat_model.py
Normal file
|
@ -0,0 +1,137 @@
|
|||
from typing import Union, List
|
||||
from datetime import datetime
|
||||
import sys
|
||||
import logging
|
||||
from threading import Thread
|
||||
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from gpt4all import GPT4All
|
||||
|
||||
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_questions_falcon(
|
||||
text: str,
|
||||
model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
|
||||
loaded_model: Union[GPT4All, None] = None,
|
||||
conversation_log={},
|
||||
use_history: bool = False,
|
||||
run_extraction: bool = False,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
"""
|
||||
all_questions = text.split("? ")
|
||||
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
|
||||
if not run_extraction:
|
||||
return all_questions
|
||||
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
chat_history = ""
|
||||
|
||||
if use_history:
|
||||
chat_history = "".join(
|
||||
[
|
||||
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\nA: {chat["message"]}\n\n'
|
||||
for chat in conversation_log.get("chat", [])[-4:]
|
||||
if chat["by"] == "khoj"
|
||||
]
|
||||
)
|
||||
|
||||
prompt = prompts.extract_questions_falcon.format(
|
||||
chat_history=chat_history,
|
||||
text=text,
|
||||
)
|
||||
message = prompts.general_conversation_falcon.format(query=prompt)
|
||||
response = gpt4all_model.generate(message, max_tokens=200, top_k=2)
|
||||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
try:
|
||||
questions = (
|
||||
str(response)
|
||||
.strip(empty_escape_sequences)
|
||||
.replace("['", '["')
|
||||
.replace("']", '"]')
|
||||
.replace("', '", '", "')
|
||||
.replace('["', "")
|
||||
.replace('"]', "")
|
||||
.split('", "')
|
||||
)
|
||||
except:
|
||||
logger.warning(f"Falcon returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||
return all_questions
|
||||
logger.debug(f"Extracted Questions by Falcon: {questions}")
|
||||
questions.extend(all_questions)
|
||||
return questions
|
||||
|
||||
|
||||
def converse_falcon(
|
||||
references,
|
||||
user_query,
|
||||
conversation_log={},
|
||||
model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
|
||||
loaded_model: Union[GPT4All, None] = None,
|
||||
completion_func=None,
|
||||
) -> ThreadedGenerator:
|
||||
"""
|
||||
Converse with user using Falcon
|
||||
"""
|
||||
gpt4all_model = loaded_model or GPT4All(model)
|
||||
# Initialize Variables
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
# TODO If compiled_references_message is too long, we need to truncate it.
|
||||
if compiled_references_message == "":
|
||||
conversation_primer = prompts.conversation_falcon.format(query=user_query)
|
||||
else:
|
||||
conversation_primer = prompts.notes_conversation.format(
|
||||
current_date=current_date, query=user_query, references=compiled_references_message
|
||||
)
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
prompts.personality.format(),
|
||||
conversation_log,
|
||||
model_name="text-davinci-001", # This isn't actually the model, but this helps us get an approximate encoding to run message truncation.
|
||||
)
|
||||
|
||||
g = ThreadedGenerator(references, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, gpt4all_model))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
|
||||
user_message = messages[0]
|
||||
system_message = messages[-1]
|
||||
conversation_history = messages[1:-1]
|
||||
|
||||
formatted_messages = [
|
||||
prompts.chat_history_falcon_from_assistant.format(message=system_message)
|
||||
if message.role == "assistant"
|
||||
else prompts.chat_history_falcon_from_user.format(message=message.content)
|
||||
for message in conversation_history
|
||||
]
|
||||
|
||||
chat_history = "".join(formatted_messages)
|
||||
full_message = system_message.content + chat_history + user_message.content
|
||||
|
||||
prompted_message = prompts.general_conversation_falcon.format(query=full_message)
|
||||
response_iterator = model.generate(
|
||||
prompted_message, streaming=True, max_tokens=256, top_k=1, temp=0, repeat_penalty=2.0
|
||||
)
|
||||
for response in response_iterator:
|
||||
logger.info(response)
|
||||
g.send(response)
|
||||
g.close()
|
0
src/khoj/processor/conversation/openai/__init__.py
Normal file
0
src/khoj/processor/conversation/openai/__init__.py
Normal file
|
@ -9,11 +9,11 @@ from langchain.schema import ChatMessage
|
|||
# Internal Packages
|
||||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
from khoj.processor.conversation.openai.utils import (
|
||||
chat_completion_with_backoff,
|
||||
completion_with_backoff,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
101
src/khoj/processor/conversation/openai/utils.py
Normal file
101
src/khoj/processor/conversation/openai/utils.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
# Standard Packages
|
||||
import os
|
||||
import logging
|
||||
from typing import Any
|
||||
from threading import Thread
|
||||
|
||||
# External Packages
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
import openai
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
def __init__(self, gen: ThreadedGenerator):
|
||||
super().__init__()
|
||||
self.gen = gen
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
||||
self.gen.send(token)
|
||||
|
||||
|
||||
@retry(
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
),
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(3),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(**kwargs):
|
||||
messages = kwargs.pop("messages")
|
||||
if not "openai_api_key" in kwargs:
|
||||
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
||||
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
|
||||
return llm(messages=messages)
|
||||
|
||||
|
||||
@retry(
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
stop=stop_after_attempt(3),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def chat_completion_with_backoff(
|
||||
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
||||
t.start()
|
||||
return g
|
||||
|
||||
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
||||
callback_handler = StreamingChatCallbackHandler(g)
|
||||
chat = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callback_manager=BaseCallbackManager([callback_handler]),
|
||||
model_name=model_name, # type: ignore
|
||||
temperature=temperature,
|
||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||
request_timeout=20,
|
||||
max_retries=1,
|
||||
client=None,
|
||||
)
|
||||
|
||||
chat(messages=messages)
|
||||
|
||||
g.close()
|
||||
|
||||
|
||||
def extract_summaries(metadata):
|
||||
"""Extract summaries from metadata"""
|
||||
return "".join([f'\n{session["summary"]}' for session in metadata])
|
|
@ -18,6 +18,36 @@ Question: {query}
|
|||
""".strip()
|
||||
)
|
||||
|
||||
general_conversation_falcon = PromptTemplate.from_template(
|
||||
"""
|
||||
Using your general knowledge and our past conversations as context, answer the following question.
|
||||
### Instruct:
|
||||
{query}
|
||||
### Response:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
chat_history_falcon_from_user = PromptTemplate.from_template(
|
||||
"""
|
||||
### Human:
|
||||
{message}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
chat_history_falcon_from_assistant = PromptTemplate.from_template(
|
||||
"""
|
||||
### Assistant:
|
||||
{message}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
conversation_falcon = PromptTemplate.from_template(
|
||||
"""
|
||||
Using our past conversations as context, answer the following question.
|
||||
|
||||
Question: {query}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## Notes Conversation
|
||||
## --
|
||||
|
@ -33,6 +63,17 @@ Question: {query}
|
|||
""".strip()
|
||||
)
|
||||
|
||||
notes_conversation_falcon = PromptTemplate.from_template(
|
||||
"""
|
||||
Using the notes and our past conversations as context, answer the following question. If the answer is not contained within the notes, say "I don't know."
|
||||
|
||||
Notes:
|
||||
{references}
|
||||
|
||||
Question: {query}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
## Summarize Chat
|
||||
## --
|
||||
|
@ -68,6 +109,40 @@ Question: {user_query}
|
|||
Answer (in second person):"""
|
||||
)
|
||||
|
||||
extract_questions_falcon = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes.
|
||||
- The user will provide their questions and answers to you for context.
|
||||
- Add as much context from the previous questions and answers as required into your search queries.
|
||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
|
||||
What searches, if any, will you need to perform to answer the users question?
|
||||
|
||||
Q: How was my trip to Cambodia?
|
||||
|
||||
["How was my trip to Cambodia?"]
|
||||
|
||||
A: The trip was amazing. I went to the Angkor Wat temple and it was beautiful.
|
||||
|
||||
Q: Who did i visit that temple with?
|
||||
|
||||
["Who did I visit the Angkor Wat Temple in Cambodia with?"]
|
||||
|
||||
A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
|
||||
|
||||
Q: How many tennis balls fit in the back of a 2002 Honda Civic?
|
||||
|
||||
["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]
|
||||
|
||||
A: 1085 tennis balls will fit in the trunk of a Honda Civic
|
||||
|
||||
{chat_history}
|
||||
Q: {text}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
## Extract Questions
|
||||
## --
|
||||
|
|
|
@ -1,35 +1,19 @@
|
|||
# Standard Packages
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
from threading import Thread
|
||||
import json
|
||||
|
||||
# External Packages
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import ChatMessage
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
import openai
|
||||
from datetime import datetime
|
||||
import tiktoken
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
# External packages
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
import queue
|
||||
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
|
||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "text-davinci-001": 910}
|
||||
|
||||
|
||||
class ThreadedGenerator:
|
||||
|
@ -49,9 +33,9 @@ class ThreadedGenerator:
|
|||
time_to_response = perf_counter() - self.start_time
|
||||
logger.info(f"Chat streaming took: {time_to_response:.3f} seconds")
|
||||
if self.completion_func:
|
||||
# The completion func effective acts as a callback.
|
||||
# It adds the aggregated response to the conversation history. It's constructed in api.py.
|
||||
self.completion_func(gpt_response=self.response)
|
||||
# The completion func effectively acts as a callback.
|
||||
# It adds the aggregated response to the conversation history.
|
||||
self.completion_func(chat_response=self.response)
|
||||
raise StopIteration
|
||||
return item
|
||||
|
||||
|
@ -65,75 +49,25 @@ class ThreadedGenerator:
|
|||
self.queue.put(StopIteration)
|
||||
|
||||
|
||||
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
def __init__(self, gen: ThreadedGenerator):
|
||||
super().__init__()
|
||||
self.gen = gen
|
||||
|
||||
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
||||
self.gen.send(token)
|
||||
|
||||
|
||||
@retry(
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
),
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(3),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def completion_with_backoff(**kwargs):
|
||||
messages = kwargs.pop("messages")
|
||||
if not "openai_api_key" in kwargs:
|
||||
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
||||
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
|
||||
return llm(messages=messages)
|
||||
|
||||
|
||||
@retry(
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
stop=stop_after_attempt(3),
|
||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||
reraise=True,
|
||||
)
|
||||
def chat_completion_with_backoff(
|
||||
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
|
||||
def message_to_log(
|
||||
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
|
||||
):
|
||||
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
|
||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
||||
t.start()
|
||||
return g
|
||||
"""Create json logs from messages, metadata for conversation log"""
|
||||
default_khoj_message_metadata = {
|
||||
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
|
||||
"trigger-emotion": "calm",
|
||||
}
|
||||
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Create json log from Human's message
|
||||
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
|
||||
|
||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
||||
callback_handler = StreamingChatCallbackHandler(g)
|
||||
chat = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
callback_manager=BaseCallbackManager([callback_handler]),
|
||||
model_name=model_name, # type: ignore
|
||||
temperature=temperature,
|
||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||
request_timeout=20,
|
||||
max_retries=1,
|
||||
client=None,
|
||||
)
|
||||
# Create json log from GPT's response
|
||||
khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
|
||||
khoj_log = merge_dicts({"message": chat_response, "by": "khoj", "created": khoj_response_time}, khoj_log)
|
||||
|
||||
chat(messages=messages)
|
||||
|
||||
g.close()
|
||||
conversation_log.extend([human_log, khoj_log])
|
||||
return conversation_log
|
||||
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
|
@ -192,27 +126,3 @@ def truncate_messages(messages, max_prompt_size, model_name):
|
|||
def reciprocal_conversation_to_chatml(message_pair):
|
||||
"""Convert a single back and forth between user and assistant to chatml format"""
|
||||
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
||||
|
||||
|
||||
def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]):
|
||||
"""Create json logs from messages, metadata for conversation log"""
|
||||
default_khoj_message_metadata = {
|
||||
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
|
||||
"trigger-emotion": "calm",
|
||||
}
|
||||
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Create json log from Human's message
|
||||
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
|
||||
|
||||
# Create json log from GPT's response
|
||||
khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
|
||||
khoj_log = merge_dicts({"message": gpt_message, "by": "khoj", "created": khoj_response_time}, khoj_log)
|
||||
|
||||
conversation_log.extend([human_log, khoj_log])
|
||||
return conversation_log
|
||||
|
||||
|
||||
def extract_summaries(metadata):
|
||||
"""Extract summaries from metadata"""
|
||||
return "".join([f'\n{session["summary"]}' for session in metadata])
|
||||
|
|
|
@ -52,7 +52,7 @@ class GithubToJsonl(TextToJsonl):
|
|||
try:
|
||||
markdown_files, org_files = self.get_files(repo_url, repo)
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to download github repo {repo_shorthand}")
|
||||
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
|
||||
raise e
|
||||
|
||||
logger.info(f"Found {len(markdown_files)} markdown files in github repo {repo_shorthand}")
|
||||
|
|
|
@ -219,7 +219,7 @@ class NotionToJsonl(TextToJsonl):
|
|||
page = self.get_page(page_id)
|
||||
content = self.get_page_children(page_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting page {page_id}: {e}")
|
||||
logger.error(f"Error getting page {page_id}: {e}", exc_info=True)
|
||||
return None, None
|
||||
properties = page["properties"]
|
||||
title_field = "title"
|
||||
|
|
|
@ -5,7 +5,7 @@ import time
|
|||
import yaml
|
||||
import logging
|
||||
import json
|
||||
from typing import Iterable, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, HTTPException, Header, Request
|
||||
|
@ -26,16 +26,19 @@ from khoj.utils.rawconfig import (
|
|||
SearchConfig,
|
||||
SearchResponse,
|
||||
TextContentConfig,
|
||||
ConversationProcessorConfig,
|
||||
OpenAIProcessorConfig,
|
||||
GithubContentConfig,
|
||||
NotionContentConfig,
|
||||
ConversationProcessorConfig,
|
||||
)
|
||||
from khoj.utils.helpers import resolve_absolute_path
|
||||
from khoj.utils.state import SearchType
|
||||
from khoj.utils import state, constants
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
from fastapi.responses import StreamingResponse, Response
|
||||
from khoj.routers.helpers import perform_chat_checks, generate_chat_response, update_telemetry_state
|
||||
from khoj.processor.conversation.gpt import extract_questions
|
||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_falcon, converse_falcon
|
||||
from fastapi.requests import Request
|
||||
|
||||
|
||||
|
@ -50,6 +53,8 @@ if not state.demo:
|
|||
if state.config is None:
|
||||
state.config = FullConfig()
|
||||
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
||||
if state.processor_config is None:
|
||||
state.processor_config = configure_processor(state.config.processor)
|
||||
|
||||
@api.get("/config/data", response_model=FullConfig)
|
||||
def get_config_data():
|
||||
|
@ -181,22 +186,28 @@ if not state.demo:
|
|||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
@api.post("/delete/config/data/processor/conversation", status_code=200)
|
||||
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
|
||||
async def remove_processor_conversation_config_data(
|
||||
request: Request,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
if not state.config or not state.config.processor or not state.config.processor.conversation:
|
||||
if (
|
||||
not state.config
|
||||
or not state.config.processor
|
||||
or not state.config.processor.conversation
|
||||
or not state.config.processor.conversation.openai
|
||||
):
|
||||
return {"status": "ok"}
|
||||
|
||||
state.config.processor.conversation = None
|
||||
state.config.processor.conversation.openai = None
|
||||
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="delete_processor_config",
|
||||
api="delete_processor_openai_config",
|
||||
client=client,
|
||||
metadata={"processor_type": "conversation"},
|
||||
metadata={"processor_conversation_type": "openai"},
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -233,23 +244,66 @@ if not state.demo:
|
|||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
@api.post("/config/data/processor/conversation", status_code=200)
|
||||
async def set_processor_conversation_config_data(
|
||||
@api.post("/config/data/processor/conversation/openai", status_code=200)
|
||||
async def set_processor_openai_config_data(
|
||||
request: Request,
|
||||
updated_config: Union[ConversationProcessorConfig, None],
|
||||
updated_config: Union[OpenAIProcessorConfig, None],
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
_initialize_config()
|
||||
|
||||
state.config.processor = ProcessorConfig(conversation=updated_config)
|
||||
state.processor_config = configure_processor(state.config.processor)
|
||||
if not state.config.processor or not state.config.processor.conversation:
|
||||
default_config = constants.default_config
|
||||
default_conversation_logfile = resolve_absolute_path(
|
||||
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
|
||||
)
|
||||
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
|
||||
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
|
||||
|
||||
assert state.config.processor.conversation is not None
|
||||
state.config.processor.conversation.openai = updated_config
|
||||
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_content_config",
|
||||
api="set_processor_config",
|
||||
client=client,
|
||||
metadata={"processor_type": "conversation"},
|
||||
metadata={"processor_conversation_type": "conversation"},
|
||||
)
|
||||
|
||||
try:
|
||||
save_config_to_file_updated_state()
|
||||
return {"status": "ok"}
|
||||
except Exception as e:
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
@api.post("/config/data/processor/conversation/enable_offline_chat", status_code=200)
|
||||
async def set_processor_enable_offline_chat_config_data(
|
||||
request: Request,
|
||||
enable_offline_chat: bool,
|
||||
client: Optional[str] = None,
|
||||
):
|
||||
_initialize_config()
|
||||
|
||||
if not state.config.processor or not state.config.processor.conversation:
|
||||
default_config = constants.default_config
|
||||
default_conversation_logfile = resolve_absolute_path(
|
||||
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
|
||||
)
|
||||
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
|
||||
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
|
||||
|
||||
assert state.config.processor.conversation is not None
|
||||
state.config.processor.conversation.enable_offline_chat = enable_offline_chat
|
||||
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
telemetry_type="api",
|
||||
api="set_processor_config",
|
||||
client=client,
|
||||
metadata={"processor_conversation_type": f"{'enable' if enable_offline_chat else 'disable'}_local_llm"},
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -569,7 +623,9 @@ def chat_history(
|
|||
perform_chat_checks()
|
||||
|
||||
# Load Conversation History
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
meta_log = {}
|
||||
if state.processor_config.conversation:
|
||||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
update_telemetry_state(
|
||||
request=request,
|
||||
|
@ -598,24 +654,25 @@ async def chat(
|
|||
perform_chat_checks()
|
||||
compiled_references, inferred_queries = await extract_references_and_questions(request, q, (n or 5))
|
||||
|
||||
# Get the (streamed) chat response from GPT.
|
||||
gpt_response = generate_chat_response(
|
||||
# Get the (streamed) chat response from the LLM of choice.
|
||||
llm_response = generate_chat_response(
|
||||
q,
|
||||
meta_log=state.processor_config.conversation.meta_log,
|
||||
compiled_references=compiled_references,
|
||||
inferred_queries=inferred_queries,
|
||||
)
|
||||
if gpt_response is None:
|
||||
return Response(content=gpt_response, media_type="text/plain", status_code=500)
|
||||
|
||||
if llm_response is None:
|
||||
return Response(content=llm_response, media_type="text/plain", status_code=500)
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
||||
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
|
||||
|
||||
# Get the full response from the generator if the stream is not requested.
|
||||
aggregated_gpt_response = ""
|
||||
while True:
|
||||
try:
|
||||
aggregated_gpt_response += next(gpt_response)
|
||||
aggregated_gpt_response += next(llm_response)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
|
@ -645,8 +702,6 @@ async def extract_references_and_questions(
|
|||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# Initialize Variables
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
chat_model = state.processor_config.conversation.chat_model
|
||||
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||
compiled_references = []
|
||||
inferred_queries = []
|
||||
|
@ -654,7 +709,13 @@ async def extract_references_and_questions(
|
|||
if conversation_type == "notes":
|
||||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
|
||||
if state.processor_config.conversation and state.processor_config.conversation.openai_model:
|
||||
api_key = state.processor_config.conversation.openai_model.api_key
|
||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
|
||||
else:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
inferred_queries = extract_questions_falcon(q, loaded_model=loaded_model, conversation_log=meta_log)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
with timer("Searching knowledge base took", logger):
|
||||
|
|
|
@ -7,22 +7,23 @@ from fastapi import HTTPException, Request
|
|||
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import timer, log_telemetry
|
||||
from khoj.processor.conversation.gpt import converse
|
||||
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
|
||||
|
||||
from khoj.processor.conversation.openai.gpt import converse
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon
|
||||
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def perform_chat_checks():
|
||||
if (
|
||||
state.processor_config is None
|
||||
or state.processor_config.conversation is None
|
||||
or state.processor_config.conversation.openai_api_key is None
|
||||
if state.processor_config.conversation and (
|
||||
state.processor_config.conversation.openai_model
|
||||
or state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
||||
)
|
||||
return
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings and restart it."
|
||||
)
|
||||
|
||||
|
||||
def update_telemetry_state(
|
||||
|
@ -57,19 +58,19 @@ def generate_chat_response(
|
|||
meta_log: dict,
|
||||
compiled_references: List[str] = [],
|
||||
inferred_queries: List[str] = [],
|
||||
):
|
||||
) -> ThreadedGenerator:
|
||||
def _save_to_conversation_log(
|
||||
q: str,
|
||||
gpt_response: str,
|
||||
chat_response: str,
|
||||
user_message_time: str,
|
||||
compiled_references: List[str],
|
||||
inferred_queries: List[str],
|
||||
meta_log,
|
||||
):
|
||||
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
|
||||
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, chat_response])
|
||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||
q,
|
||||
gpt_response,
|
||||
user_message=q,
|
||||
chat_response=chat_response,
|
||||
user_message_metadata={"created": user_message_time},
|
||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
|
@ -79,8 +80,6 @@ def generate_chat_response(
|
|||
meta_log = state.processor_config.conversation.meta_log
|
||||
|
||||
# Initialize Variables
|
||||
api_key = state.processor_config.conversation.openai_api_key
|
||||
chat_model = state.processor_config.conversation.chat_model
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||
|
||||
|
@ -99,12 +98,29 @@ def generate_chat_response(
|
|||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
gpt_response = converse(
|
||||
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
|
||||
)
|
||||
if state.processor_config.conversation.openai_model:
|
||||
api_key = state.processor_config.conversation.openai_model.api_key
|
||||
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||
chat_response = converse(
|
||||
compiled_references,
|
||||
q,
|
||||
meta_log,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
else:
|
||||
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||
chat_response = converse_falcon(
|
||||
references=compiled_references,
|
||||
user_query=q,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return gpt_response
|
||||
return chat_response
|
||||
|
|
|
@ -3,7 +3,7 @@ from fastapi import APIRouter
|
|||
from fastapi import Request
|
||||
from fastapi.responses import HTMLResponse, FileResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from khoj.utils.rawconfig import TextContentConfig, ConversationProcessorConfig, FullConfig
|
||||
from khoj.utils.rawconfig import TextContentConfig, OpenAIProcessorConfig, FullConfig
|
||||
|
||||
# Internal Packages
|
||||
from khoj.utils import constants, state
|
||||
|
@ -151,28 +151,29 @@ if not state.demo:
|
|||
},
|
||||
)
|
||||
|
||||
@web_client.get("/config/processor/conversation", response_class=HTMLResponse)
|
||||
@web_client.get("/config/processor/conversation/openai", response_class=HTMLResponse)
|
||||
def conversation_processor_config_page(request: Request):
|
||||
default_copy = constants.default_config.copy()
|
||||
default_processor_config = default_copy["processor"]["conversation"] # type: ignore
|
||||
default_processor_config = ConversationProcessorConfig(
|
||||
openai_api_key="",
|
||||
model=default_processor_config["model"],
|
||||
conversation_logfile=default_processor_config["conversation-logfile"],
|
||||
default_processor_config = default_copy["processor"]["conversation"]["openai"] # type: ignore
|
||||
default_openai_config = OpenAIProcessorConfig(
|
||||
api_key="",
|
||||
chat_model=default_processor_config["chat-model"],
|
||||
)
|
||||
|
||||
current_processor_conversation_config = (
|
||||
state.config.processor.conversation
|
||||
if state.config and state.config.processor and state.config.processor.conversation
|
||||
else default_processor_config
|
||||
current_processor_openai_config = (
|
||||
state.config.processor.conversation.openai
|
||||
if state.config
|
||||
and state.config.processor
|
||||
and state.config.processor.conversation
|
||||
and state.config.processor.conversation.openai
|
||||
else default_openai_config
|
||||
)
|
||||
current_processor_conversation_config = json.loads(current_processor_conversation_config.json())
|
||||
current_processor_openai_config = json.loads(current_processor_openai_config.json())
|
||||
|
||||
return templates.TemplateResponse(
|
||||
"processor_conversation_input.html",
|
||||
context={
|
||||
"request": request,
|
||||
"current_config": current_processor_conversation_config,
|
||||
"current_config": current_processor_openai_config,
|
||||
},
|
||||
)
|
||||
|
|
|
@ -5,7 +5,9 @@ from importlib.metadata import version
|
|||
|
||||
# Internal Packages
|
||||
from khoj.utils.helpers import resolve_absolute_path
|
||||
from khoj.utils.yaml import load_config_from_file, parse_config_from_file, save_config_to_file
|
||||
from khoj.utils.yaml import parse_config_from_file
|
||||
from khoj.migrations.migrate_version import migrate_config_to_version
|
||||
from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
|
||||
|
||||
|
||||
def cli(args=None):
|
||||
|
@ -46,22 +48,14 @@ def cli(args=None):
|
|||
if not args.config_file.exists():
|
||||
args.config = None
|
||||
else:
|
||||
args = migrate_config(args)
|
||||
args = run_migrations(args)
|
||||
args.config = parse_config_from_file(args.config_file)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def migrate_config(args):
|
||||
raw_config = load_config_from_file(args.config_file)
|
||||
|
||||
# Add version to khoj config schema
|
||||
if "version" not in raw_config:
|
||||
raw_config["version"] = args.version_no
|
||||
save_config_to_file(raw_config, args.config_file)
|
||||
|
||||
# regenerate khoj index on first start of this version
|
||||
# this should refresh index and apply index corruption fixes from #325
|
||||
args.regenerate = True
|
||||
|
||||
def run_migrations(args):
|
||||
migrations = [migrate_config_to_version, migrate_processor_conversation_schema]
|
||||
for migration in migrations:
|
||||
args = migration(args)
|
||||
return args
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
# System Packages
|
||||
from __future__ import annotations # to avoid quoting type hints
|
||||
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
|
||||
|
||||
from gpt4all import GPT4All
|
||||
|
||||
# External Packages
|
||||
import torch
|
||||
|
@ -13,7 +16,7 @@ if TYPE_CHECKING:
|
|||
from sentence_transformers import CrossEncoder
|
||||
from khoj.search_filter.base_filter import BaseFilter
|
||||
from khoj.utils.models import BaseEncoder
|
||||
from khoj.utils.rawconfig import ConversationProcessorConfig, Entry
|
||||
from khoj.utils.rawconfig import ConversationProcessorConfig, Entry, OpenAIProcessorConfig
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
|
@ -74,15 +77,29 @@ class SearchModels:
|
|||
plugin_search: Optional[Dict[str, TextSearchModel]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPT4AllProcessorConfig:
|
||||
chat_model: Optional[str] = "ggml-model-gpt4all-falcon-q4_0.bin"
|
||||
loaded_model: Union[Any, None] = None
|
||||
|
||||
|
||||
class ConversationProcessorConfigModel:
|
||||
def __init__(self, processor_config: ConversationProcessorConfig):
|
||||
self.openai_api_key = processor_config.openai_api_key
|
||||
self.model = processor_config.model
|
||||
self.chat_model = processor_config.chat_model
|
||||
self.conversation_logfile = Path(processor_config.conversation_logfile)
|
||||
def __init__(
|
||||
self,
|
||||
conversation_config: ConversationProcessorConfig,
|
||||
):
|
||||
self.openai_model = conversation_config.openai
|
||||
self.gpt4all_model = GPT4AllProcessorConfig()
|
||||
self.enable_offline_chat = conversation_config.enable_offline_chat
|
||||
self.conversation_logfile = Path(conversation_config.conversation_logfile)
|
||||
self.chat_session: List[str] = []
|
||||
self.meta_log: dict = {}
|
||||
|
||||
if not self.openai_model and self.enable_offline_chat:
|
||||
self.gpt4all_model.loaded_model = GPT4All(self.gpt4all_model.chat_model) # type: ignore
|
||||
else:
|
||||
self.gpt4all_model.loaded_model = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessorConfigModel:
|
||||
|
|
|
@ -62,10 +62,12 @@ default_config = {
|
|||
},
|
||||
"processor": {
|
||||
"conversation": {
|
||||
"openai-api-key": None,
|
||||
"model": "text-davinci-003",
|
||||
"openai": {
|
||||
"api-key": None,
|
||||
"chat-model": "gpt-3.5-turbo",
|
||||
},
|
||||
"enable-offline-chat": False,
|
||||
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
|
||||
"chat-model": "gpt-3.5-turbo",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
|
|
@ -27,12 +27,12 @@ class OpenAI(BaseEncoder):
|
|||
if (
|
||||
not state.processor_config
|
||||
or not state.processor_config.conversation
|
||||
or not state.processor_config.conversation.openai_api_key
|
||||
or not state.processor_config.conversation.openai_model
|
||||
):
|
||||
raise Exception(
|
||||
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
|
||||
)
|
||||
openai.api_key = state.processor_config.conversation.openai_api_key
|
||||
openai.api_key = state.processor_config.conversation.openai_model.api_key
|
||||
self.embedding_dimensions = None
|
||||
|
||||
def encode(self, entries, device=None, **kwargs):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# System Packages
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from typing import List, Dict, Optional, Union, Any
|
||||
|
||||
# External Packages
|
||||
from pydantic import BaseModel, validator
|
||||
|
@ -103,13 +103,17 @@ class SearchConfig(ConfigBase):
|
|||
image: Optional[ImageSearchConfig]
|
||||
|
||||
|
||||
class ConversationProcessorConfig(ConfigBase):
|
||||
openai_api_key: str
|
||||
conversation_logfile: Path
|
||||
model: Optional[str] = "text-davinci-003"
|
||||
class OpenAIProcessorConfig(ConfigBase):
|
||||
api_key: str
|
||||
chat_model: Optional[str] = "gpt-3.5-turbo"
|
||||
|
||||
|
||||
class ConversationProcessorConfig(ConfigBase):
|
||||
conversation_logfile: Path
|
||||
openai: Optional[OpenAIProcessorConfig]
|
||||
enable_offline_chat: Optional[bool] = False
|
||||
|
||||
|
||||
class ProcessorConfig(ConfigBase):
|
||||
conversation: Optional[ConversationProcessorConfig]
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ content-type:
|
|||
embeddings-file: content_plugin_2_embeddings.pt
|
||||
input-filter:
|
||||
- '*2_new.jsonl.gz'
|
||||
enable-offline-chat: false
|
||||
search-type:
|
||||
asymmetric:
|
||||
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||
|
|
426
tests/test_gpt4all_chat_actors.py
Normal file
426
tests/test_gpt4all_chat_actors.py
Normal file
|
@ -0,0 +1,426 @@
|
|||
# Standard Packages
|
||||
from datetime import datetime
|
||||
|
||||
# External Packages
|
||||
import pytest
|
||||
|
||||
SKIP_TESTS = True
|
||||
pytestmark = pytest.mark.skipif(
|
||||
SKIP_TESTS,
|
||||
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.",
|
||||
)
|
||||
|
||||
import freezegun
|
||||
from freezegun import freeze_time
|
||||
|
||||
from gpt4all import GPT4All
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon, extract_questions_falcon
|
||||
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loaded_model():
|
||||
return GPT4All("ggml-model-gpt4all-falcon-q4_0.bin")
|
||||
|
||||
|
||||
freezegun.configure(extend_ignore_list=["transformers"])
|
||||
|
||||
|
||||
# Test
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
"Where did I go for dinner yesterday?", loaded_model=loaded_model, run_extraction=True
|
||||
)
|
||||
|
||||
assert len(response) >= 1
|
||||
assert response[-1] == "Where did I go for dinner yesterday?"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon("Which countries did I visit last month?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
assert len(response) == 1
|
||||
assert response == ["Which countries did I visit last month?"]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@freeze_time("1984-04-02")
|
||||
def test_extract_question_with_date_filter_from_relative_year(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
"Which countries have I visited this year?", loaded_model=loaded_model, run_extraction=True
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
assert response[-1] == "Which countries have I visited this year?"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_explicit_questions_from_message(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon("What is the Sun? What is the Moon?", loaded_model=loaded_model)
|
||||
|
||||
# Assert
|
||||
expected_responses = ["What is the Sun?", "What is the Moon?"]
|
||||
assert len(response) == 2
|
||||
assert expected_responses == response
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_extract_multiple_implicit_questions_from_message(loaded_model):
|
||||
# Act
|
||||
response = extract_questions_falcon("Is Morpheus taller than Neo?", loaded_model=loaded_model, run_extraction=True)
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
("morpheus", "neo"),
|
||||
]
|
||||
assert len(response) == 2
|
||||
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
|
||||
"Expected two search queries in response but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
"Does he have any sons?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
run_extraction=True,
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
expected_responses = [
|
||||
"do not have",
|
||||
"clarify",
|
||||
"am sorry",
|
||||
]
|
||||
|
||||
# Assert
|
||||
assert len(response) >= 1
|
||||
assert any([expected_response in response[0] for expected_response in expected_responses]), (
|
||||
"Expected chat actor to ask for clarification in response, but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
"Is she a Jedi?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
run_extraction=True,
|
||||
use_history=True,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(response) == 1
|
||||
assert "Leia" in response[0]
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor is not sufficiently date-aware")
|
||||
@pytest.mark.chatquality
|
||||
def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response = extract_questions_falcon(
|
||||
"What was the Pizza place we ate at over there?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
run_extraction=True,
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
("dt>='2000-04-01'", "dt<'2000-05-01'"),
|
||||
("dt>='2000-04-01'", "dt<='2000-04-30'"),
|
||||
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
|
||||
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
|
||||
]
|
||||
assert len(response) == 1
|
||||
assert "Masai Mara" in response[0]
|
||||
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
|
||||
"Expected date filter to limit to April 2000 in response but got: " + response[0]
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Hello, my name is Testatron. Who are you?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["Khoj", "khoj", "khooj", "Khooj", "KHOJ"]
|
||||
assert len(response) > 0
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected assistants name, [K|k]hoj, in response but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor isn't really good at proper nouns yet.")
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model):
|
||||
"Chat actor needs to use context in previous notes and chat history to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
(
|
||||
"When was I born?",
|
||||
"You were born on 1st April 1984.",
|
||||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
# Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes
|
||||
assert "Testville" in response
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
|
||||
"Chat actor needs to use context across currently retrieved notes and chat history to answer question"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
references=[
|
||||
"Testatron was born on 1st April 1984 in Testville."
|
||||
], # Assume context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
assert "Testville" in response
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor is rather liable to lying.")
|
||||
@pytest.mark.chatquality
|
||||
def test_refuse_answering_unanswerable_question(loaded_model):
|
||||
"Chat actor should not try make up answers to unanswerable questions."
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = [
|
||||
"don't know",
|
||||
"do not know",
|
||||
"no information",
|
||||
"do not have",
|
||||
"don't have",
|
||||
"cannot answer",
|
||||
"I'm sorry",
|
||||
]
|
||||
assert len(response) > 0
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to say they don't know in response, but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_requires_current_date_awareness(loaded_model):
|
||||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||
# Arrange
|
||||
context = [
|
||||
f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
f"""2020-04-01 "SuperMercado" "Bananas"
|
||||
Expenses:Food:Groceries 10.00 USD""",
|
||||
f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I have for Dinner today?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["tacos", "Tacos"]
|
||||
assert len(response) > 0
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected [T|t]acos in response, but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(loaded_model):
|
||||
"Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
# Arrange
|
||||
context = [
|
||||
f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
f"""2020-04-01 "SuperMercado" "Bananas"
|
||||
Expenses:Food:Groceries 10.00 USD""",
|
||||
f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="How much did I spend on dining this year?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
assert len(response) > 0
|
||||
assert "20" in response
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded_model):
|
||||
"Chat actor should be able to answer general questions not requiring looking at chat history or notes"
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
references=[], # Assume no context retrieved from notes for the user_query
|
||||
user_query="Write a haiku about unit testing in 3 lines",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["test", "testing"]
|
||||
assert len(response.splitlines()) >= 3 # haikus are 3 lines long, but Falcon tends to add a lot of new lines.
|
||||
assert any([expected_response in response.lower() for expected_response in expected_responses]), (
|
||||
"Expected [T|t]est in response, but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.")
|
||||
@pytest.mark.chatquality
|
||||
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
|
||||
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
|
||||
# Arrange
|
||||
context = [
|
||||
f"""# Ramya
|
||||
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""",
|
||||
f"""# Fang
|
||||
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""",
|
||||
f"""# Aiyla
|
||||
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""",
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_falcon(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="How many kids does my older sister have?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert
|
||||
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"]
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to ask for clarification in response, but got: " + response
|
||||
)
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
for user_message, chat_response, context in message_list:
|
||||
message_to_log(
|
||||
user_message,
|
||||
chat_response,
|
||||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
conversation_log=conversation_log["chat"],
|
||||
)
|
||||
return conversation_log
|
|
@ -8,7 +8,7 @@ import freezegun
|
|||
from freezegun import freeze_time
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation.gpt import converse, extract_questions
|
||||
from khoj.processor.conversation.openai.gpt import converse, extract_questions
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
|
||||
|
Loading…
Reference in a new issue