mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Add support for our first Local LLM 🤖🏠 (#330)
* Add support for gpt4all's falcon model as an additional conversation processor - Update the UI pages to allow the user to point to the new endpoints for GPT - Update the internal schemas to support both GPT4 models and OpenAI - Add unit tests benchmarking some of the Falcon performance * Add exc_info to include stack trace in error logs for text processors * Pull shared functions into utils.py to be used across gpt4 and gpt * Add migration for new processor conversation schema * Skip GPT4All actor tests due to typing issues * Fix Obsidian processor configuration in auto-configure flow * Rename enable_local_llm to enable_offline_chat
This commit is contained in:
parent
23d77ee338
commit
8b2af0b5ef
34 changed files with 1258 additions and 291 deletions
1
.github/workflows/test.yml
vendored
1
.github/workflows/test.yml
vendored
|
@ -30,7 +30,6 @@ jobs:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python_version:
|
python_version:
|
||||||
- '3.8'
|
|
||||||
- '3.9'
|
- '3.9'
|
||||||
- '3.10'
|
- '3.10'
|
||||||
- '3.11'
|
- '3.11'
|
||||||
|
|
|
@ -58,6 +58,7 @@ dependencies = [
|
||||||
"pypdf >= 3.9.0",
|
"pypdf >= 3.9.0",
|
||||||
"requests >= 2.26.0",
|
"requests >= 2.26.0",
|
||||||
"bs4 >= 0.0.1",
|
"bs4 >= 0.0.1",
|
||||||
|
"gpt4all==1.0.5",
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
|
|
@ -373,7 +373,7 @@ CONFIG is json obtained from Khoj config API."
|
||||||
(ignore-error json-end-of-file (json-parse-buffer :object-type 'alist :array-type 'list :null-object json-null :false-object json-false))))
|
(ignore-error json-end-of-file (json-parse-buffer :object-type 'alist :array-type 'list :null-object json-null :false-object json-false))))
|
||||||
(default-index-dir (khoj--get-directory-from-config default-config '(content-type org embeddings-file)))
|
(default-index-dir (khoj--get-directory-from-config default-config '(content-type org embeddings-file)))
|
||||||
(default-chat-dir (khoj--get-directory-from-config default-config '(processor conversation conversation-logfile)))
|
(default-chat-dir (khoj--get-directory-from-config default-config '(processor conversation conversation-logfile)))
|
||||||
(chat-model (or khoj-chat-model (alist-get 'chat-model (alist-get 'conversation (alist-get 'processor default-config)))))
|
(chat-model (or khoj-chat-model (alist-get 'chat-model (alist-get 'openai (alist-get 'conversation (alist-get 'processor default-config))))))
|
||||||
(default-model (alist-get 'model (alist-get 'conversation (alist-get 'processor default-config))))
|
(default-model (alist-get 'model (alist-get 'conversation (alist-get 'processor default-config))))
|
||||||
(config (or current-config default-config)))
|
(config (or current-config default-config)))
|
||||||
|
|
||||||
|
@ -423,15 +423,27 @@ CONFIG is json obtained from Khoj config API."
|
||||||
;; Configure processors
|
;; Configure processors
|
||||||
(cond
|
(cond
|
||||||
((not khoj-openai-api-key)
|
((not khoj-openai-api-key)
|
||||||
(setq config (delq (assoc 'processor config) config)))
|
(let* ((processor (assoc 'processor config))
|
||||||
|
(conversation (assoc 'conversation processor))
|
||||||
|
(openai (assoc 'openai conversation)))
|
||||||
|
(when openai
|
||||||
|
;; Unset the `openai' field in the khoj conversation processor config
|
||||||
|
(message "khoj.el: disable Khoj Chat using OpenAI as your OpenAI API key got removed from config")
|
||||||
|
(setcdr conversation (delq openai (cdr conversation)))
|
||||||
|
(setcdr processor (delq conversation (cdr processor)))
|
||||||
|
(setq config (delq processor config))
|
||||||
|
(push conversation (cdr processor))
|
||||||
|
(push processor config))))
|
||||||
|
|
||||||
((not current-config)
|
((not current-config)
|
||||||
(message "khoj.el: Chat not configured yet.")
|
(message "khoj.el: Chat not configured yet.")
|
||||||
(setq config (delq (assoc 'processor config) config))
|
(setq config (delq (assoc 'processor config) config))
|
||||||
(cl-pushnew `(processor . ((conversation . ((conversation-logfile . ,(format "%s/conversation.json" default-chat-dir))
|
(cl-pushnew `(processor . ((conversation . ((conversation-logfile . ,(format "%s/conversation.json" default-chat-dir))
|
||||||
(chat-model . ,chat-model)
|
(openai . (
|
||||||
(model . ,default-model)
|
(chat-model . ,chat-model)
|
||||||
(openai-api-key . ,khoj-openai-api-key)))))
|
(api-key . ,khoj-openai-api-key)
|
||||||
|
))
|
||||||
|
))))
|
||||||
config))
|
config))
|
||||||
|
|
||||||
((not (alist-get 'conversation (alist-get 'processor config)))
|
((not (alist-get 'conversation (alist-get 'processor config)))
|
||||||
|
@ -440,21 +452,19 @@ CONFIG is json obtained from Khoj config API."
|
||||||
(setq new-processor-type (delq (assoc 'conversation new-processor-type) new-processor-type))
|
(setq new-processor-type (delq (assoc 'conversation new-processor-type) new-processor-type))
|
||||||
(cl-pushnew `(conversation . ((conversation-logfile . ,(format "%s/conversation.json" default-chat-dir))
|
(cl-pushnew `(conversation . ((conversation-logfile . ,(format "%s/conversation.json" default-chat-dir))
|
||||||
(chat-model . ,chat-model)
|
(chat-model . ,chat-model)
|
||||||
(model . ,default-model)
|
|
||||||
(openai-api-key . ,khoj-openai-api-key)))
|
(openai-api-key . ,khoj-openai-api-key)))
|
||||||
new-processor-type)
|
new-processor-type)
|
||||||
(setq config (delq (assoc 'processor config) config))
|
(setq config (delq (assoc 'processor config) config))
|
||||||
(cl-pushnew `(processor . ,new-processor-type) config)))
|
(cl-pushnew `(processor . ,new-processor-type) config)))
|
||||||
|
|
||||||
;; Else if khoj is not configured with specified openai api key
|
;; Else if khoj is not configured with specified openai api key
|
||||||
((not (and (equal (alist-get 'openai-api-key (alist-get 'conversation (alist-get 'processor config))) khoj-openai-api-key)
|
((not (and (equal (alist-get 'api-key (alist-get 'openai (alist-get 'conversation (alist-get 'processor config)))) khoj-openai-api-key)
|
||||||
(equal (alist-get 'chat-model (alist-get 'conversation (alist-get 'processor config))) khoj-chat-model)))
|
(equal (alist-get 'chat-model (alist-get 'openai (alist-get 'conversation (alist-get 'processor config)))) khoj-chat-model)))
|
||||||
(message "khoj.el: Chat configuration has gone stale.")
|
(message "khoj.el: Chat configuration has gone stale.")
|
||||||
(let* ((chat-directory (khoj--get-directory-from-config config '(processor conversation conversation-logfile)))
|
(let* ((chat-directory (khoj--get-directory-from-config config '(processor conversation conversation-logfile)))
|
||||||
(new-processor-type (alist-get 'processor config)))
|
(new-processor-type (alist-get 'processor config)))
|
||||||
(setq new-processor-type (delq (assoc 'conversation new-processor-type) new-processor-type))
|
(setq new-processor-type (delq (assoc 'conversation new-processor-type) new-processor-type))
|
||||||
(cl-pushnew `(conversation . ((conversation-logfile . ,(format "%s/conversation.json" chat-directory))
|
(cl-pushnew `(conversation . ((conversation-logfile . ,(format "%s/conversation.json" chat-directory))
|
||||||
(model . ,default-model)
|
|
||||||
(chat-model . ,khoj-chat-model)
|
(chat-model . ,khoj-chat-model)
|
||||||
(openai-api-key . ,khoj-openai-api-key)))
|
(openai-api-key . ,khoj-openai-api-key)))
|
||||||
new-processor-type)
|
new-processor-type)
|
||||||
|
|
|
@ -36,7 +36,7 @@ export async function configureKhojBackend(vault: Vault, setting: KhojSetting, n
|
||||||
let khojDefaultMdIndexDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["content-type"]["markdown"]["embeddings-file"]);
|
let khojDefaultMdIndexDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["content-type"]["markdown"]["embeddings-file"]);
|
||||||
let khojDefaultPdfIndexDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["content-type"]["pdf"]["embeddings-file"]);
|
let khojDefaultPdfIndexDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["content-type"]["pdf"]["embeddings-file"]);
|
||||||
let khojDefaultChatDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["processor"]["conversation"]["conversation-logfile"]);
|
let khojDefaultChatDirectory = getIndexDirectoryFromBackendConfig(defaultConfig["processor"]["conversation"]["conversation-logfile"]);
|
||||||
let khojDefaultChatModelName = defaultConfig["processor"]["conversation"]["model"];
|
let khojDefaultChatModelName = defaultConfig["processor"]["conversation"]["openai"]["chat-model"];
|
||||||
|
|
||||||
// Get current config if khoj backend configured, else get default config from khoj backend
|
// Get current config if khoj backend configured, else get default config from khoj backend
|
||||||
await request(khoj_already_configured ? khojConfigUrl : `${khojConfigUrl}/default`)
|
await request(khoj_already_configured ? khojConfigUrl : `${khojConfigUrl}/default`)
|
||||||
|
@ -142,25 +142,35 @@ export async function configureKhojBackend(vault: Vault, setting: KhojSetting, n
|
||||||
data["processor"] = {
|
data["processor"] = {
|
||||||
"conversation": {
|
"conversation": {
|
||||||
"conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`,
|
"conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`,
|
||||||
"model": khojDefaultChatModelName,
|
"openai": {
|
||||||
"openai-api-key": setting.openaiApiKey,
|
"chat-model": khojDefaultChatModelName,
|
||||||
}
|
"api-key": setting.openaiApiKey,
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Else if khoj config has no conversation processor config
|
// Else if khoj config has no conversation processor config
|
||||||
else if (!data["processor"]["conversation"]) {
|
else if (!data["processor"]["conversation"] || !data["processor"]["conversation"]["openai"]) {
|
||||||
data["processor"]["conversation"] = {
|
data["processor"] = {
|
||||||
"conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`,
|
"conversation": {
|
||||||
"model": khojDefaultChatModelName,
|
"conversation-logfile": `${khojDefaultChatDirectory}/conversation.json`,
|
||||||
"openai-api-key": setting.openaiApiKey,
|
"openai": {
|
||||||
|
"chat-model": khojDefaultChatModelName,
|
||||||
|
"api-key": setting.openaiApiKey,
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Else if khoj is not configured with OpenAI API key from khoj plugin settings
|
// Else if khoj is not configured with OpenAI API key from khoj plugin settings
|
||||||
else if (data["processor"]["conversation"]["openai-api-key"] !== setting.openaiApiKey) {
|
else if (data["processor"]["conversation"]["openai"]["api-key"] !== setting.openaiApiKey) {
|
||||||
data["processor"]["conversation"] = {
|
data["processor"] = {
|
||||||
"conversation-logfile": data["processor"]["conversation"]["conversation-logfile"],
|
"conversation": {
|
||||||
"model": data["processor"]["conversation"]["model"],
|
"conversation-logfile": data["processor"]["conversation"]["conversation-logfile"],
|
||||||
"openai-api-key": setting.openaiApiKey,
|
"openai": {
|
||||||
|
"chat-model": data["processor"]["conversation"]["openai"]["chat-model"],
|
||||||
|
"api-key": setting.openaiApiKey,
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,6 @@ import schedule
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.processor.conversation.gpt import summarize
|
|
||||||
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
from khoj.processor.jsonl.jsonl_to_jsonl import JsonlToJsonl
|
||||||
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
from khoj.processor.markdown.markdown_to_jsonl import MarkdownToJsonl
|
||||||
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
from khoj.processor.org_mode.org_to_jsonl import OrgToJsonl
|
||||||
|
@ -28,7 +27,7 @@ from khoj.utils.config import (
|
||||||
ConversationProcessorConfigModel,
|
ConversationProcessorConfigModel,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
|
from khoj.utils.helpers import LRU, resolve_absolute_path, merge_dicts
|
||||||
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig
|
from khoj.utils.rawconfig import FullConfig, ProcessorConfig, SearchConfig, ContentConfig, ConversationProcessorConfig
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
|
@ -64,7 +63,7 @@ def configure_server(config: FullConfig, regenerate: bool, search_type: Optional
|
||||||
state.config_lock.acquire()
|
state.config_lock.acquire()
|
||||||
state.processor_config = configure_processor(state.config.processor)
|
state.processor_config = configure_processor(state.config.processor)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"🚨 Failed to configure processor")
|
logger.error(f"🚨 Failed to configure processor", exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
state.config_lock.release()
|
state.config_lock.release()
|
||||||
|
@ -75,7 +74,7 @@ def configure_server(config: FullConfig, regenerate: bool, search_type: Optional
|
||||||
state.SearchType = configure_search_types(state.config)
|
state.SearchType = configure_search_types(state.config)
|
||||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"🚨 Failed to configure search models")
|
logger.error(f"🚨 Failed to configure search models", exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
state.config_lock.release()
|
state.config_lock.release()
|
||||||
|
@ -88,7 +87,7 @@ def configure_server(config: FullConfig, regenerate: bool, search_type: Optional
|
||||||
state.content_index, state.config.content_type, state.search_models, regenerate, search_type
|
state.content_index, state.config.content_type, state.search_models, regenerate, search_type
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"🚨 Failed to index content")
|
logger.error(f"🚨 Failed to index content", exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
finally:
|
finally:
|
||||||
state.config_lock.release()
|
state.config_lock.release()
|
||||||
|
@ -117,7 +116,7 @@ if not state.demo:
|
||||||
)
|
)
|
||||||
logger.info("📬 Content index updated via Scheduler")
|
logger.info("📬 Content index updated via Scheduler")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"🚨 Error updating content index via Scheduler: {e}")
|
logger.error(f"🚨 Error updating content index via Scheduler: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
state.config_lock.release()
|
state.config_lock.release()
|
||||||
|
|
||||||
|
@ -258,7 +257,9 @@ def configure_content(
|
||||||
return content_index
|
return content_index
|
||||||
|
|
||||||
|
|
||||||
def configure_processor(processor_config: Optional[ProcessorConfig]):
|
def configure_processor(
|
||||||
|
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
|
||||||
|
):
|
||||||
if not processor_config:
|
if not processor_config:
|
||||||
logger.warning("🚨 No Processor configuration available.")
|
logger.warning("🚨 No Processor configuration available.")
|
||||||
return None
|
return None
|
||||||
|
@ -266,16 +267,47 @@ def configure_processor(processor_config: Optional[ProcessorConfig]):
|
||||||
processor = ProcessorConfigModel()
|
processor = ProcessorConfigModel()
|
||||||
|
|
||||||
# Initialize Conversation Processor
|
# Initialize Conversation Processor
|
||||||
if processor_config.conversation:
|
logger.info("💬 Setting up conversation processor")
|
||||||
logger.info("💬 Setting up conversation processor")
|
processor.conversation = configure_conversation_processor(processor_config, state_processor_config)
|
||||||
processor.conversation = configure_conversation_processor(processor_config.conversation)
|
|
||||||
|
|
||||||
return processor
|
return processor
|
||||||
|
|
||||||
|
|
||||||
def configure_conversation_processor(conversation_processor_config):
|
def configure_conversation_processor(
|
||||||
conversation_processor = ConversationProcessorConfigModel(conversation_processor_config)
|
processor_config: Optional[ProcessorConfig], state_processor_config: Optional[ProcessorConfigModel] = None
|
||||||
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
|
):
|
||||||
|
if (
|
||||||
|
not processor_config
|
||||||
|
or not processor_config.conversation
|
||||||
|
or not processor_config.conversation.conversation_logfile
|
||||||
|
):
|
||||||
|
default_config = constants.default_config
|
||||||
|
default_conversation_logfile = resolve_absolute_path(
|
||||||
|
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
|
||||||
|
)
|
||||||
|
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
|
||||||
|
conversation_config = processor_config.conversation if processor_config else None
|
||||||
|
conversation_processor = ConversationProcessorConfigModel(
|
||||||
|
conversation_config=ConversationProcessorConfig(
|
||||||
|
conversation_logfile=conversation_logfile,
|
||||||
|
openai=(conversation_config.openai if (conversation_config is not None) else None),
|
||||||
|
enable_offline_chat=(
|
||||||
|
conversation_config.enable_offline_chat if (conversation_config is not None) else False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
conversation_processor = ConversationProcessorConfigModel(
|
||||||
|
conversation_config=processor_config.conversation,
|
||||||
|
)
|
||||||
|
conversation_logfile = resolve_absolute_path(conversation_processor.conversation_logfile)
|
||||||
|
|
||||||
|
# Load Conversation Logs from Disk
|
||||||
|
if state_processor_config and state_processor_config.conversation and state_processor_config.conversation.meta_log:
|
||||||
|
conversation_processor.meta_log = state_processor_config.conversation.meta_log
|
||||||
|
conversation_processor.chat_session = state_processor_config.conversation.chat_session
|
||||||
|
logger.debug(f"Loaded conversation logs from state")
|
||||||
|
return conversation_processor
|
||||||
|
|
||||||
if conversation_logfile.is_file():
|
if conversation_logfile.is_file():
|
||||||
# Load Metadata Logs from Conversation Logfile
|
# Load Metadata Logs from Conversation Logfile
|
||||||
|
@ -302,12 +334,8 @@ def save_chat_session():
|
||||||
return
|
return
|
||||||
|
|
||||||
# Summarize Conversation Logs for this Session
|
# Summarize Conversation Logs for this Session
|
||||||
chat_session = state.processor_config.conversation.chat_session
|
|
||||||
openai_api_key = state.processor_config.conversation.openai_api_key
|
|
||||||
conversation_log = state.processor_config.conversation.meta_log
|
conversation_log = state.processor_config.conversation.meta_log
|
||||||
chat_model = state.processor_config.conversation.chat_model
|
|
||||||
session = {
|
session = {
|
||||||
"summary": summarize(chat_session, model=chat_model, api_key=openai_api_key),
|
|
||||||
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
"session-start": conversation_log.get("session", [{"session-end": 0}])[-1]["session-end"],
|
||||||
"session-end": len(conversation_log["chat"]),
|
"session-end": len(conversation_log["chat"]),
|
||||||
}
|
}
|
||||||
|
@ -344,6 +372,6 @@ def upload_telemetry():
|
||||||
log[field] = str(log[field])
|
log[field] = str(log[field])
|
||||||
requests.post(constants.telemetry_server, json=state.telemetry)
|
requests.post(constants.telemetry_server, json=state.telemetry)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"📡 Error uploading telemetry: {e}")
|
logger.error(f"📡 Error uploading telemetry: {e}", exc_info=True)
|
||||||
else:
|
else:
|
||||||
state.telemetry = []
|
state.telemetry = []
|
||||||
|
|
1
src/khoj/interface/web/assets/icons/openai-logomark.svg
Normal file
1
src/khoj/interface/web/assets/icons/openai-logomark.svg
Normal file
|
@ -0,0 +1 @@
|
||||||
|
<svg viewBox="0 0 320 320" xmlns="http://www.w3.org/2000/svg"><path d="m297.06 130.97c7.26-21.79 4.76-45.66-6.85-65.48-17.46-30.4-52.56-46.04-86.84-38.68-15.25-17.18-37.16-26.95-60.13-26.81-35.04-.08-66.13 22.48-76.91 55.82-22.51 4.61-41.94 18.7-53.31 38.67-17.59 30.32-13.58 68.54 9.92 94.54-7.26 21.79-4.76 45.66 6.85 65.48 17.46 30.4 52.56 46.04 86.84 38.68 15.24 17.18 37.16 26.95 60.13 26.8 35.06.09 66.16-22.49 76.94-55.86 22.51-4.61 41.94-18.7 53.31-38.67 17.57-30.32 13.55-68.51-9.94-94.51zm-120.28 168.11c-14.03.02-27.62-4.89-38.39-13.88.49-.26 1.34-.73 1.89-1.07l63.72-36.8c3.26-1.85 5.26-5.32 5.24-9.07v-89.83l26.93 15.55c.29.14.48.42.52.74v74.39c-.04 33.08-26.83 59.9-59.91 59.97zm-128.84-55.03c-7.03-12.14-9.56-26.37-7.15-40.18.47.28 1.3.79 1.89 1.13l63.72 36.8c3.23 1.89 7.23 1.89 10.47 0l77.79-44.92v31.1c.02.32-.13.63-.38.83l-64.41 37.19c-28.69 16.52-65.33 6.7-81.92-21.95zm-16.77-139.09c7-12.16 18.05-21.46 31.21-26.29 0 .55-.03 1.52-.03 2.2v73.61c-.02 3.74 1.98 7.21 5.23 9.06l77.79 44.91-26.93 15.55c-.27.18-.61.21-.91.08l-64.42-37.22c-28.63-16.58-38.45-53.21-21.95-81.89zm221.26 51.49-77.79-44.92 26.93-15.54c.27-.18.61-.21.91-.08l64.42 37.19c28.68 16.57 38.51 53.26 21.94 81.94-7.01 12.14-18.05 21.44-31.2 26.28v-75.81c.03-3.74-1.96-7.2-5.2-9.06zm26.8-40.34c-.47-.29-1.3-.79-1.89-1.13l-63.72-36.8c-3.23-1.89-7.23-1.89-10.47 0l-77.79 44.92v-31.1c-.02-.32.13-.63.38-.83l64.41-37.16c28.69-16.55 65.37-6.7 81.91 22 6.99 12.12 9.52 26.31 7.15 40.1zm-168.51 55.43-26.94-15.55c-.29-.14-.48-.42-.52-.74v-74.39c.02-33.12 26.89-59.96 60.01-59.94 14.01 0 27.57 4.92 38.34 13.88-.49.26-1.33.73-1.89 1.07l-63.72 36.8c-3.26 1.85-5.26 5.31-5.24 9.06l-.04 89.79zm14.63-31.54 34.65-20.01 34.65 20v40.01l-34.65 20-34.65-20z"/></svg>
|
After Width: | Height: | Size: 1.7 KiB |
|
@ -167,10 +167,42 @@
|
||||||
text-align: left;
|
text-align: left;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
button.card-button.happy {
|
||||||
|
color: rgb(0, 146, 0);
|
||||||
|
}
|
||||||
|
|
||||||
img.configured-icon {
|
img.configured-icon {
|
||||||
max-width: 16px;
|
max-width: 16px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
div.card-action-row.enabled{
|
||||||
|
display: block;
|
||||||
|
}
|
||||||
|
|
||||||
|
img.configured-icon.enabled {
|
||||||
|
display: inline;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.card-action-row.disabled,
|
||||||
|
img.configured-icon.disabled {
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.loader {
|
||||||
|
border: 16px solid #f3f3f3; /* Light grey */
|
||||||
|
border-top: 16px solid var(--primary);
|
||||||
|
border-radius: 50%;
|
||||||
|
width: 16px;
|
||||||
|
height: 16px;
|
||||||
|
animation: spin 2s linear infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes spin {
|
||||||
|
0% { transform: rotate(0deg); }
|
||||||
|
100% { transform: rotate(360deg); }
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
div.finalize-actions {
|
div.finalize-actions {
|
||||||
grid-auto-flow: column;
|
grid-auto-flow: column;
|
||||||
grid-gap: 24px;
|
grid-gap: 24px;
|
||||||
|
|
|
@ -135,8 +135,8 @@
|
||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
if (data.detail) {
|
if (data.detail) {
|
||||||
// If the server returns a 500 error with detail, render it as a message.
|
// If the server returns a 500 error with detail, render a setup hint.
|
||||||
renderMessage("Hi 👋🏾, to get started <br/>1. Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a><br/>2. Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation'>chat settings</a> <br/>3. Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>", "khoj");
|
renderMessage("Hi 👋🏾, to get started <br/>1. Get your <a class='inline-chat-link' href='https://platform.openai.com/account/api-keys'>OpenAI API key</a><br/>2. Save it in the Khoj <a class='inline-chat-link' href='/config/processor/conversation/openai'>chat settings</a> <br/>3. Click Configure on the Khoj <a class='inline-chat-link' href='/config'>settings page</a>", "khoj");
|
||||||
|
|
||||||
// Disable chat input field and update placeholder text
|
// Disable chat input field and update placeholder text
|
||||||
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
document.getElementById("chat-input").setAttribute("disabled", "disabled");
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
</h3>
|
</h3>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-description-row">
|
<div class="card-description-row">
|
||||||
<p class="card-description">Set repositories for Khoj to index</p>
|
<p class="card-description">Set repositories to index</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
<a class="card-button" href="/config/content_type/github">
|
<a class="card-button" href="/config/content_type/github">
|
||||||
|
@ -90,7 +90,7 @@
|
||||||
</h3>
|
</h3>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-description-row">
|
<div class="card-description-row">
|
||||||
<p class="card-description">Set markdown files for Khoj to index</p>
|
<p class="card-description">Set markdown files to index</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
<a class="card-button" href="/config/content_type/markdown">
|
<a class="card-button" href="/config/content_type/markdown">
|
||||||
|
@ -125,7 +125,7 @@
|
||||||
</h3>
|
</h3>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-description-row">
|
<div class="card-description-row">
|
||||||
<p class="card-description">Set org files for Khoj to index</p>
|
<p class="card-description">Set org files to index</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
<a class="card-button" href="/config/content_type/org">
|
<a class="card-button" href="/config/content_type/org">
|
||||||
|
@ -160,7 +160,7 @@
|
||||||
</h3>
|
</h3>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-description-row">
|
<div class="card-description-row">
|
||||||
<p class="card-description">Set PDF files for Khoj to index</p>
|
<p class="card-description">Set PDF files to index</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
<a class="card-button" href="/config/content_type/pdf">
|
<a class="card-button" href="/config/content_type/pdf">
|
||||||
|
@ -187,10 +187,10 @@
|
||||||
<div class="section-cards">
|
<div class="section-cards">
|
||||||
<div class="card">
|
<div class="card">
|
||||||
<div class="card-title-row">
|
<div class="card-title-row">
|
||||||
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
<img class="card-icon" src="/static/assets/icons/openai-logomark.svg" alt="Chat">
|
||||||
<h3 class="card-title">
|
<h3 class="card-title">
|
||||||
Chat
|
Chat
|
||||||
{% if current_config.processor and current_config.processor.conversation %}
|
{% if current_config.processor and current_config.processor.conversation.openai %}
|
||||||
{% if current_model_state.conversation == False %}
|
{% if current_model_state.conversation == False %}
|
||||||
<img id="misconfigured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
<img id="misconfigured-icon-conversation-processor" class="configured-icon" src="/static/assets/icons/question-mark-icon.svg" alt="Not Configured" title="Embeddings have not been generated yet for this content type. Either the configuration is invalid, or you just need to click Configure.">
|
||||||
{% else %}
|
{% else %}
|
||||||
|
@ -200,11 +200,11 @@
|
||||||
</h3>
|
</h3>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-description-row">
|
<div class="card-description-row">
|
||||||
<p class="card-description">Setup Khoj Chat with OpenAI</p>
|
<p class="card-description">Setup chat using OpenAI</p>
|
||||||
</div>
|
</div>
|
||||||
<div class="card-action-row">
|
<div class="card-action-row">
|
||||||
<a class="card-button" href="/config/processor/conversation">
|
<a class="card-button" href="/config/processor/conversation/openai">
|
||||||
{% if current_config.processor and current_config.processor.conversation %}
|
{% if current_config.processor and current_config.processor.conversation.openai %}
|
||||||
Update
|
Update
|
||||||
{% else %}
|
{% else %}
|
||||||
Setup
|
Setup
|
||||||
|
@ -212,7 +212,7 @@
|
||||||
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
<svg xmlns="http://www.w3.org/2000/svg" width="1em" height="1em" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M5 12h14M12 5l7 7-7 7"></path></svg>
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
{% if current_config.processor and current_config.processor.conversation %}
|
{% if current_config.processor and current_config.processor.conversation.openai %}
|
||||||
<div id="clear-conversation" class="card-action-row">
|
<div id="clear-conversation" class="card-action-row">
|
||||||
<button class="card-button" onclick="clearConversationProcessor()">
|
<button class="card-button" onclick="clearConversationProcessor()">
|
||||||
Disable
|
Disable
|
||||||
|
@ -220,6 +220,31 @@
|
||||||
</div>
|
</div>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
</div>
|
</div>
|
||||||
|
<div class="card">
|
||||||
|
<div class="card-title-row">
|
||||||
|
<img class="card-icon" src="/static/assets/icons/chat.svg" alt="Chat">
|
||||||
|
<h3 class="card-title">
|
||||||
|
Offline Chat
|
||||||
|
<img id="configured-icon-conversation-enable-offline-chat" class="configured-icon {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}" src="/static/assets/icons/confirm-icon.svg" alt="Configured">
|
||||||
|
</h3>
|
||||||
|
</div>
|
||||||
|
<div class="card-description-row">
|
||||||
|
<p class="card-description">Setup offline chat (Falcon 7B)</p>
|
||||||
|
</div>
|
||||||
|
<div id="clear-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}enabled{% else %}disabled{% endif %}">
|
||||||
|
<button class="card-button" onclick="toggleEnableLocalLLLM(false)">
|
||||||
|
Disable
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div id="set-enable-offline-chat" class="card-action-row {% if current_config.processor and current_config.processor.conversation and current_config.processor.conversation.enable_offline_chat %}disabled{% else %}enabled{% endif %}">
|
||||||
|
<button class="card-button happy" onclick="toggleEnableLocalLLLM(true)">
|
||||||
|
Enable
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
<div id="toggle-enable-offline-chat" class="card-action-row disabled">
|
||||||
|
<div class="loader"></div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="section">
|
<div class="section">
|
||||||
|
@ -263,9 +288,59 @@
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
|
function toggleEnableLocalLLLM(enable) {
|
||||||
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
|
var toggleEnableLocalLLLMButton = document.getElementById("toggle-enable-offline-chat");
|
||||||
|
toggleEnableLocalLLLMButton.classList.remove("disabled");
|
||||||
|
toggleEnableLocalLLLMButton.classList.add("enabled");
|
||||||
|
|
||||||
|
fetch('/api/config/data/processor/conversation/enable_offline_chat' + '?enable_offline_chat=' + enable, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'X-CSRFToken': csrfToken
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.then(response => response.json())
|
||||||
|
.then(data => {
|
||||||
|
if (data.status == "ok") {
|
||||||
|
// Toggle the Enabled/Disabled UI based on the action/response.
|
||||||
|
var enableLocalLLLMButton = document.getElementById("set-enable-offline-chat");
|
||||||
|
var disableLocalLLLMButton = document.getElementById("clear-enable-offline-chat");
|
||||||
|
var configuredIcon = document.getElementById("configured-icon-conversation-enable-offline-chat");
|
||||||
|
var toggleEnableLocalLLLMButton = document.getElementById("toggle-enable-offline-chat");
|
||||||
|
|
||||||
|
toggleEnableLocalLLLMButton.classList.remove("enabled");
|
||||||
|
toggleEnableLocalLLLMButton.classList.add("disabled");
|
||||||
|
|
||||||
|
|
||||||
|
if (enable) {
|
||||||
|
enableLocalLLLMButton.classList.add("disabled");
|
||||||
|
enableLocalLLLMButton.classList.remove("enabled");
|
||||||
|
|
||||||
|
configuredIcon.classList.add("enabled");
|
||||||
|
configuredIcon.classList.remove("disabled");
|
||||||
|
|
||||||
|
disableLocalLLLMButton.classList.remove("disabled");
|
||||||
|
disableLocalLLLMButton.classList.add("enabled");
|
||||||
|
} else {
|
||||||
|
enableLocalLLLMButton.classList.remove("disabled");
|
||||||
|
enableLocalLLLMButton.classList.add("enabled");
|
||||||
|
|
||||||
|
configuredIcon.classList.remove("enabled");
|
||||||
|
configuredIcon.classList.add("disabled");
|
||||||
|
|
||||||
|
disableLocalLLLMButton.classList.add("disabled");
|
||||||
|
disableLocalLLLMButton.classList.remove("enabled");
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
function clearConversationProcessor() {
|
function clearConversationProcessor() {
|
||||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
fetch('/api/delete/config/data/processor/conversation', {
|
fetch('/api/delete/config/data/processor/conversation/openai', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
|
@ -319,7 +394,7 @@
|
||||||
function updateIndex(force, successText, errorText, button, loadingText, emoji) {
|
function updateIndex(force, successText, errorText, button, loadingText, emoji) {
|
||||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
button.disabled = true;
|
button.disabled = true;
|
||||||
button.innerHTML = emoji + loadingText;
|
button.innerHTML = emoji + " " + loadingText;
|
||||||
fetch('/api/update?&client=web&force=' + force, {
|
fetch('/api/update?&client=web&force=' + force, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
headers: {
|
headers: {
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
<label for="openai-api-key" title="Get your OpenAI key from https://platform.openai.com/account/api-keys">OpenAI API key</label>
|
<label for="openai-api-key" title="Get your OpenAI key from https://platform.openai.com/account/api-keys">OpenAI API key</label>
|
||||||
</td>
|
</td>
|
||||||
<td>
|
<td>
|
||||||
<input type="text" id="openai-api-key" name="openai-api-key" value="{{ current_config['openai_api_key'] }}">
|
<input type="text" id="openai-api-key" name="openai-api-key" value="{{ current_config['api_key'] }}">
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr>
|
<tr>
|
||||||
|
@ -25,24 +25,6 @@
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
<table style="display: none;">
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<label for="conversation-logfile">Conversation Logfile</label>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<input type="text" id="conversation-logfile" name="conversation-logfile" value="{{ current_config['conversation_logfile'] }}">
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<label for="model">Model</label>
|
|
||||||
</td>
|
|
||||||
<td>
|
|
||||||
<input type="text" id="model" name="model" value="{{ current_config['model'] }}">
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
<div class="section">
|
<div class="section">
|
||||||
<div id="success" style="display: none;" ></div>
|
<div id="success" style="display: none;" ></div>
|
||||||
<button id="submit" type="submit">Save</button>
|
<button id="submit" type="submit">Save</button>
|
||||||
|
@ -54,21 +36,23 @@
|
||||||
submit.addEventListener("click", function(event) {
|
submit.addEventListener("click", function(event) {
|
||||||
event.preventDefault();
|
event.preventDefault();
|
||||||
var openai_api_key = document.getElementById("openai-api-key").value;
|
var openai_api_key = document.getElementById("openai-api-key").value;
|
||||||
var conversation_logfile = document.getElementById("conversation-logfile").value;
|
|
||||||
var model = document.getElementById("model").value;
|
|
||||||
var chat_model = document.getElementById("chat-model").value;
|
var chat_model = document.getElementById("chat-model").value;
|
||||||
|
|
||||||
|
if (openai_api_key == "" || chat_model == "") {
|
||||||
|
document.getElementById("success").innerHTML = "⚠️ Please fill all the fields.";
|
||||||
|
document.getElementById("success").style.display = "block";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
const csrfToken = document.cookie.split('; ').find(row => row.startsWith('csrftoken'))?.split('=')[1];
|
||||||
fetch('/api/config/data/processor/conversation', {
|
fetch('/api/config/data/processor/conversation/openai', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'X-CSRFToken': csrfToken
|
'X-CSRFToken': csrfToken
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
"openai_api_key": openai_api_key,
|
"api_key": openai_api_key,
|
||||||
"conversation_logfile": conversation_logfile,
|
|
||||||
"model": model,
|
|
||||||
"chat_model": chat_model
|
"chat_model": chat_model
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
0
src/khoj/migrations/__init__.py
Normal file
0
src/khoj/migrations/__init__.py
Normal file
66
src/khoj/migrations/migrate_processor_config_openai.py
Normal file
66
src/khoj/migrations/migrate_processor_config_openai.py
Normal file
|
@ -0,0 +1,66 @@
|
||||||
|
"""
|
||||||
|
Current format of khoj.yml
|
||||||
|
---
|
||||||
|
app:
|
||||||
|
should-log-telemetry: true
|
||||||
|
content-type:
|
||||||
|
...
|
||||||
|
processor:
|
||||||
|
conversation:
|
||||||
|
chat-model: gpt-3.5-turbo
|
||||||
|
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||||
|
model: text-davinci-003
|
||||||
|
openai-api-key: sk-secret-key
|
||||||
|
search-type:
|
||||||
|
...
|
||||||
|
|
||||||
|
New format of khoj.yml
|
||||||
|
---
|
||||||
|
app:
|
||||||
|
should-log-telemetry: true
|
||||||
|
content-type:
|
||||||
|
...
|
||||||
|
processor:
|
||||||
|
conversation:
|
||||||
|
openai:
|
||||||
|
chat-model: gpt-3.5-turbo
|
||||||
|
openai-api-key: sk-secret-key
|
||||||
|
conversation-logfile: ~/.khoj/processor/conversation/conversation_logs.json
|
||||||
|
enable-offline-chat: false
|
||||||
|
search-type:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_processor_conversation_schema(args):
|
||||||
|
raw_config = load_config_from_file(args.config_file)
|
||||||
|
|
||||||
|
raw_config["version"] = args.version_no
|
||||||
|
|
||||||
|
if "processor" not in raw_config:
|
||||||
|
return args
|
||||||
|
if raw_config["processor"] is None:
|
||||||
|
return args
|
||||||
|
if "conversation" not in raw_config["processor"]:
|
||||||
|
return args
|
||||||
|
|
||||||
|
# Add enable_offline_chat to khoj config schema
|
||||||
|
if "enable-offline-chat" not in raw_config["processor"]["conversation"]:
|
||||||
|
raw_config["processor"]["conversation"]["enable-offline-chat"] = False
|
||||||
|
save_config_to_file(raw_config, args.config_file)
|
||||||
|
|
||||||
|
current_openai_api_key = raw_config["processor"]["conversation"].get("openai-api-key", None)
|
||||||
|
current_chat_model = raw_config["processor"]["conversation"].get("chat-model", None)
|
||||||
|
if current_openai_api_key is None and current_chat_model is None:
|
||||||
|
return args
|
||||||
|
|
||||||
|
conversation_logfile = raw_config["processor"]["conversation"].get("conversation-logfile", None)
|
||||||
|
|
||||||
|
raw_config["processor"]["conversation"] = {
|
||||||
|
"openai": {"chat-model": current_chat_model, "api-key": current_openai_api_key},
|
||||||
|
"conversation-logfile": conversation_logfile,
|
||||||
|
"enable-offline-chat": False,
|
||||||
|
}
|
||||||
|
save_config_to_file(raw_config, args.config_file)
|
||||||
|
return args
|
16
src/khoj/migrations/migrate_version.py
Normal file
16
src/khoj/migrations/migrate_version.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from khoj.utils.yaml import load_config_from_file, save_config_to_file
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_config_to_version(args):
|
||||||
|
raw_config = load_config_from_file(args.config_file)
|
||||||
|
|
||||||
|
# Add version to khoj config schema
|
||||||
|
if "version" not in raw_config:
|
||||||
|
raw_config["version"] = args.version_no
|
||||||
|
save_config_to_file(raw_config, args.config_file)
|
||||||
|
|
||||||
|
# regenerate khoj index on first start of this version
|
||||||
|
# this should refresh index and apply index corruption fixes from #325
|
||||||
|
args.regenerate = True
|
||||||
|
|
||||||
|
return args
|
0
src/khoj/processor/conversation/gpt4all/__init__.py
Normal file
0
src/khoj/processor/conversation/gpt4all/__init__.py
Normal file
137
src/khoj/processor/conversation/gpt4all/chat_model.py
Normal file
137
src/khoj/processor/conversation/gpt4all/chat_model.py
Normal file
|
@ -0,0 +1,137 @@
|
||||||
|
from typing import Union, List
|
||||||
|
from datetime import datetime
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
|
from gpt4all import GPT4All
|
||||||
|
|
||||||
|
|
||||||
|
from khoj.processor.conversation.utils import ThreadedGenerator, generate_chatml_messages_with_context
|
||||||
|
from khoj.processor.conversation import prompts
|
||||||
|
from khoj.utils.constants import empty_escape_sequences
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_questions_falcon(
|
||||||
|
text: str,
|
||||||
|
model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
|
||||||
|
loaded_model: Union[GPT4All, None] = None,
|
||||||
|
conversation_log={},
|
||||||
|
use_history: bool = False,
|
||||||
|
run_extraction: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Infer search queries to retrieve relevant notes to answer user query
|
||||||
|
"""
|
||||||
|
all_questions = text.split("? ")
|
||||||
|
all_questions = [q + "?" for q in all_questions[:-1]] + [all_questions[-1]]
|
||||||
|
if not run_extraction:
|
||||||
|
return all_questions
|
||||||
|
|
||||||
|
gpt4all_model = loaded_model or GPT4All(model)
|
||||||
|
|
||||||
|
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||||
|
chat_history = ""
|
||||||
|
|
||||||
|
if use_history:
|
||||||
|
chat_history = "".join(
|
||||||
|
[
|
||||||
|
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\nA: {chat["message"]}\n\n'
|
||||||
|
for chat in conversation_log.get("chat", [])[-4:]
|
||||||
|
if chat["by"] == "khoj"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = prompts.extract_questions_falcon.format(
|
||||||
|
chat_history=chat_history,
|
||||||
|
text=text,
|
||||||
|
)
|
||||||
|
message = prompts.general_conversation_falcon.format(query=prompt)
|
||||||
|
response = gpt4all_model.generate(message, max_tokens=200, top_k=2)
|
||||||
|
|
||||||
|
# Extract, Clean Message from GPT's Response
|
||||||
|
try:
|
||||||
|
questions = (
|
||||||
|
str(response)
|
||||||
|
.strip(empty_escape_sequences)
|
||||||
|
.replace("['", '["')
|
||||||
|
.replace("']", '"]')
|
||||||
|
.replace("', '", '", "')
|
||||||
|
.replace('["', "")
|
||||||
|
.replace('"]', "")
|
||||||
|
.split('", "')
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
logger.warning(f"Falcon returned invalid JSON. Falling back to using user message as search query.\n{response}")
|
||||||
|
return all_questions
|
||||||
|
logger.debug(f"Extracted Questions by Falcon: {questions}")
|
||||||
|
questions.extend(all_questions)
|
||||||
|
return questions
|
||||||
|
|
||||||
|
|
||||||
|
def converse_falcon(
|
||||||
|
references,
|
||||||
|
user_query,
|
||||||
|
conversation_log={},
|
||||||
|
model: str = "ggml-model-gpt4all-falcon-q4_0.bin",
|
||||||
|
loaded_model: Union[GPT4All, None] = None,
|
||||||
|
completion_func=None,
|
||||||
|
) -> ThreadedGenerator:
|
||||||
|
"""
|
||||||
|
Converse with user using Falcon
|
||||||
|
"""
|
||||||
|
gpt4all_model = loaded_model or GPT4All(model)
|
||||||
|
# Initialize Variables
|
||||||
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
||||||
|
|
||||||
|
# Get Conversation Primer appropriate to Conversation Type
|
||||||
|
# TODO If compiled_references_message is too long, we need to truncate it.
|
||||||
|
if compiled_references_message == "":
|
||||||
|
conversation_primer = prompts.conversation_falcon.format(query=user_query)
|
||||||
|
else:
|
||||||
|
conversation_primer = prompts.notes_conversation.format(
|
||||||
|
current_date=current_date, query=user_query, references=compiled_references_message
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup Prompt with Primer or Conversation History
|
||||||
|
messages = generate_chatml_messages_with_context(
|
||||||
|
conversation_primer,
|
||||||
|
prompts.personality.format(),
|
||||||
|
conversation_log,
|
||||||
|
model_name="text-davinci-001", # This isn't actually the model, but this helps us get an approximate encoding to run message truncation.
|
||||||
|
)
|
||||||
|
|
||||||
|
g = ThreadedGenerator(references, completion_func=completion_func)
|
||||||
|
t = Thread(target=llm_thread, args=(g, messages, gpt4all_model))
|
||||||
|
t.start()
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def llm_thread(g, messages: List[ChatMessage], model: GPT4All):
|
||||||
|
user_message = messages[0]
|
||||||
|
system_message = messages[-1]
|
||||||
|
conversation_history = messages[1:-1]
|
||||||
|
|
||||||
|
formatted_messages = [
|
||||||
|
prompts.chat_history_falcon_from_assistant.format(message=system_message)
|
||||||
|
if message.role == "assistant"
|
||||||
|
else prompts.chat_history_falcon_from_user.format(message=message.content)
|
||||||
|
for message in conversation_history
|
||||||
|
]
|
||||||
|
|
||||||
|
chat_history = "".join(formatted_messages)
|
||||||
|
full_message = system_message.content + chat_history + user_message.content
|
||||||
|
|
||||||
|
prompted_message = prompts.general_conversation_falcon.format(query=full_message)
|
||||||
|
response_iterator = model.generate(
|
||||||
|
prompted_message, streaming=True, max_tokens=256, top_k=1, temp=0, repeat_penalty=2.0
|
||||||
|
)
|
||||||
|
for response in response_iterator:
|
||||||
|
logger.info(response)
|
||||||
|
g.send(response)
|
||||||
|
g.close()
|
0
src/khoj/processor/conversation/openai/__init__.py
Normal file
0
src/khoj/processor/conversation/openai/__init__.py
Normal file
|
@ -9,11 +9,11 @@ from langchain.schema import ChatMessage
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.constants import empty_escape_sequences
|
from khoj.utils.constants import empty_escape_sequences
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.openai.utils import (
|
||||||
chat_completion_with_backoff,
|
chat_completion_with_backoff,
|
||||||
completion_with_backoff,
|
completion_with_backoff,
|
||||||
generate_chatml_messages_with_context,
|
|
||||||
)
|
)
|
||||||
|
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
101
src/khoj/processor/conversation/openai/utils.py
Normal file
101
src/khoj/processor/conversation/openai/utils.py
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
# Standard Packages
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
|
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||||
|
from langchain.callbacks.base import BaseCallbackManager
|
||||||
|
import openai
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
wait_random_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from khoj.processor.conversation.utils import ThreadedGenerator
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
||||||
|
def __init__(self, gen: ThreadedGenerator):
|
||||||
|
super().__init__()
|
||||||
|
self.gen = gen
|
||||||
|
|
||||||
|
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
||||||
|
self.gen.send(token)
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
wait=wait_random_exponential(min=1, max=10),
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def completion_with_backoff(**kwargs):
|
||||||
|
messages = kwargs.pop("messages")
|
||||||
|
if not "openai_api_key" in kwargs:
|
||||||
|
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
||||||
|
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
|
||||||
|
return llm(messages=messages)
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def chat_completion_with_backoff(
|
||||||
|
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
|
||||||
|
):
|
||||||
|
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
|
||||||
|
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
||||||
|
t.start()
|
||||||
|
return g
|
||||||
|
|
||||||
|
|
||||||
|
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
||||||
|
callback_handler = StreamingChatCallbackHandler(g)
|
||||||
|
chat = ChatOpenAI(
|
||||||
|
streaming=True,
|
||||||
|
verbose=True,
|
||||||
|
callback_manager=BaseCallbackManager([callback_handler]),
|
||||||
|
model_name=model_name, # type: ignore
|
||||||
|
temperature=temperature,
|
||||||
|
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
||||||
|
request_timeout=20,
|
||||||
|
max_retries=1,
|
||||||
|
client=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
chat(messages=messages)
|
||||||
|
|
||||||
|
g.close()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_summaries(metadata):
|
||||||
|
"""Extract summaries from metadata"""
|
||||||
|
return "".join([f'\n{session["summary"]}' for session in metadata])
|
|
@ -18,6 +18,36 @@ Question: {query}
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
general_conversation_falcon = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
Using your general knowledge and our past conversations as context, answer the following question.
|
||||||
|
### Instruct:
|
||||||
|
{query}
|
||||||
|
### Response:
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_history_falcon_from_user = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
### Human:
|
||||||
|
{message}
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_history_falcon_from_assistant = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
### Assistant:
|
||||||
|
{message}
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_falcon = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
Using our past conversations as context, answer the following question.
|
||||||
|
|
||||||
|
Question: {query}
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
## Notes Conversation
|
## Notes Conversation
|
||||||
## --
|
## --
|
||||||
|
@ -33,6 +63,17 @@ Question: {query}
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
notes_conversation_falcon = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
Using the notes and our past conversations as context, answer the following question. If the answer is not contained within the notes, say "I don't know."
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
{references}
|
||||||
|
|
||||||
|
Question: {query}
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
## Summarize Chat
|
## Summarize Chat
|
||||||
## --
|
## --
|
||||||
|
@ -68,6 +109,40 @@ Question: {user_query}
|
||||||
Answer (in second person):"""
|
Answer (in second person):"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extract_questions_falcon = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
You are Khoj, an extremely smart and helpful search assistant with the ability to retrieve information from the user's notes.
|
||||||
|
- The user will provide their questions and answers to you for context.
|
||||||
|
- Add as much context from the previous questions and answers as required into your search queries.
|
||||||
|
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||||
|
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||||
|
|
||||||
|
What searches, if any, will you need to perform to answer the users question?
|
||||||
|
|
||||||
|
Q: How was my trip to Cambodia?
|
||||||
|
|
||||||
|
["How was my trip to Cambodia?"]
|
||||||
|
|
||||||
|
A: The trip was amazing. I went to the Angkor Wat temple and it was beautiful.
|
||||||
|
|
||||||
|
Q: Who did i visit that temple with?
|
||||||
|
|
||||||
|
["Who did I visit the Angkor Wat Temple in Cambodia with?"]
|
||||||
|
|
||||||
|
A: You visited the Angkor Wat Temple in Cambodia with Pablo, Namita and Xi.
|
||||||
|
|
||||||
|
Q: How many tennis balls fit in the back of a 2002 Honda Civic?
|
||||||
|
|
||||||
|
["What is the size of a tennis ball?", "What is the trunk size of a 2002 Honda Civic?"]
|
||||||
|
|
||||||
|
A: 1085 tennis balls will fit in the trunk of a Honda Civic
|
||||||
|
|
||||||
|
{chat_history}
|
||||||
|
Q: {text}
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
## Extract Questions
|
## Extract Questions
|
||||||
## --
|
## --
|
||||||
|
|
|
@ -1,35 +1,19 @@
|
||||||
# Standard Packages
|
# Standard Packages
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from typing import Any
|
|
||||||
from threading import Thread
|
|
||||||
import json
|
import json
|
||||||
|
from datetime import datetime
|
||||||
# External Packages
|
|
||||||
from langchain.chat_models import ChatOpenAI
|
|
||||||
from langchain.schema import ChatMessage
|
|
||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
|
||||||
from langchain.callbacks.base import BaseCallbackManager
|
|
||||||
import openai
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from tenacity import (
|
|
||||||
before_sleep_log,
|
# External packages
|
||||||
retry,
|
from langchain.schema import ChatMessage
|
||||||
retry_if_exception_type,
|
|
||||||
stop_after_attempt,
|
|
||||||
wait_exponential,
|
|
||||||
wait_random_exponential,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.helpers import merge_dicts
|
|
||||||
import queue
|
import queue
|
||||||
|
from khoj.utils.helpers import merge_dicts
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192}
|
max_prompt_size = {"gpt-3.5-turbo": 4096, "gpt-4": 8192, "text-davinci-001": 910}
|
||||||
|
|
||||||
|
|
||||||
class ThreadedGenerator:
|
class ThreadedGenerator:
|
||||||
|
@ -49,9 +33,9 @@ class ThreadedGenerator:
|
||||||
time_to_response = perf_counter() - self.start_time
|
time_to_response = perf_counter() - self.start_time
|
||||||
logger.info(f"Chat streaming took: {time_to_response:.3f} seconds")
|
logger.info(f"Chat streaming took: {time_to_response:.3f} seconds")
|
||||||
if self.completion_func:
|
if self.completion_func:
|
||||||
# The completion func effective acts as a callback.
|
# The completion func effectively acts as a callback.
|
||||||
# It adds the aggregated response to the conversation history. It's constructed in api.py.
|
# It adds the aggregated response to the conversation history.
|
||||||
self.completion_func(gpt_response=self.response)
|
self.completion_func(chat_response=self.response)
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
return item
|
return item
|
||||||
|
|
||||||
|
@ -65,75 +49,25 @@ class ThreadedGenerator:
|
||||||
self.queue.put(StopIteration)
|
self.queue.put(StopIteration)
|
||||||
|
|
||||||
|
|
||||||
class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
def message_to_log(
|
||||||
def __init__(self, gen: ThreadedGenerator):
|
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
|
||||||
super().__init__()
|
|
||||||
self.gen = gen
|
|
||||||
|
|
||||||
def on_llm_new_token(self, token: str, **kwargs) -> Any:
|
|
||||||
self.gen.send(token)
|
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
retry=(
|
|
||||||
retry_if_exception_type(openai.error.Timeout)
|
|
||||||
| retry_if_exception_type(openai.error.APIError)
|
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
|
||||||
| retry_if_exception_type(openai.error.RateLimitError)
|
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
|
||||||
),
|
|
||||||
wait=wait_random_exponential(min=1, max=10),
|
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
def completion_with_backoff(**kwargs):
|
|
||||||
messages = kwargs.pop("messages")
|
|
||||||
if not "openai_api_key" in kwargs:
|
|
||||||
kwargs["openai_api_key"] = os.getenv("OPENAI_API_KEY")
|
|
||||||
llm = ChatOpenAI(**kwargs, request_timeout=20, max_retries=1)
|
|
||||||
return llm(messages=messages)
|
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
|
||||||
retry=(
|
|
||||||
retry_if_exception_type(openai.error.Timeout)
|
|
||||||
| retry_if_exception_type(openai.error.APIError)
|
|
||||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
|
||||||
| retry_if_exception_type(openai.error.RateLimitError)
|
|
||||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
|
||||||
),
|
|
||||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
|
||||||
stop=stop_after_attempt(3),
|
|
||||||
before_sleep=before_sleep_log(logger, logging.DEBUG),
|
|
||||||
reraise=True,
|
|
||||||
)
|
|
||||||
def chat_completion_with_backoff(
|
|
||||||
messages, compiled_references, model_name, temperature, openai_api_key=None, completion_func=None
|
|
||||||
):
|
):
|
||||||
g = ThreadedGenerator(compiled_references, completion_func=completion_func)
|
"""Create json logs from messages, metadata for conversation log"""
|
||||||
t = Thread(target=llm_thread, args=(g, messages, model_name, temperature, openai_api_key))
|
default_khoj_message_metadata = {
|
||||||
t.start()
|
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
|
||||||
return g
|
"trigger-emotion": "calm",
|
||||||
|
}
|
||||||
|
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
# Create json log from Human's message
|
||||||
|
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
|
||||||
|
|
||||||
def llm_thread(g, messages, model_name, temperature, openai_api_key=None):
|
# Create json log from GPT's response
|
||||||
callback_handler = StreamingChatCallbackHandler(g)
|
khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
|
||||||
chat = ChatOpenAI(
|
khoj_log = merge_dicts({"message": chat_response, "by": "khoj", "created": khoj_response_time}, khoj_log)
|
||||||
streaming=True,
|
|
||||||
verbose=True,
|
|
||||||
callback_manager=BaseCallbackManager([callback_handler]),
|
|
||||||
model_name=model_name, # type: ignore
|
|
||||||
temperature=temperature,
|
|
||||||
openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"),
|
|
||||||
request_timeout=20,
|
|
||||||
max_retries=1,
|
|
||||||
client=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
chat(messages=messages)
|
conversation_log.extend([human_log, khoj_log])
|
||||||
|
return conversation_log
|
||||||
g.close()
|
|
||||||
|
|
||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
|
@ -192,27 +126,3 @@ def truncate_messages(messages, max_prompt_size, model_name):
|
||||||
def reciprocal_conversation_to_chatml(message_pair):
|
def reciprocal_conversation_to_chatml(message_pair):
|
||||||
"""Convert a single back and forth between user and assistant to chatml format"""
|
"""Convert a single back and forth between user and assistant to chatml format"""
|
||||||
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
||||||
|
|
||||||
|
|
||||||
def message_to_log(user_message, gpt_message, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]):
|
|
||||||
"""Create json logs from messages, metadata for conversation log"""
|
|
||||||
default_khoj_message_metadata = {
|
|
||||||
"intent": {"type": "remember", "memory-type": "notes", "query": user_message},
|
|
||||||
"trigger-emotion": "calm",
|
|
||||||
}
|
|
||||||
khoj_response_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
|
|
||||||
# Create json log from Human's message
|
|
||||||
human_log = merge_dicts({"message": user_message, "by": "you"}, user_message_metadata)
|
|
||||||
|
|
||||||
# Create json log from GPT's response
|
|
||||||
khoj_log = merge_dicts(khoj_message_metadata, default_khoj_message_metadata)
|
|
||||||
khoj_log = merge_dicts({"message": gpt_message, "by": "khoj", "created": khoj_response_time}, khoj_log)
|
|
||||||
|
|
||||||
conversation_log.extend([human_log, khoj_log])
|
|
||||||
return conversation_log
|
|
||||||
|
|
||||||
|
|
||||||
def extract_summaries(metadata):
|
|
||||||
"""Extract summaries from metadata"""
|
|
||||||
return "".join([f'\n{session["summary"]}' for session in metadata])
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ class GithubToJsonl(TextToJsonl):
|
||||||
try:
|
try:
|
||||||
markdown_files, org_files = self.get_files(repo_url, repo)
|
markdown_files, org_files = self.get_files(repo_url, repo)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unable to download github repo {repo_shorthand}")
|
logger.error(f"Unable to download github repo {repo_shorthand}", exc_info=True)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
logger.info(f"Found {len(markdown_files)} markdown files in github repo {repo_shorthand}")
|
logger.info(f"Found {len(markdown_files)} markdown files in github repo {repo_shorthand}")
|
||||||
|
|
|
@ -219,7 +219,7 @@ class NotionToJsonl(TextToJsonl):
|
||||||
page = self.get_page(page_id)
|
page = self.get_page(page_id)
|
||||||
content = self.get_page_children(page_id)
|
content = self.get_page_children(page_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error getting page {page_id}: {e}")
|
logger.error(f"Error getting page {page_id}: {e}", exc_info=True)
|
||||||
return None, None
|
return None, None
|
||||||
properties = page["properties"]
|
properties = page["properties"]
|
||||||
title_field = "title"
|
title_field = "title"
|
||||||
|
|
|
@ -5,7 +5,7 @@ import time
|
||||||
import yaml
|
import yaml
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
from typing import Iterable, List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter, HTTPException, Header, Request
|
from fastapi import APIRouter, HTTPException, Header, Request
|
||||||
|
@ -26,16 +26,19 @@ from khoj.utils.rawconfig import (
|
||||||
SearchConfig,
|
SearchConfig,
|
||||||
SearchResponse,
|
SearchResponse,
|
||||||
TextContentConfig,
|
TextContentConfig,
|
||||||
ConversationProcessorConfig,
|
OpenAIProcessorConfig,
|
||||||
GithubContentConfig,
|
GithubContentConfig,
|
||||||
NotionContentConfig,
|
NotionContentConfig,
|
||||||
|
ConversationProcessorConfig,
|
||||||
)
|
)
|
||||||
|
from khoj.utils.helpers import resolve_absolute_path
|
||||||
from khoj.utils.state import SearchType
|
from khoj.utils.state import SearchType
|
||||||
from khoj.utils import state, constants
|
from khoj.utils import state, constants
|
||||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||||
from fastapi.responses import StreamingResponse, Response
|
from fastapi.responses import StreamingResponse, Response
|
||||||
from khoj.routers.helpers import perform_chat_checks, generate_chat_response, update_telemetry_state
|
from khoj.routers.helpers import perform_chat_checks, generate_chat_response, update_telemetry_state
|
||||||
from khoj.processor.conversation.gpt import extract_questions
|
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||||
|
from khoj.processor.conversation.gpt4all.chat_model import extract_questions_falcon, converse_falcon
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,6 +53,8 @@ if not state.demo:
|
||||||
if state.config is None:
|
if state.config is None:
|
||||||
state.config = FullConfig()
|
state.config = FullConfig()
|
||||||
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
state.config.search_type = SearchConfig.parse_obj(constants.default_config["search-type"])
|
||||||
|
if state.processor_config is None:
|
||||||
|
state.processor_config = configure_processor(state.config.processor)
|
||||||
|
|
||||||
@api.get("/config/data", response_model=FullConfig)
|
@api.get("/config/data", response_model=FullConfig)
|
||||||
def get_config_data():
|
def get_config_data():
|
||||||
|
@ -181,22 +186,28 @@ if not state.demo:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
@api.post("/delete/config/data/processor/conversation", status_code=200)
|
@api.post("/delete/config/data/processor/conversation/openai", status_code=200)
|
||||||
async def remove_processor_conversation_config_data(
|
async def remove_processor_conversation_config_data(
|
||||||
request: Request,
|
request: Request,
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if not state.config or not state.config.processor or not state.config.processor.conversation:
|
if (
|
||||||
|
not state.config
|
||||||
|
or not state.config.processor
|
||||||
|
or not state.config.processor.conversation
|
||||||
|
or not state.config.processor.conversation.openai
|
||||||
|
):
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
state.config.processor.conversation = None
|
state.config.processor.conversation.openai = None
|
||||||
|
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="delete_processor_config",
|
api="delete_processor_openai_config",
|
||||||
client=client,
|
client=client,
|
||||||
metadata={"processor_type": "conversation"},
|
metadata={"processor_conversation_type": "openai"},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -233,23 +244,66 @@ if not state.demo:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {"status": "error", "message": str(e)}
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
@api.post("/config/data/processor/conversation", status_code=200)
|
@api.post("/config/data/processor/conversation/openai", status_code=200)
|
||||||
async def set_processor_conversation_config_data(
|
async def set_processor_openai_config_data(
|
||||||
request: Request,
|
request: Request,
|
||||||
updated_config: Union[ConversationProcessorConfig, None],
|
updated_config: Union[OpenAIProcessorConfig, None],
|
||||||
client: Optional[str] = None,
|
client: Optional[str] = None,
|
||||||
):
|
):
|
||||||
_initialize_config()
|
_initialize_config()
|
||||||
|
|
||||||
state.config.processor = ProcessorConfig(conversation=updated_config)
|
if not state.config.processor or not state.config.processor.conversation:
|
||||||
state.processor_config = configure_processor(state.config.processor)
|
default_config = constants.default_config
|
||||||
|
default_conversation_logfile = resolve_absolute_path(
|
||||||
|
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
|
||||||
|
)
|
||||||
|
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
|
||||||
|
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
|
||||||
|
|
||||||
|
assert state.config.processor.conversation is not None
|
||||||
|
state.config.processor.conversation.openai = updated_config
|
||||||
|
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="set_content_config",
|
api="set_processor_config",
|
||||||
client=client,
|
client=client,
|
||||||
metadata={"processor_type": "conversation"},
|
metadata={"processor_conversation_type": "conversation"},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
save_config_to_file_updated_state()
|
||||||
|
return {"status": "ok"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
@api.post("/config/data/processor/conversation/enable_offline_chat", status_code=200)
|
||||||
|
async def set_processor_enable_offline_chat_config_data(
|
||||||
|
request: Request,
|
||||||
|
enable_offline_chat: bool,
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
_initialize_config()
|
||||||
|
|
||||||
|
if not state.config.processor or not state.config.processor.conversation:
|
||||||
|
default_config = constants.default_config
|
||||||
|
default_conversation_logfile = resolve_absolute_path(
|
||||||
|
default_config["processor"]["conversation"]["conversation-logfile"] # type: ignore
|
||||||
|
)
|
||||||
|
conversation_logfile = resolve_absolute_path(default_conversation_logfile)
|
||||||
|
state.config.processor = ProcessorConfig(conversation=ConversationProcessorConfig(conversation_logfile=conversation_logfile)) # type: ignore
|
||||||
|
|
||||||
|
assert state.config.processor.conversation is not None
|
||||||
|
state.config.processor.conversation.enable_offline_chat = enable_offline_chat
|
||||||
|
state.processor_config = configure_processor(state.config.processor, state.processor_config)
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="set_processor_config",
|
||||||
|
client=client,
|
||||||
|
metadata={"processor_conversation_type": f"{'enable' if enable_offline_chat else 'disable'}_local_llm"},
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -569,7 +623,9 @@ def chat_history(
|
||||||
perform_chat_checks()
|
perform_chat_checks()
|
||||||
|
|
||||||
# Load Conversation History
|
# Load Conversation History
|
||||||
meta_log = state.processor_config.conversation.meta_log
|
meta_log = {}
|
||||||
|
if state.processor_config.conversation:
|
||||||
|
meta_log = state.processor_config.conversation.meta_log
|
||||||
|
|
||||||
update_telemetry_state(
|
update_telemetry_state(
|
||||||
request=request,
|
request=request,
|
||||||
|
@ -598,24 +654,25 @@ async def chat(
|
||||||
perform_chat_checks()
|
perform_chat_checks()
|
||||||
compiled_references, inferred_queries = await extract_references_and_questions(request, q, (n or 5))
|
compiled_references, inferred_queries = await extract_references_and_questions(request, q, (n or 5))
|
||||||
|
|
||||||
# Get the (streamed) chat response from GPT.
|
# Get the (streamed) chat response from the LLM of choice.
|
||||||
gpt_response = generate_chat_response(
|
llm_response = generate_chat_response(
|
||||||
q,
|
q,
|
||||||
meta_log=state.processor_config.conversation.meta_log,
|
meta_log=state.processor_config.conversation.meta_log,
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
inferred_queries=inferred_queries,
|
inferred_queries=inferred_queries,
|
||||||
)
|
)
|
||||||
if gpt_response is None:
|
|
||||||
return Response(content=gpt_response, media_type="text/plain", status_code=500)
|
if llm_response is None:
|
||||||
|
return Response(content=llm_response, media_type="text/plain", status_code=500)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return StreamingResponse(gpt_response, media_type="text/event-stream", status_code=200)
|
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
|
||||||
|
|
||||||
# Get the full response from the generator if the stream is not requested.
|
# Get the full response from the generator if the stream is not requested.
|
||||||
aggregated_gpt_response = ""
|
aggregated_gpt_response = ""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
aggregated_gpt_response += next(gpt_response)
|
aggregated_gpt_response += next(llm_response)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -645,8 +702,6 @@ async def extract_references_and_questions(
|
||||||
meta_log = state.processor_config.conversation.meta_log
|
meta_log = state.processor_config.conversation.meta_log
|
||||||
|
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
api_key = state.processor_config.conversation.openai_api_key
|
|
||||||
chat_model = state.processor_config.conversation.chat_model
|
|
||||||
conversation_type = "general" if q.startswith("@general") else "notes"
|
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||||
compiled_references = []
|
compiled_references = []
|
||||||
inferred_queries = []
|
inferred_queries = []
|
||||||
|
@ -654,7 +709,13 @@ async def extract_references_and_questions(
|
||||||
if conversation_type == "notes":
|
if conversation_type == "notes":
|
||||||
# Infer search queries from user message
|
# Infer search queries from user message
|
||||||
with timer("Extracting search queries took", logger):
|
with timer("Extracting search queries took", logger):
|
||||||
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
|
if state.processor_config.conversation and state.processor_config.conversation.openai_model:
|
||||||
|
api_key = state.processor_config.conversation.openai_model.api_key
|
||||||
|
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||||
|
inferred_queries = extract_questions(q, model=chat_model, api_key=api_key, conversation_log=meta_log)
|
||||||
|
else:
|
||||||
|
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||||
|
inferred_queries = extract_questions_falcon(q, loaded_model=loaded_model, conversation_log=meta_log)
|
||||||
|
|
||||||
# Collate search results as context for GPT
|
# Collate search results as context for GPT
|
||||||
with timer("Searching knowledge base took", logger):
|
with timer("Searching knowledge base took", logger):
|
||||||
|
|
|
@ -7,22 +7,23 @@ from fastapi import HTTPException, Request
|
||||||
|
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.helpers import timer, log_telemetry
|
from khoj.utils.helpers import timer, log_telemetry
|
||||||
from khoj.processor.conversation.gpt import converse
|
from khoj.processor.conversation.openai.gpt import converse
|
||||||
from khoj.processor.conversation.utils import message_to_log, reciprocal_conversation_to_chatml
|
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon
|
||||||
|
from khoj.processor.conversation.utils import reciprocal_conversation_to_chatml, message_to_log, ThreadedGenerator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def perform_chat_checks():
|
def perform_chat_checks():
|
||||||
if (
|
if state.processor_config.conversation and (
|
||||||
state.processor_config is None
|
state.processor_config.conversation.openai_model
|
||||||
or state.processor_config.conversation is None
|
or state.processor_config.conversation.gpt4all_model.loaded_model
|
||||||
or state.processor_config.conversation.openai_api_key is None
|
|
||||||
):
|
):
|
||||||
raise HTTPException(
|
return
|
||||||
status_code=500, detail="Set your OpenAI API key via Khoj settings and restart it to use Khoj Chat."
|
|
||||||
)
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings and restart it."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def update_telemetry_state(
|
def update_telemetry_state(
|
||||||
|
@ -57,19 +58,19 @@ def generate_chat_response(
|
||||||
meta_log: dict,
|
meta_log: dict,
|
||||||
compiled_references: List[str] = [],
|
compiled_references: List[str] = [],
|
||||||
inferred_queries: List[str] = [],
|
inferred_queries: List[str] = [],
|
||||||
):
|
) -> ThreadedGenerator:
|
||||||
def _save_to_conversation_log(
|
def _save_to_conversation_log(
|
||||||
q: str,
|
q: str,
|
||||||
gpt_response: str,
|
chat_response: str,
|
||||||
user_message_time: str,
|
user_message_time: str,
|
||||||
compiled_references: List[str],
|
compiled_references: List[str],
|
||||||
inferred_queries: List[str],
|
inferred_queries: List[str],
|
||||||
meta_log,
|
meta_log,
|
||||||
):
|
):
|
||||||
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, gpt_response])
|
state.processor_config.conversation.chat_session += reciprocal_conversation_to_chatml([q, chat_response])
|
||||||
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
state.processor_config.conversation.meta_log["chat"] = message_to_log(
|
||||||
q,
|
user_message=q,
|
||||||
gpt_response,
|
chat_response=chat_response,
|
||||||
user_message_metadata={"created": user_message_time},
|
user_message_metadata={"created": user_message_time},
|
||||||
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
khoj_message_metadata={"context": compiled_references, "intent": {"inferred-queries": inferred_queries}},
|
||||||
conversation_log=meta_log.get("chat", []),
|
conversation_log=meta_log.get("chat", []),
|
||||||
|
@ -79,8 +80,6 @@ def generate_chat_response(
|
||||||
meta_log = state.processor_config.conversation.meta_log
|
meta_log = state.processor_config.conversation.meta_log
|
||||||
|
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
api_key = state.processor_config.conversation.openai_api_key
|
|
||||||
chat_model = state.processor_config.conversation.chat_model
|
|
||||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
conversation_type = "general" if q.startswith("@general") else "notes"
|
conversation_type = "general" if q.startswith("@general") else "notes"
|
||||||
|
|
||||||
|
@ -99,12 +98,29 @@ def generate_chat_response(
|
||||||
meta_log=meta_log,
|
meta_log=meta_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
gpt_response = converse(
|
if state.processor_config.conversation.openai_model:
|
||||||
compiled_references, q, meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion
|
api_key = state.processor_config.conversation.openai_model.api_key
|
||||||
)
|
chat_model = state.processor_config.conversation.openai_model.chat_model
|
||||||
|
chat_response = converse(
|
||||||
|
compiled_references,
|
||||||
|
q,
|
||||||
|
meta_log,
|
||||||
|
model=chat_model,
|
||||||
|
api_key=api_key,
|
||||||
|
completion_func=partial_completion,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loaded_model = state.processor_config.conversation.gpt4all_model.loaded_model
|
||||||
|
chat_response = converse_falcon(
|
||||||
|
references=compiled_references,
|
||||||
|
user_query=q,
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
conversation_log=meta_log,
|
||||||
|
completion_func=partial_completion,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e, exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
return gpt_response
|
return chat_response
|
||||||
|
|
|
@ -3,7 +3,7 @@ from fastapi import APIRouter
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi.responses import HTMLResponse, FileResponse
|
from fastapi.responses import HTMLResponse, FileResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from khoj.utils.rawconfig import TextContentConfig, ConversationProcessorConfig, FullConfig
|
from khoj.utils.rawconfig import TextContentConfig, OpenAIProcessorConfig, FullConfig
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
|
@ -151,28 +151,29 @@ if not state.demo:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@web_client.get("/config/processor/conversation", response_class=HTMLResponse)
|
@web_client.get("/config/processor/conversation/openai", response_class=HTMLResponse)
|
||||||
def conversation_processor_config_page(request: Request):
|
def conversation_processor_config_page(request: Request):
|
||||||
default_copy = constants.default_config.copy()
|
default_copy = constants.default_config.copy()
|
||||||
default_processor_config = default_copy["processor"]["conversation"] # type: ignore
|
default_processor_config = default_copy["processor"]["conversation"]["openai"] # type: ignore
|
||||||
default_processor_config = ConversationProcessorConfig(
|
default_openai_config = OpenAIProcessorConfig(
|
||||||
openai_api_key="",
|
api_key="",
|
||||||
model=default_processor_config["model"],
|
|
||||||
conversation_logfile=default_processor_config["conversation-logfile"],
|
|
||||||
chat_model=default_processor_config["chat-model"],
|
chat_model=default_processor_config["chat-model"],
|
||||||
)
|
)
|
||||||
|
|
||||||
current_processor_conversation_config = (
|
current_processor_openai_config = (
|
||||||
state.config.processor.conversation
|
state.config.processor.conversation.openai
|
||||||
if state.config and state.config.processor and state.config.processor.conversation
|
if state.config
|
||||||
else default_processor_config
|
and state.config.processor
|
||||||
|
and state.config.processor.conversation
|
||||||
|
and state.config.processor.conversation.openai
|
||||||
|
else default_openai_config
|
||||||
)
|
)
|
||||||
current_processor_conversation_config = json.loads(current_processor_conversation_config.json())
|
current_processor_openai_config = json.loads(current_processor_openai_config.json())
|
||||||
|
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
"processor_conversation_input.html",
|
"processor_conversation_input.html",
|
||||||
context={
|
context={
|
||||||
"request": request,
|
"request": request,
|
||||||
"current_config": current_processor_conversation_config,
|
"current_config": current_processor_openai_config,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,7 +5,9 @@ from importlib.metadata import version
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.utils.helpers import resolve_absolute_path
|
from khoj.utils.helpers import resolve_absolute_path
|
||||||
from khoj.utils.yaml import load_config_from_file, parse_config_from_file, save_config_to_file
|
from khoj.utils.yaml import parse_config_from_file
|
||||||
|
from khoj.migrations.migrate_version import migrate_config_to_version
|
||||||
|
from khoj.migrations.migrate_processor_config_openai import migrate_processor_conversation_schema
|
||||||
|
|
||||||
|
|
||||||
def cli(args=None):
|
def cli(args=None):
|
||||||
|
@ -46,22 +48,14 @@ def cli(args=None):
|
||||||
if not args.config_file.exists():
|
if not args.config_file.exists():
|
||||||
args.config = None
|
args.config = None
|
||||||
else:
|
else:
|
||||||
args = migrate_config(args)
|
args = run_migrations(args)
|
||||||
args.config = parse_config_from_file(args.config_file)
|
args.config = parse_config_from_file(args.config_file)
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
def migrate_config(args):
|
def run_migrations(args):
|
||||||
raw_config = load_config_from_file(args.config_file)
|
migrations = [migrate_config_to_version, migrate_processor_conversation_schema]
|
||||||
|
for migration in migrations:
|
||||||
# Add version to khoj config schema
|
args = migration(args)
|
||||||
if "version" not in raw_config:
|
|
||||||
raw_config["version"] = args.version_no
|
|
||||||
save_config_to_file(raw_config, args.config_file)
|
|
||||||
|
|
||||||
# regenerate khoj index on first start of this version
|
|
||||||
# this should refresh index and apply index corruption fixes from #325
|
|
||||||
args.regenerate = True
|
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
# System Packages
|
# System Packages
|
||||||
from __future__ import annotations # to avoid quoting type hints
|
from __future__ import annotations # to avoid quoting type hints
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union, Any
|
||||||
|
|
||||||
|
from gpt4all import GPT4All
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
import torch
|
import torch
|
||||||
|
@ -13,7 +16,7 @@ if TYPE_CHECKING:
|
||||||
from sentence_transformers import CrossEncoder
|
from sentence_transformers import CrossEncoder
|
||||||
from khoj.search_filter.base_filter import BaseFilter
|
from khoj.search_filter.base_filter import BaseFilter
|
||||||
from khoj.utils.models import BaseEncoder
|
from khoj.utils.models import BaseEncoder
|
||||||
from khoj.utils.rawconfig import ConversationProcessorConfig, Entry
|
from khoj.utils.rawconfig import ConversationProcessorConfig, Entry, OpenAIProcessorConfig
|
||||||
|
|
||||||
|
|
||||||
class SearchType(str, Enum):
|
class SearchType(str, Enum):
|
||||||
|
@ -74,15 +77,29 @@ class SearchModels:
|
||||||
plugin_search: Optional[Dict[str, TextSearchModel]] = None
|
plugin_search: Optional[Dict[str, TextSearchModel]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GPT4AllProcessorConfig:
|
||||||
|
chat_model: Optional[str] = "ggml-model-gpt4all-falcon-q4_0.bin"
|
||||||
|
loaded_model: Union[Any, None] = None
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessorConfigModel:
|
class ConversationProcessorConfigModel:
|
||||||
def __init__(self, processor_config: ConversationProcessorConfig):
|
def __init__(
|
||||||
self.openai_api_key = processor_config.openai_api_key
|
self,
|
||||||
self.model = processor_config.model
|
conversation_config: ConversationProcessorConfig,
|
||||||
self.chat_model = processor_config.chat_model
|
):
|
||||||
self.conversation_logfile = Path(processor_config.conversation_logfile)
|
self.openai_model = conversation_config.openai
|
||||||
|
self.gpt4all_model = GPT4AllProcessorConfig()
|
||||||
|
self.enable_offline_chat = conversation_config.enable_offline_chat
|
||||||
|
self.conversation_logfile = Path(conversation_config.conversation_logfile)
|
||||||
self.chat_session: List[str] = []
|
self.chat_session: List[str] = []
|
||||||
self.meta_log: dict = {}
|
self.meta_log: dict = {}
|
||||||
|
|
||||||
|
if not self.openai_model and self.enable_offline_chat:
|
||||||
|
self.gpt4all_model.loaded_model = GPT4All(self.gpt4all_model.chat_model) # type: ignore
|
||||||
|
else:
|
||||||
|
self.gpt4all_model.loaded_model = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ProcessorConfigModel:
|
class ProcessorConfigModel:
|
||||||
|
|
|
@ -62,10 +62,12 @@ default_config = {
|
||||||
},
|
},
|
||||||
"processor": {
|
"processor": {
|
||||||
"conversation": {
|
"conversation": {
|
||||||
"openai-api-key": None,
|
"openai": {
|
||||||
"model": "text-davinci-003",
|
"api-key": None,
|
||||||
|
"chat-model": "gpt-3.5-turbo",
|
||||||
|
},
|
||||||
|
"enable-offline-chat": False,
|
||||||
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
|
"conversation-logfile": "~/.khoj/processor/conversation/conversation_logs.json",
|
||||||
"chat-model": "gpt-3.5-turbo",
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,12 +27,12 @@ class OpenAI(BaseEncoder):
|
||||||
if (
|
if (
|
||||||
not state.processor_config
|
not state.processor_config
|
||||||
or not state.processor_config.conversation
|
or not state.processor_config.conversation
|
||||||
or not state.processor_config.conversation.openai_api_key
|
or not state.processor_config.conversation.openai_model
|
||||||
):
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
|
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
|
||||||
)
|
)
|
||||||
openai.api_key = state.processor_config.conversation.openai_api_key
|
openai.api_key = state.processor_config.conversation.openai_model.api_key
|
||||||
self.embedding_dimensions = None
|
self.embedding_dimensions = None
|
||||||
|
|
||||||
def encode(self, entries, device=None, **kwargs):
|
def encode(self, entries, device=None, **kwargs):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# System Packages
|
# System Packages
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional, Union, Any
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from pydantic import BaseModel, validator
|
from pydantic import BaseModel, validator
|
||||||
|
@ -103,13 +103,17 @@ class SearchConfig(ConfigBase):
|
||||||
image: Optional[ImageSearchConfig]
|
image: Optional[ImageSearchConfig]
|
||||||
|
|
||||||
|
|
||||||
class ConversationProcessorConfig(ConfigBase):
|
class OpenAIProcessorConfig(ConfigBase):
|
||||||
openai_api_key: str
|
api_key: str
|
||||||
conversation_logfile: Path
|
|
||||||
model: Optional[str] = "text-davinci-003"
|
|
||||||
chat_model: Optional[str] = "gpt-3.5-turbo"
|
chat_model: Optional[str] = "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationProcessorConfig(ConfigBase):
|
||||||
|
conversation_logfile: Path
|
||||||
|
openai: Optional[OpenAIProcessorConfig]
|
||||||
|
enable_offline_chat: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class ProcessorConfig(ConfigBase):
|
class ProcessorConfig(ConfigBase):
|
||||||
conversation: Optional[ConversationProcessorConfig]
|
conversation: Optional[ConversationProcessorConfig]
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,7 @@ content-type:
|
||||||
embeddings-file: content_plugin_2_embeddings.pt
|
embeddings-file: content_plugin_2_embeddings.pt
|
||||||
input-filter:
|
input-filter:
|
||||||
- '*2_new.jsonl.gz'
|
- '*2_new.jsonl.gz'
|
||||||
|
enable-offline-chat: false
|
||||||
search-type:
|
search-type:
|
||||||
asymmetric:
|
asymmetric:
|
||||||
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
|
cross-encoder: cross-encoder/ms-marco-MiniLM-L-6-v2
|
||||||
|
|
426
tests/test_gpt4all_chat_actors.py
Normal file
426
tests/test_gpt4all_chat_actors.py
Normal file
|
@ -0,0 +1,426 @@
|
||||||
|
# Standard Packages
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# External Packages
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
SKIP_TESTS = True
|
||||||
|
pytestmark = pytest.mark.skipif(
|
||||||
|
SKIP_TESTS,
|
||||||
|
reason="The GPT4All library has some quirks that make it hard to test in CI. This causes some tests to fail. Hence, disable it in CI.",
|
||||||
|
)
|
||||||
|
|
||||||
|
import freezegun
|
||||||
|
from freezegun import freeze_time
|
||||||
|
|
||||||
|
from gpt4all import GPT4All
|
||||||
|
|
||||||
|
# Internal Packages
|
||||||
|
from khoj.processor.conversation.gpt4all.chat_model import converse_falcon, extract_questions_falcon
|
||||||
|
|
||||||
|
from khoj.processor.conversation.utils import message_to_log
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def loaded_model():
|
||||||
|
return GPT4All("ggml-model-gpt4all-falcon-q4_0.bin")
|
||||||
|
|
||||||
|
|
||||||
|
freezegun.configure(extend_ignore_list=["transformers"])
|
||||||
|
|
||||||
|
|
||||||
|
# Test
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
@freeze_time("1984-04-02")
|
||||||
|
def test_extract_question_with_date_filter_from_relative_day(loaded_model):
|
||||||
|
# Act
|
||||||
|
response = extract_questions_falcon(
|
||||||
|
"Where did I go for dinner yesterday?", loaded_model=loaded_model, run_extraction=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(response) >= 1
|
||||||
|
assert response[-1] == "Where did I go for dinner yesterday?"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
@freeze_time("1984-04-02")
|
||||||
|
def test_extract_question_with_date_filter_from_relative_month(loaded_model):
|
||||||
|
# Act
|
||||||
|
response = extract_questions_falcon("Which countries did I visit last month?", loaded_model=loaded_model)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(response) == 1
|
||||||
|
assert response == ["Which countries did I visit last month?"]
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
@freeze_time("1984-04-02")
|
||||||
|
def test_extract_question_with_date_filter_from_relative_year(loaded_model):
|
||||||
|
# Act
|
||||||
|
response = extract_questions_falcon(
|
||||||
|
"Which countries have I visited this year?", loaded_model=loaded_model, run_extraction=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(response) >= 1
|
||||||
|
assert response[-1] == "Which countries have I visited this year?"
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_extract_multiple_explicit_questions_from_message(loaded_model):
|
||||||
|
# Act
|
||||||
|
response = extract_questions_falcon("What is the Sun? What is the Moon?", loaded_model=loaded_model)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
expected_responses = ["What is the Sun?", "What is the Moon?"]
|
||||||
|
assert len(response) == 2
|
||||||
|
assert expected_responses == response
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_extract_multiple_implicit_questions_from_message(loaded_model):
|
||||||
|
# Act
|
||||||
|
response = extract_questions_falcon("Is Morpheus taller than Neo?", loaded_model=loaded_model, run_extraction=True)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
expected_responses = [
|
||||||
|
("morpheus", "neo"),
|
||||||
|
]
|
||||||
|
assert len(response) == 2
|
||||||
|
assert any([start in response[0].lower() and end in response[1].lower() for start, end in expected_responses]), (
|
||||||
|
"Expected two search queries in response but got: " + response[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_generate_search_query_using_question_from_chat_history(loaded_model):
|
||||||
|
# Arrange
|
||||||
|
message_list = [
|
||||||
|
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = extract_questions_falcon(
|
||||||
|
"Does he have any sons?",
|
||||||
|
conversation_log=populate_chat_history(message_list),
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
run_extraction=True,
|
||||||
|
use_history=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_responses = [
|
||||||
|
"do not have",
|
||||||
|
"clarify",
|
||||||
|
"am sorry",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(response) >= 1
|
||||||
|
assert any([expected_response in response[0] for expected_response in expected_responses]), (
|
||||||
|
"Expected chat actor to ask for clarification in response, but got: " + response[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.xfail(reason="Chat actor does not consistently follow template instructions.")
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_generate_search_query_using_answer_from_chat_history(loaded_model):
|
||||||
|
# Arrange
|
||||||
|
message_list = [
|
||||||
|
("What is the name of Mr. Vader's daughter?", "Princess Leia", []),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = extract_questions_falcon(
|
||||||
|
"Is she a Jedi?",
|
||||||
|
conversation_log=populate_chat_history(message_list),
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
run_extraction=True,
|
||||||
|
use_history=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(response) == 1
|
||||||
|
assert "Leia" in response[0]
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.xfail(reason="Chat actor is not sufficiently date-aware")
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_generate_search_query_with_date_and_context_from_chat_history(loaded_model):
|
||||||
|
# Arrange
|
||||||
|
message_list = [
|
||||||
|
("When did I visit Masai Mara?", "You visited Masai Mara in April 2000", []),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = extract_questions_falcon(
|
||||||
|
"What was the Pizza place we ate at over there?",
|
||||||
|
conversation_log=populate_chat_history(message_list),
|
||||||
|
run_extraction=True,
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
expected_responses = [
|
||||||
|
("dt>='2000-04-01'", "dt<'2000-05-01'"),
|
||||||
|
("dt>='2000-04-01'", "dt<='2000-04-30'"),
|
||||||
|
('dt>="2000-04-01"', 'dt<"2000-05-01"'),
|
||||||
|
('dt>="2000-04-01"', 'dt<="2000-04-30"'),
|
||||||
|
]
|
||||||
|
assert len(response) == 1
|
||||||
|
assert "Masai Mara" in response[0]
|
||||||
|
assert any([start in response[0] and end in response[0] for start, end in expected_responses]), (
|
||||||
|
"Expected date filter to limit to April 2000 in response but got: " + response[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_chat_with_no_chat_history_or_retrieved_content(loaded_model):
|
||||||
|
# Act
|
||||||
|
response_gen = converse_falcon(
|
||||||
|
references=[], # Assume no context retrieved from notes for the user_query
|
||||||
|
user_query="Hello, my name is Testatron. Who are you?",
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
)
|
||||||
|
response = "".join([response_chunk for response_chunk in response_gen])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
expected_responses = ["Khoj", "khoj", "khooj", "Khooj", "KHOJ"]
|
||||||
|
assert len(response) > 0
|
||||||
|
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||||
|
"Expected assistants name, [K|k]hoj, in response but got: " + response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.xfail(reason="Chat actor isn't really good at proper nouns yet.")
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_answer_from_chat_history_and_previously_retrieved_content(loaded_model):
|
||||||
|
"Chat actor needs to use context in previous notes and chat history to answer question"
|
||||||
|
# Arrange
|
||||||
|
message_list = [
|
||||||
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
|
(
|
||||||
|
"When was I born?",
|
||||||
|
"You were born on 1st April 1984.",
|
||||||
|
["Testatron was born on 1st April 1984 in Testville."],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response_gen = converse_falcon(
|
||||||
|
references=[], # Assume no context retrieved from notes for the user_query
|
||||||
|
user_query="Where was I born?",
|
||||||
|
conversation_log=populate_chat_history(message_list),
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
)
|
||||||
|
response = "".join([response_chunk for response_chunk in response_gen])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(response) > 0
|
||||||
|
# Infer who I am and use that to infer I was born in Testville using chat history and previously retrieved notes
|
||||||
|
assert "Testville" in response
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_answer_from_chat_history_and_currently_retrieved_content(loaded_model):
|
||||||
|
"Chat actor needs to use context across currently retrieved notes and chat history to answer question"
|
||||||
|
# Arrange
|
||||||
|
message_list = [
|
||||||
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response_gen = converse_falcon(
|
||||||
|
references=[
|
||||||
|
"Testatron was born on 1st April 1984 in Testville."
|
||||||
|
], # Assume context retrieved from notes for the user_query
|
||||||
|
user_query="Where was I born?",
|
||||||
|
conversation_log=populate_chat_history(message_list),
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
)
|
||||||
|
response = "".join([response_chunk for response_chunk in response_gen])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(response) > 0
|
||||||
|
assert "Testville" in response
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.xfail(reason="Chat actor is rather liable to lying.")
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_refuse_answering_unanswerable_question(loaded_model):
|
||||||
|
"Chat actor should not try make up answers to unanswerable questions."
|
||||||
|
# Arrange
|
||||||
|
message_list = [
|
||||||
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response_gen = converse_falcon(
|
||||||
|
references=[], # Assume no context retrieved from notes for the user_query
|
||||||
|
user_query="Where was I born?",
|
||||||
|
conversation_log=populate_chat_history(message_list),
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
)
|
||||||
|
response = "".join([response_chunk for response_chunk in response_gen])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
expected_responses = [
|
||||||
|
"don't know",
|
||||||
|
"do not know",
|
||||||
|
"no information",
|
||||||
|
"do not have",
|
||||||
|
"don't have",
|
||||||
|
"cannot answer",
|
||||||
|
"I'm sorry",
|
||||||
|
]
|
||||||
|
assert len(response) > 0
|
||||||
|
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||||
|
"Expected chat actor to say they don't know in response, but got: " + response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_answer_requires_current_date_awareness(loaded_model):
|
||||||
|
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||||
|
# Arrange
|
||||||
|
context = [
|
||||||
|
f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
|
||||||
|
Expenses:Food:Dining 10.00 USD""",
|
||||||
|
f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
|
||||||
|
Expenses:Food:Dining 10.00 USD""",
|
||||||
|
f"""2020-04-01 "SuperMercado" "Bananas"
|
||||||
|
Expenses:Food:Groceries 10.00 USD""",
|
||||||
|
f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||||
|
Expenses:Food:Dining 10.00 USD""",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response_gen = converse_falcon(
|
||||||
|
references=context, # Assume context retrieved from notes for the user_query
|
||||||
|
user_query="What did I have for Dinner today?",
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
)
|
||||||
|
response = "".join([response_chunk for response_chunk in response_gen])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
expected_responses = ["tacos", "Tacos"]
|
||||||
|
assert len(response) > 0
|
||||||
|
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||||
|
"Expected [T|t]acos in response, but got: " + response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_answer_requires_date_aware_aggregation_across_provided_notes(loaded_model):
|
||||||
|
"Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
|
||||||
|
# Arrange
|
||||||
|
context = [
|
||||||
|
f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
|
||||||
|
Expenses:Food:Dining 10.00 USD""",
|
||||||
|
f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
|
||||||
|
Expenses:Food:Dining 10.00 USD""",
|
||||||
|
f"""2020-04-01 "SuperMercado" "Bananas"
|
||||||
|
Expenses:Food:Groceries 10.00 USD""",
|
||||||
|
f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||||
|
Expenses:Food:Dining 10.00 USD""",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response_gen = converse_falcon(
|
||||||
|
references=context, # Assume context retrieved from notes for the user_query
|
||||||
|
user_query="How much did I spend on dining this year?",
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
)
|
||||||
|
response = "".join([response_chunk for response_chunk in response_gen])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert len(response) > 0
|
||||||
|
assert "20" in response
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_answer_general_question_not_in_chat_history_or_retrieved_content(loaded_model):
|
||||||
|
"Chat actor should be able to answer general questions not requiring looking at chat history or notes"
|
||||||
|
# Arrange
|
||||||
|
message_list = [
|
||||||
|
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||||
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
|
("Where was I born?", "You were born Testville.", []),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response_gen = converse_falcon(
|
||||||
|
references=[], # Assume no context retrieved from notes for the user_query
|
||||||
|
user_query="Write a haiku about unit testing in 3 lines",
|
||||||
|
conversation_log=populate_chat_history(message_list),
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
)
|
||||||
|
response = "".join([response_chunk for response_chunk in response_gen])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
expected_responses = ["test", "testing"]
|
||||||
|
assert len(response.splitlines()) >= 3 # haikus are 3 lines long, but Falcon tends to add a lot of new lines.
|
||||||
|
assert any([expected_response in response.lower() for expected_response in expected_responses]), (
|
||||||
|
"Expected [T|t]est in response, but got: " + response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.xfail(reason="Chat actor not consistently capable of asking for clarification yet.")
|
||||||
|
@pytest.mark.chatquality
|
||||||
|
def test_ask_for_clarification_if_not_enough_context_in_question(loaded_model):
|
||||||
|
"Chat actor should ask for clarification if question cannot be answered unambiguously with the provided context"
|
||||||
|
# Arrange
|
||||||
|
context = [
|
||||||
|
f"""# Ramya
|
||||||
|
My sister, Ramya, is married to Kali Devi. They have 2 kids, Ravi and Rani.""",
|
||||||
|
f"""# Fang
|
||||||
|
My sister, Fang Liu is married to Xi Li. They have 1 kid, Xiao Li.""",
|
||||||
|
f"""# Aiyla
|
||||||
|
My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response_gen = converse_falcon(
|
||||||
|
references=context, # Assume context retrieved from notes for the user_query
|
||||||
|
user_query="How many kids does my older sister have?",
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
)
|
||||||
|
response = "".join([response_chunk for response_chunk in response_gen])
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"]
|
||||||
|
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||||
|
"Expected chat actor to ask for clarification in response, but got: " + response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Helpers
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
def populate_chat_history(message_list):
|
||||||
|
# Generate conversation logs
|
||||||
|
conversation_log = {"chat": []}
|
||||||
|
for user_message, chat_response, context in message_list:
|
||||||
|
message_to_log(
|
||||||
|
user_message,
|
||||||
|
chat_response,
|
||||||
|
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||||
|
conversation_log=conversation_log["chat"],
|
||||||
|
)
|
||||||
|
return conversation_log
|
|
@ -8,7 +8,7 @@ import freezegun
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.processor.conversation.gpt import converse, extract_questions
|
from khoj.processor.conversation.openai.gpt import converse, extract_questions
|
||||||
from khoj.processor.conversation.utils import message_to_log
|
from khoj.processor.conversation.utils import message_to_log
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue