mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Merge branch 'master' of github.com:khoj-ai/khoj into features/advanced-reasoning
This commit is contained in:
commit
30f9225021
85 changed files with 5132 additions and 3202 deletions
|
@ -38,8 +38,8 @@
|
||||||
- Chat with any local or online LLM (e.g llama3, qwen, gemma, mistral, gpt, claude, gemini).
|
- Chat with any local or online LLM (e.g llama3, qwen, gemma, mistral, gpt, claude, gemini).
|
||||||
- Get answers from the internet and your docs (including image, pdf, markdown, org-mode, word, notion files).
|
- Get answers from the internet and your docs (including image, pdf, markdown, org-mode, word, notion files).
|
||||||
- Access it from your Browser, Obsidian, Emacs, Desktop, Phone or Whatsapp.
|
- Access it from your Browser, Obsidian, Emacs, Desktop, Phone or Whatsapp.
|
||||||
- Build agents with custom knowledge bases and tools.
|
- Create agents with custom knowledge, persona, chat model and tools to take on any role.
|
||||||
- Create automations to get personal newsletters and smart notifications.
|
- Automate away repetitive research. Get personal newsletters and smart notifications delivered to your inbox.
|
||||||
- Find relevant docs quickly and easily using our advanced semantic search.
|
- Find relevant docs quickly and easily using our advanced semantic search.
|
||||||
- Generate images, talk out loud, play your messages.
|
- Generate images, talk out loud, play your messages.
|
||||||
- Khoj is open-source, self-hostable. Always.
|
- Khoj is open-source, self-hostable. Always.
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
# Admin Panel
|
# Admin Panel
|
||||||
> Describes the Khoj settings configurable via the admin panel
|
> Describes the Khoj settings configurable via the admin panel
|
||||||
|
|
||||||
|
By default, you admin panel is available at `http://localhost:42110/server/admin/`. You can access the admin panel by logging in with your admin credentials (this would be your `KHOJ_ADMIN_EMAIL` and `KHOJ_ADMIN_PASSWORD`). The admin panel allows you to configure various settings for your Khoj server.
|
||||||
|
|
||||||
## App Settings
|
## App Settings
|
||||||
### Agents
|
### Agents
|
||||||
Add all the agents you want to use for your different use-cases like Writer, Researcher, Therapist etc.
|
Add all the agents you want to use for your different use-cases like Writer, Researcher, Therapist etc.
|
||||||
|
|
|
@ -12,7 +12,7 @@ Without any desktop clients, you can start chatting with Khoj on WhatsApp. Bear
|
||||||
|
|
||||||
In order to use Khoj on WhatsApp with your own data, you need to setup a Khoj Cloud account and connect your WhatsApp account to it. This is a one time setup and you can do it from the [Khoj Cloud config page](https://app.khoj.dev/settings).
|
In order to use Khoj on WhatsApp with your own data, you need to setup a Khoj Cloud account and connect your WhatsApp account to it. This is a one time setup and you can do it from the [Khoj Cloud config page](https://app.khoj.dev/settings).
|
||||||
|
|
||||||
If you hit usage limits for the WhatsApp bot, upgrade to [a paid plan](https://khoj.dev/pricing) on Khoj Cloud.
|
If you hit usage limits for the WhatsApp bot, upgrade to [a paid plan](https://khoj.dev/#pricing) on Khoj Cloud.
|
||||||
|
|
||||||
<img src="https://khoj-web-bucket.s3.amazonaws.com/khojwhatsapp.png" alt="WhatsApp QR Code" width="300" height="300" />
|
<img src="https://khoj-web-bucket.s3.amazonaws.com/khojwhatsapp.png" alt="WhatsApp QR Code" width="300" height="300" />
|
||||||
|
|
||||||
|
|
|
@ -102,7 +102,19 @@ sudo -u postgres createdb khoj --password
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
#### 3. Run
|
#### 3. Build the front-end assets
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cd src/interface/web/
|
||||||
|
yarn install
|
||||||
|
yarn export
|
||||||
|
```
|
||||||
|
|
||||||
|
You can optionally use `yarn dev` to start a development server for the front-end which will be available at http://localhost:3000. This is especially useful if you're making changes to the front-end code, but not necessary for running Khoj. Note that streaming does not work on the dev server due to how it is handled with SSR in Next.js.
|
||||||
|
|
||||||
|
Always run `yarn export` to test your front-end changes on http://localhost:42110 before creating a PR.
|
||||||
|
|
||||||
|
#### 4. Run
|
||||||
1. Start Khoj
|
1. Start Khoj
|
||||||
```bash
|
```bash
|
||||||
khoj -vv
|
khoj -vv
|
||||||
|
|
|
@ -40,7 +40,7 @@ If you want to use the offline chat model and you have a GPU, you should use Ins
|
||||||
</TabItem>
|
</TabItem>
|
||||||
<TabItem value="linux" label="Linux">
|
<TabItem value="linux" label="Linux">
|
||||||
<h3>Prerequisites</h3>
|
<h3>Prerequisites</h3>
|
||||||
Install [Docker Desktop](https://docs.docker.com/desktop/install/windows-install/).
|
Install [Docker Desktop](https://docs.docker.com/desktop/install/linux/).
|
||||||
You can also use your package manager to install Docker Engine & Docker Compose.
|
You can also use your package manager to install Docker Engine & Docker Compose.
|
||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
@ -240,7 +240,7 @@ Ensure you are using **localhost, not 127.0.0.1**, to access the admin panel to
|
||||||
|
|
||||||
:::info[DISALLOWED HOST or Bad Request (400) Error]
|
:::info[DISALLOWED HOST or Bad Request (400) Error]
|
||||||
You may hit this if you try access Khoj exposed on a custom domain (e.g. 192.168.12.3 or example.com) or over HTTP.
|
You may hit this if you try access Khoj exposed on a custom domain (e.g. 192.168.12.3 or example.com) or over HTTP.
|
||||||
Set the environment variables KHOJ_DOMAIN=your-domain and KHOJ_NO_HTTPS=false if required to avoid this error.
|
Set the environment variables KHOJ_DOMAIN=your-domain and KHOJ_NO_HTTPS=True if required to avoid this error.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
:::tip[Note]
|
:::tip[Note]
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"id": "khoj",
|
"id": "khoj",
|
||||||
"name": "Khoj",
|
"name": "Khoj",
|
||||||
"version": "1.25.0",
|
"version": "1.26.4",
|
||||||
"minAppVersion": "0.15.0",
|
"minAppVersion": "0.15.0",
|
||||||
"description": "Your Second Brain",
|
"description": "Your Second Brain",
|
||||||
"author": "Khoj Inc.",
|
"author": "Khoj Inc.",
|
||||||
|
|
|
@ -62,8 +62,8 @@ dependencies = [
|
||||||
"requests >= 2.26.0",
|
"requests >= 2.26.0",
|
||||||
"tenacity == 8.3.0",
|
"tenacity == 8.3.0",
|
||||||
"anyio == 3.7.1",
|
"anyio == 3.7.1",
|
||||||
"pymupdf >= 1.23.5",
|
"pymupdf == 1.24.11",
|
||||||
"django == 5.0.8",
|
"django == 5.0.9",
|
||||||
"authlib == 1.2.1",
|
"authlib == 1.2.1",
|
||||||
"llama-cpp-python == 0.2.88",
|
"llama-cpp-python == 0.2.88",
|
||||||
"itsdangerous == 2.1.2",
|
"itsdangerous == 2.1.2",
|
||||||
|
|
|
@ -326,7 +326,7 @@
|
||||||
entries.forEach(entry => {
|
entries.forEach(entry => {
|
||||||
// If the element is in the viewport, fetch the remaining message and unobserve the element
|
// If the element is in the viewport, fetch the remaining message and unobserve the element
|
||||||
if (entry.isIntersecting) {
|
if (entry.isIntersecting) {
|
||||||
fetchRemainingChatMessages(chatHistoryUrl, headers);
|
fetchRemainingChatMessages(chatHistoryUrl, headers, chatBody.dataset.conversation_id, hostURL);
|
||||||
observer.unobserve(entry.target);
|
observer.unobserve(entry.target);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -342,7 +342,11 @@
|
||||||
new Date(chat_log.created),
|
new Date(chat_log.created),
|
||||||
chat_log.onlineContext,
|
chat_log.onlineContext,
|
||||||
chat_log.intent?.type,
|
chat_log.intent?.type,
|
||||||
chat_log.intent?.["inferred-queries"]);
|
chat_log.intent?.["inferred-queries"],
|
||||||
|
chatBody.dataset.conversationId ?? "",
|
||||||
|
hostURL,
|
||||||
|
);
|
||||||
|
|
||||||
chatBody.appendChild(messageElement);
|
chatBody.appendChild(messageElement);
|
||||||
|
|
||||||
// When the 4th oldest message is within viewing distance (~60% scrolled up)
|
// When the 4th oldest message is within viewing distance (~60% scrolled up)
|
||||||
|
@ -421,7 +425,7 @@
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function fetchRemainingChatMessages(chatHistoryUrl, headers) {
|
function fetchRemainingChatMessages(chatHistoryUrl, headers, conversationId, hostURL) {
|
||||||
// Create a new IntersectionObserver
|
// Create a new IntersectionObserver
|
||||||
let observer = new IntersectionObserver((entries, observer) => {
|
let observer = new IntersectionObserver((entries, observer) => {
|
||||||
entries.forEach(entry => {
|
entries.forEach(entry => {
|
||||||
|
@ -435,7 +439,9 @@
|
||||||
new Date(chat_log.created),
|
new Date(chat_log.created),
|
||||||
chat_log.onlineContext,
|
chat_log.onlineContext,
|
||||||
chat_log.intent?.type,
|
chat_log.intent?.type,
|
||||||
chat_log.intent?.["inferred-queries"]
|
chat_log.intent?.["inferred-queries"],
|
||||||
|
chatBody.dataset.conversationId ?? "",
|
||||||
|
hostURL,
|
||||||
);
|
);
|
||||||
entry.target.replaceWith(messageElement);
|
entry.target.replaceWith(messageElement);
|
||||||
|
|
||||||
|
|
|
@ -189,11 +189,19 @@ function processOnlineReferences(referenceSection, onlineContext) { //same
|
||||||
return numOnlineReferences;
|
return numOnlineReferences;
|
||||||
}
|
}
|
||||||
|
|
||||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { //same
|
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null, conversationId=null, hostURL=null) {
|
||||||
let chatEl;
|
let chatEl;
|
||||||
if (intentType?.includes("text-to-image")) {
|
if (intentType?.includes("text-to-image")) {
|
||||||
let imageMarkdown = generateImageMarkdown(message, intentType, inferredQueries);
|
let imageMarkdown = generateImageMarkdown(message, intentType, inferredQueries);
|
||||||
chatEl = renderMessage(imageMarkdown, by, dt, null, false, "return");
|
chatEl = renderMessage(imageMarkdown, by, dt, null, false, "return");
|
||||||
|
} else if (intentType === "excalidraw") {
|
||||||
|
let domain = hostURL ?? "https://app.khoj.dev/";
|
||||||
|
|
||||||
|
if (!domain.endsWith("/")) domain += "/";
|
||||||
|
|
||||||
|
let excalidrawMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in the web app at ${domain}chat?conversationId=${conversationId}`;
|
||||||
|
|
||||||
|
chatEl = renderMessage(excalidrawMessage, by, dt, null, false, "return");
|
||||||
} else {
|
} else {
|
||||||
chatEl = renderMessage(message, by, dt, null, false, "return");
|
chatEl = renderMessage(message, by, dt, null, false, "return");
|
||||||
}
|
}
|
||||||
|
@ -312,7 +320,6 @@ function formatHTMLMessage(message, raw=false, willReplace=true) { //same
|
||||||
}
|
}
|
||||||
|
|
||||||
function createReferenceSection(references, createLinkerSection=false) {
|
function createReferenceSection(references, createLinkerSection=false) {
|
||||||
console.log("linker data: ", createLinkerSection);
|
|
||||||
let referenceSection = document.createElement('div');
|
let referenceSection = document.createElement('div');
|
||||||
referenceSection.classList.add("reference-section");
|
referenceSection.classList.add("reference-section");
|
||||||
referenceSection.classList.add("collapsed");
|
referenceSection.classList.add("collapsed");
|
||||||
|
@ -417,7 +424,11 @@ function handleImageResponse(imageJson, rawResponse) {
|
||||||
rawResponse += `![generated_image](${imageJson.image})`;
|
rawResponse += `![generated_image](${imageJson.image})`;
|
||||||
} else if (imageJson.intentType === "text-to-image-v3") {
|
} else if (imageJson.intentType === "text-to-image-v3") {
|
||||||
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
|
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
|
||||||
|
} else if (imageJson.intentType === "excalidraw") {
|
||||||
|
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in the web app`;
|
||||||
|
rawResponse += redirectMessage;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (inferredQuery) {
|
if (inferredQuery) {
|
||||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ const textFileTypes = [
|
||||||
'org', 'md', 'markdown', 'txt', 'html', 'xml',
|
'org', 'md', 'markdown', 'txt', 'html', 'xml',
|
||||||
// Other valid text file extensions from https://google.github.io/magika/model/config.json
|
// Other valid text file extensions from https://google.github.io/magika/model/config.json
|
||||||
'appleplist', 'asm', 'asp', 'batch', 'c', 'cs', 'css', 'csv', 'eml', 'go', 'html', 'ini', 'internetshortcut', 'java', 'javascript', 'json', 'latex', 'lisp', 'makefile', 'markdown', 'mht', 'mum', 'pem', 'perl', 'php', 'powershell', 'python', 'rdf', 'rst', 'rtf', 'ruby', 'rust', 'scala', 'shell', 'smali', 'sql', 'svg', 'symlinktext', 'txt', 'vba', 'winregistry', 'xml', 'yaml']
|
'appleplist', 'asm', 'asp', 'batch', 'c', 'cs', 'css', 'csv', 'eml', 'go', 'html', 'ini', 'internetshortcut', 'java', 'javascript', 'json', 'latex', 'lisp', 'makefile', 'markdown', 'mht', 'mum', 'pem', 'perl', 'php', 'powershell', 'python', 'rdf', 'rst', 'rtf', 'ruby', 'rust', 'scala', 'shell', 'smali', 'sql', 'svg', 'symlinktext', 'txt', 'vba', 'winregistry', 'xml', 'yaml']
|
||||||
const binaryFileTypes = ['pdf', 'jpg', 'jpeg', 'png']
|
const binaryFileTypes = ['pdf', 'jpg', 'jpeg', 'png', 'webp']
|
||||||
const validFileTypes = textFileTypes.concat(binaryFileTypes);
|
const validFileTypes = textFileTypes.concat(binaryFileTypes);
|
||||||
|
|
||||||
const schema = {
|
const schema = {
|
||||||
|
@ -104,6 +104,8 @@ function filenameToMimeType (filename) {
|
||||||
case 'jpg':
|
case 'jpg':
|
||||||
case 'jpeg':
|
case 'jpeg':
|
||||||
return 'image/jpeg';
|
return 'image/jpeg';
|
||||||
|
case 'webp':
|
||||||
|
return 'image/webp';
|
||||||
case 'md':
|
case 'md':
|
||||||
case 'markdown':
|
case 'markdown':
|
||||||
return 'text/markdown';
|
return 'text/markdown';
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "Khoj",
|
"name": "Khoj",
|
||||||
"version": "1.25.0",
|
"version": "1.26.4",
|
||||||
"description": "Your Second Brain",
|
"description": "Your Second Brain",
|
||||||
"author": "Khoj Inc. <team@khoj.dev>",
|
"author": "Khoj Inc. <team@khoj.dev>",
|
||||||
"license": "GPL-3.0-or-later",
|
"license": "GPL-3.0-or-later",
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
;; Saba Imran <saba@khoj.dev>
|
;; Saba Imran <saba@khoj.dev>
|
||||||
;; Description: Your Second Brain
|
;; Description: Your Second Brain
|
||||||
;; Keywords: search, chat, ai, org-mode, outlines, markdown, pdf, image
|
;; Keywords: search, chat, ai, org-mode, outlines, markdown, pdf, image
|
||||||
;; Version: 1.25.0
|
;; Version: 1.26.4
|
||||||
;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1"))
|
;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1"))
|
||||||
;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs
|
;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs
|
||||||
|
|
||||||
|
@ -127,6 +127,11 @@
|
||||||
(const "image")
|
(const "image")
|
||||||
(const "pdf")))
|
(const "pdf")))
|
||||||
|
|
||||||
|
(defcustom khoj-default-agent "khoj"
|
||||||
|
"The default agent to chat with. See https://app.khoj.dev/agents for available options."
|
||||||
|
:group 'khoj
|
||||||
|
:type 'string)
|
||||||
|
|
||||||
|
|
||||||
;; --------------------------
|
;; --------------------------
|
||||||
;; Khoj Dynamic Configuration
|
;; Khoj Dynamic Configuration
|
||||||
|
@ -144,6 +149,9 @@
|
||||||
(defconst khoj--chat-buffer-name "*🏮 Khoj Chat*"
|
(defconst khoj--chat-buffer-name "*🏮 Khoj Chat*"
|
||||||
"Name of chat buffer for Khoj.")
|
"Name of chat buffer for Khoj.")
|
||||||
|
|
||||||
|
(defvar khoj--selected-agent khoj-default-agent
|
||||||
|
"Currently selected Khoj agent.")
|
||||||
|
|
||||||
(defvar khoj--content-type "org"
|
(defvar khoj--content-type "org"
|
||||||
"The type of content to perform search on.")
|
"The type of content to perform search on.")
|
||||||
|
|
||||||
|
@ -656,13 +664,15 @@ Simplified fork of `org-cycle-content' from Emacs 29.1 to work with >=27.1."
|
||||||
;; --------------
|
;; --------------
|
||||||
;; Query Khoj API
|
;; Query Khoj API
|
||||||
;; --------------
|
;; --------------
|
||||||
(defun khoj--call-api (path &optional method params callback &rest cbargs)
|
(defun khoj--call-api (path &optional method params body callback &rest cbargs)
|
||||||
"Sync call API at PATH with METHOD and query PARAMS as kv assoc list.
|
"Sync call API at PATH with METHOD, query PARAMS and BODY as kv assoc list.
|
||||||
Optionally apply CALLBACK with JSON parsed response and CBARGS."
|
Optionally apply CALLBACK with JSON parsed response and CBARGS."
|
||||||
(let* ((url-request-method (or method "GET"))
|
(let* ((url-request-method (or method "GET"))
|
||||||
(url-request-extra-headers `(("Authorization" . ,(format "Bearer %s" khoj-api-key))))
|
(url-request-extra-headers `(("Authorization" . ,(format "Bearer %s" khoj-api-key))))
|
||||||
(param-string (if params (url-build-query-string params) ""))
|
(url-request-extra-headers `(("Authorization" . ,(format "Bearer %s" khoj-api-key)) ("Content-Type" . "application/json")))
|
||||||
(query-url (format "%s%s?%s&client=emacs" khoj-server-url path param-string))
|
(url-request-data (if body (json-encode body) nil))
|
||||||
|
(param-string (url-build-query-string (append params '((client "emacs")))))
|
||||||
|
(query-url (format "%s%s?%s" khoj-server-url path param-string))
|
||||||
(cbargs (if (and (listp cbargs) (listp (car cbargs))) (car cbargs) cbargs))) ; normalize cbargs to (a b) from ((a b)) if required
|
(cbargs (if (and (listp cbargs) (listp (car cbargs))) (car cbargs) cbargs))) ; normalize cbargs to (a b) from ((a b)) if required
|
||||||
(with-temp-buffer
|
(with-temp-buffer
|
||||||
(condition-case ex
|
(condition-case ex
|
||||||
|
@ -682,8 +692,8 @@ Optionally apply CALLBACK with JSON parsed response and CBARGS."
|
||||||
(url-request-extra-headers `(("Authorization" . ,(format "Bearer %s" khoj-api-key)) ("Content-Type" . "application/json")))
|
(url-request-extra-headers `(("Authorization" . ,(format "Bearer %s" khoj-api-key)) ("Content-Type" . "application/json")))
|
||||||
(url-request-data (if body (json-encode body) nil))
|
(url-request-data (if body (json-encode body) nil))
|
||||||
(param-string (url-build-query-string (append params '((client "emacs")))))
|
(param-string (url-build-query-string (append params '((client "emacs")))))
|
||||||
(cbargs (if (and (listp cbargs) (listp (car cbargs))) (car cbargs) cbargs)) ; normalize cbargs to (a b) from ((a b)) if required
|
(query-url (format "%s%s?%s" khoj-server-url path param-string))
|
||||||
(query-url (format "%s%s?%s" khoj-server-url path param-string)))
|
(cbargs (if (and (listp cbargs) (listp (car cbargs))) (car cbargs) cbargs))) ; normalize cbargs to (a b) from ((a b)) if required
|
||||||
(url-retrieve query-url
|
(url-retrieve query-url
|
||||||
(lambda (status)
|
(lambda (status)
|
||||||
(if (plist-get status :error)
|
(if (plist-get status :error)
|
||||||
|
@ -699,7 +709,7 @@ Optionally apply CALLBACK with JSON parsed response and CBARGS."
|
||||||
|
|
||||||
(defun khoj--get-enabled-content-types ()
|
(defun khoj--get-enabled-content-types ()
|
||||||
"Get content types enabled for search from API."
|
"Get content types enabled for search from API."
|
||||||
(khoj--call-api "/api/content/types" "GET" nil `(lambda (item) (mapcar #'intern item))))
|
(khoj--call-api "/api/content/types" "GET" nil nil `(lambda (item) (mapcar #'intern item))))
|
||||||
|
|
||||||
(defun khoj--query-search-api-and-render-results (query content-type buffer-name &optional rerank is-find-similar)
|
(defun khoj--query-search-api-and-render-results (query content-type buffer-name &optional rerank is-find-similar)
|
||||||
"Query Khoj Search API with QUERY, CONTENT-TYPE and RERANK as query params.
|
"Query Khoj Search API with QUERY, CONTENT-TYPE and RERANK as query params.
|
||||||
|
@ -913,14 +923,16 @@ Call CALLBACK func with response and CBARGS."
|
||||||
(let ((selected-session-id (khoj--select-conversation-session "Open")))
|
(let ((selected-session-id (khoj--select-conversation-session "Open")))
|
||||||
(khoj--load-chat-session khoj--chat-buffer-name selected-session-id)))
|
(khoj--load-chat-session khoj--chat-buffer-name selected-session-id)))
|
||||||
|
|
||||||
(defun khoj--create-chat-session ()
|
(defun khoj--create-chat-session (&optional agent)
|
||||||
"Create new chat session."
|
"Create new chat session with AGENT."
|
||||||
(khoj--call-api "/api/chat/sessions" "POST"))
|
(khoj--call-api "/api/chat/sessions"
|
||||||
|
"POST"
|
||||||
|
(when agent `(("agent_slug" ,agent)))))
|
||||||
|
|
||||||
(defun khoj--new-conversation-session ()
|
(defun khoj--new-conversation-session (&optional agent)
|
||||||
"Create new Khoj conversation session."
|
"Create new Khoj conversation session with AGENT."
|
||||||
(thread-last
|
(thread-last
|
||||||
(khoj--create-chat-session)
|
(khoj--create-chat-session agent)
|
||||||
(assoc 'conversation_id)
|
(assoc 'conversation_id)
|
||||||
(cdr)
|
(cdr)
|
||||||
(khoj--chat)))
|
(khoj--chat)))
|
||||||
|
@ -935,6 +947,15 @@ Call CALLBACK func with response and CBARGS."
|
||||||
(khoj--select-conversation-session "Delete")
|
(khoj--select-conversation-session "Delete")
|
||||||
(khoj--delete-chat-session)))
|
(khoj--delete-chat-session)))
|
||||||
|
|
||||||
|
(defun khoj--get-agents ()
|
||||||
|
"Get list of available Khoj agents."
|
||||||
|
(let* ((response (khoj--call-api "/api/agents" "GET"))
|
||||||
|
(agents (mapcar (lambda (agent)
|
||||||
|
(cons (cdr (assoc 'name agent))
|
||||||
|
(cdr (assoc 'slug agent))))
|
||||||
|
response)))
|
||||||
|
agents))
|
||||||
|
|
||||||
(defun khoj--render-chat-message (message sender &optional receive-date)
|
(defun khoj--render-chat-message (message sender &optional receive-date)
|
||||||
"Render chat messages as `org-mode' list item.
|
"Render chat messages as `org-mode' list item.
|
||||||
MESSAGE is the text of the chat message.
|
MESSAGE is the text of the chat message.
|
||||||
|
@ -1246,6 +1267,20 @@ Paragraph only starts at first text after blank line."
|
||||||
;; dynamically set choices to content types enabled on khoj backend
|
;; dynamically set choices to content types enabled on khoj backend
|
||||||
:choices (or (ignore-errors (mapcar #'symbol-name (khoj--get-enabled-content-types))) '("all" "org" "markdown" "pdf" "image")))
|
:choices (or (ignore-errors (mapcar #'symbol-name (khoj--get-enabled-content-types))) '("all" "org" "markdown" "pdf" "image")))
|
||||||
|
|
||||||
|
(transient-define-argument khoj--agent-switch ()
|
||||||
|
:class 'transient-switches
|
||||||
|
:argument-format "--agent=%s"
|
||||||
|
:argument-regexp ".+"
|
||||||
|
:init-value (lambda (obj)
|
||||||
|
(oset obj value (format "--agent=%s" khoj--selected-agent)))
|
||||||
|
:choices (or (ignore-errors (mapcar #'cdr (khoj--get-agents))) '("khoj"))
|
||||||
|
:reader (lambda (prompt initial-input history)
|
||||||
|
(let* ((agents (khoj--get-agents))
|
||||||
|
(selected (completing-read prompt agents nil t initial-input history))
|
||||||
|
(slug (cdr (assoc selected agents))))
|
||||||
|
(setq khoj--selected-agent slug)
|
||||||
|
slug)))
|
||||||
|
|
||||||
(transient-define-suffix khoj--search-command (&optional args)
|
(transient-define-suffix khoj--search-command (&optional args)
|
||||||
(interactive (list (transient-args transient-current-command)))
|
(interactive (list (transient-args transient-current-command)))
|
||||||
(progn
|
(progn
|
||||||
|
@ -1287,10 +1322,11 @@ Paragraph only starts at first text after blank line."
|
||||||
(interactive (list (transient-args transient-current-command)))
|
(interactive (list (transient-args transient-current-command)))
|
||||||
(khoj--open-conversation-session))
|
(khoj--open-conversation-session))
|
||||||
|
|
||||||
(transient-define-suffix khoj--new-conversation-session-command (&optional _)
|
(transient-define-suffix khoj--new-conversation-session-command (&optional args)
|
||||||
"Command to select Khoj conversation sessions to open."
|
"Command to select Khoj conversation sessions to open."
|
||||||
(interactive (list (transient-args transient-current-command)))
|
(interactive (list (transient-args transient-current-command)))
|
||||||
(khoj--new-conversation-session))
|
(let ((agent-slug (transient-arg-value "--agent=" args)))
|
||||||
|
(khoj--new-conversation-session agent-slug)))
|
||||||
|
|
||||||
(transient-define-suffix khoj--delete-conversation-session-command (&optional _)
|
(transient-define-suffix khoj--delete-conversation-session-command (&optional _)
|
||||||
"Command to select Khoj conversation sessions to delete."
|
"Command to select Khoj conversation sessions to delete."
|
||||||
|
@ -1298,14 +1334,15 @@ Paragraph only starts at first text after blank line."
|
||||||
(khoj--delete-conversation-session))
|
(khoj--delete-conversation-session))
|
||||||
|
|
||||||
(transient-define-prefix khoj--chat-menu ()
|
(transient-define-prefix khoj--chat-menu ()
|
||||||
"Open the Khoj chat menu."
|
"Create the Khoj Chat Menu and Execute Commands."
|
||||||
["Act"
|
[["Configure"
|
||||||
("c" "Chat" khoj--chat-command)
|
("a" "Select Agent" khoj--agent-switch)]]
|
||||||
("o" "Open Conversation" khoj--open-conversation-session-command)
|
[["Act"
|
||||||
("n" "New Conversation" khoj--new-conversation-session-command)
|
("c" "Chat" khoj--chat-command)
|
||||||
("d" "Delete Conversation" khoj--delete-conversation-session-command)
|
("o" "Open Conversation" khoj--open-conversation-session-command)
|
||||||
("q" "Quit" transient-quit-one)
|
("n" "New Conversation" khoj--new-conversation-session-command)
|
||||||
])
|
("d" "Delete Conversation" khoj--delete-conversation-session-command)
|
||||||
|
("q" "Quit" transient-quit-one)]])
|
||||||
|
|
||||||
(transient-define-prefix khoj--menu ()
|
(transient-define-prefix khoj--menu ()
|
||||||
"Create Khoj Menu to Configure and Execute Commands."
|
"Create Khoj Menu to Configure and Execute Commands."
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
{
|
{
|
||||||
"id": "khoj",
|
"id": "khoj",
|
||||||
"name": "Khoj",
|
"name": "Khoj",
|
||||||
"version": "1.25.0",
|
"version": "1.26.4",
|
||||||
"minAppVersion": "0.15.0",
|
"minAppVersion": "0.15.0",
|
||||||
"description": "Your Second Brain",
|
"description": "Your Second Brain",
|
||||||
"author": "Khoj Inc.",
|
"author": "Khoj Inc.",
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "Khoj",
|
"name": "Khoj",
|
||||||
"version": "1.25.0",
|
"version": "1.26.4",
|
||||||
"description": "Your Second Brain",
|
"description": "Your Second Brain",
|
||||||
"author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>",
|
"author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>",
|
||||||
"license": "GPL-3.0-or-later",
|
"license": "GPL-3.0-or-later",
|
||||||
|
|
|
@ -484,12 +484,13 @@ export class KhojChatView extends KhojPaneView {
|
||||||
dt?: Date,
|
dt?: Date,
|
||||||
intentType?: string,
|
intentType?: string,
|
||||||
inferredQueries?: string[],
|
inferredQueries?: string[],
|
||||||
|
conversationId?: string,
|
||||||
) {
|
) {
|
||||||
if (!message) return;
|
if (!message) return;
|
||||||
|
|
||||||
let chatMessageEl;
|
let chatMessageEl;
|
||||||
if (intentType?.includes("text-to-image")) {
|
if (intentType?.includes("text-to-image") || intentType === "excalidraw") {
|
||||||
let imageMarkdown = this.generateImageMarkdown(message, intentType, inferredQueries);
|
let imageMarkdown = this.generateImageMarkdown(message, intentType, inferredQueries, conversationId);
|
||||||
chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt);
|
chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt);
|
||||||
} else {
|
} else {
|
||||||
chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
|
chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
|
||||||
|
@ -509,7 +510,7 @@ export class KhojChatView extends KhojPaneView {
|
||||||
chatMessageBodyEl.appendChild(this.createReferenceSection(references));
|
chatMessageBodyEl.appendChild(this.createReferenceSection(references));
|
||||||
}
|
}
|
||||||
|
|
||||||
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[]) {
|
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string): string {
|
||||||
let imageMarkdown = "";
|
let imageMarkdown = "";
|
||||||
if (intentType === "text-to-image") {
|
if (intentType === "text-to-image") {
|
||||||
imageMarkdown = `![](data:image/png;base64,${message})`;
|
imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||||
|
@ -517,6 +518,10 @@ export class KhojChatView extends KhojPaneView {
|
||||||
imageMarkdown = `![](${message})`;
|
imageMarkdown = `![](${message})`;
|
||||||
} else if (intentType === "text-to-image-v3") {
|
} else if (intentType === "text-to-image-v3") {
|
||||||
imageMarkdown = `![](data:image/webp;base64,${message})`;
|
imageMarkdown = `![](data:image/webp;base64,${message})`;
|
||||||
|
} else if (intentType === "excalidraw") {
|
||||||
|
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
|
||||||
|
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}chat?conversationId=${conversationId}`;
|
||||||
|
imageMarkdown = redirectMessage;
|
||||||
}
|
}
|
||||||
if (inferredQueries) {
|
if (inferredQueries) {
|
||||||
imageMarkdown += "\n\n**Inferred Query**:";
|
imageMarkdown += "\n\n**Inferred Query**:";
|
||||||
|
@ -884,6 +889,7 @@ export class KhojChatView extends KhojPaneView {
|
||||||
new Date(chatLog.created),
|
new Date(chatLog.created),
|
||||||
chatLog.intent?.type,
|
chatLog.intent?.type,
|
||||||
chatLog.intent?.["inferred-queries"],
|
chatLog.intent?.["inferred-queries"],
|
||||||
|
chatBodyEl.dataset.conversationId ?? "",
|
||||||
);
|
);
|
||||||
// push the user messages to the chat history
|
// push the user messages to the chat history
|
||||||
if(chatLog.by === "you"){
|
if(chatLog.by === "you"){
|
||||||
|
@ -1354,6 +1360,10 @@ export class KhojChatView extends KhojPaneView {
|
||||||
rawResponse += `![generated_image](${imageJson.image})`;
|
rawResponse += `![generated_image](${imageJson.image})`;
|
||||||
} else if (imageJson.intentType === "text-to-image-v3") {
|
} else if (imageJson.intentType === "text-to-image-v3") {
|
||||||
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
|
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
|
||||||
|
} else if (imageJson.intentType === "excalidraw") {
|
||||||
|
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
|
||||||
|
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}`;
|
||||||
|
rawResponse += redirectMessage;
|
||||||
}
|
}
|
||||||
if (inferredQuery) {
|
if (inferredQuery) {
|
||||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||||
|
|
|
@ -37,6 +37,8 @@ function filenameToMimeType (filename: TFile): string {
|
||||||
case 'jpg':
|
case 'jpg':
|
||||||
case 'jpeg':
|
case 'jpeg':
|
||||||
return 'image/jpeg';
|
return 'image/jpeg';
|
||||||
|
case 'webp':
|
||||||
|
return 'image/webp';
|
||||||
case 'md':
|
case 'md':
|
||||||
case 'markdown':
|
case 'markdown':
|
||||||
return 'text/markdown';
|
return 'text/markdown';
|
||||||
|
@ -50,7 +52,7 @@ function filenameToMimeType (filename: TFile): string {
|
||||||
|
|
||||||
export const fileTypeToExtension = {
|
export const fileTypeToExtension = {
|
||||||
'pdf': ['pdf'],
|
'pdf': ['pdf'],
|
||||||
'image': ['png', 'jpg', 'jpeg'],
|
'image': ['png', 'jpg', 'jpeg', 'webp'],
|
||||||
'markdown': ['md', 'markdown'],
|
'markdown': ['md', 'markdown'],
|
||||||
};
|
};
|
||||||
export const supportedImageFilesTypes = fileTypeToExtension.image;
|
export const supportedImageFilesTypes = fileTypeToExtension.image;
|
||||||
|
|
|
@ -77,5 +77,10 @@
|
||||||
"1.23.3": "0.15.0",
|
"1.23.3": "0.15.0",
|
||||||
"1.24.0": "0.15.0",
|
"1.24.0": "0.15.0",
|
||||||
"1.24.1": "0.15.0",
|
"1.24.1": "0.15.0",
|
||||||
"1.25.0": "0.15.0"
|
"1.25.0": "0.15.0",
|
||||||
|
"1.26.0": "0.15.0",
|
||||||
|
"1.26.1": "0.15.0",
|
||||||
|
"1.26.2": "0.15.0",
|
||||||
|
"1.26.3": "0.15.0",
|
||||||
|
"1.26.4": "0.15.0"
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -79,7 +79,7 @@ div.titleBar {
|
||||||
div.chatBoxBody {
|
div.chatBoxBody {
|
||||||
display: grid;
|
display: grid;
|
||||||
height: 100%;
|
height: 100%;
|
||||||
width: 70%;
|
width: 95%;
|
||||||
margin: auto;
|
margin: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,14 @@ export default function RootLayout({
|
||||||
child-src 'none';
|
child-src 'none';
|
||||||
object-src 'none';"
|
object-src 'none';"
|
||||||
></meta>
|
></meta>
|
||||||
<body className={inter.className}>{children}</body>
|
<body className={inter.className}>
|
||||||
|
{children}
|
||||||
|
<script
|
||||||
|
dangerouslySetInnerHTML={{
|
||||||
|
__html: `window.EXCALIDRAW_ASSET_PATH = 'https://assets.khoj.dev/@excalidraw/excalidraw/dist/';`,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</body>
|
||||||
</html>
|
</html>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import styles from "./chat.module.css";
|
import styles from "./chat.module.css";
|
||||||
import React, { Suspense, useEffect, useState } from "react";
|
import React, { Suspense, useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
import SidePanel, { ChatSessionActionMenu } from "../components/sidePanel/chatHistorySidePanel";
|
import SidePanel, { ChatSessionActionMenu } from "../components/sidePanel/chatHistorySidePanel";
|
||||||
import ChatHistory from "../components/chatHistory/chatHistory";
|
import ChatHistory from "../components/chatHistory/chatHistory";
|
||||||
|
@ -19,11 +19,9 @@ import {
|
||||||
StreamMessage,
|
StreamMessage,
|
||||||
} from "../components/chatMessage/chatMessage";
|
} from "../components/chatMessage/chatMessage";
|
||||||
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../common/utils";
|
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../common/utils";
|
||||||
import ChatInputArea, { ChatOptions } from "../components/chatInputArea/chatInputArea";
|
import { ChatInputArea, ChatOptions } from "../components/chatInputArea/chatInputArea";
|
||||||
import { useAuthenticatedData } from "../common/auth";
|
import { useAuthenticatedData } from "../common/auth";
|
||||||
import { AgentData } from "../agents/page";
|
import { AgentData } from "../agents/page";
|
||||||
import { DotsThreeVertical } from "@phosphor-icons/react";
|
|
||||||
import { Button } from "@/components/ui/button";
|
|
||||||
|
|
||||||
interface ChatBodyDataProps {
|
interface ChatBodyDataProps {
|
||||||
chatOptionsData: ChatOptions | null;
|
chatOptionsData: ChatOptions | null;
|
||||||
|
@ -34,32 +32,38 @@ interface ChatBodyDataProps {
|
||||||
setUploadedFiles: (files: string[]) => void;
|
setUploadedFiles: (files: string[]) => void;
|
||||||
isMobileWidth?: boolean;
|
isMobileWidth?: boolean;
|
||||||
isLoggedIn: boolean;
|
isLoggedIn: boolean;
|
||||||
setImage64: (image64: string) => void;
|
setImages: (images: string[]) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
function ChatBodyData(props: ChatBodyDataProps) {
|
function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
const searchParams = useSearchParams();
|
const searchParams = useSearchParams();
|
||||||
const conversationId = searchParams.get("conversationId");
|
const conversationId = searchParams.get("conversationId");
|
||||||
const [message, setMessage] = useState("");
|
const [message, setMessage] = useState("");
|
||||||
const [image, setImage] = useState<string | null>(null);
|
const [images, setImages] = useState<string[]>([]);
|
||||||
const [processingMessage, setProcessingMessage] = useState(false);
|
const [processingMessage, setProcessingMessage] = useState(false);
|
||||||
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
||||||
|
const chatInputRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
|
||||||
const setQueryToProcess = props.setQueryToProcess;
|
const setQueryToProcess = props.setQueryToProcess;
|
||||||
const onConversationIdChange = props.onConversationIdChange;
|
const onConversationIdChange = props.onConversationIdChange;
|
||||||
|
|
||||||
useEffect(() => {
|
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
|
||||||
if (image) {
|
|
||||||
props.setImage64(encodeURIComponent(image));
|
|
||||||
}
|
|
||||||
}, [image, props.setImage64]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const storedImage = localStorage.getItem("image");
|
if (images.length > 0) {
|
||||||
if (storedImage) {
|
const encodedImages = images.map((image) => encodeURIComponent(image));
|
||||||
setImage(storedImage);
|
props.setImages(encodedImages);
|
||||||
props.setImage64(encodeURIComponent(storedImage));
|
}
|
||||||
localStorage.removeItem("image");
|
}, [images, props.setImages]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const storedImages = localStorage.getItem("images");
|
||||||
|
if (storedImages) {
|
||||||
|
const parsedImages: string[] = JSON.parse(storedImages);
|
||||||
|
setImages(parsedImages);
|
||||||
|
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
|
||||||
|
props.setImages(encodedImages);
|
||||||
|
localStorage.removeItem("images");
|
||||||
}
|
}
|
||||||
|
|
||||||
const storedMessage = localStorage.getItem("message");
|
const storedMessage = localStorage.getItem("message");
|
||||||
|
@ -67,7 +71,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
setProcessingMessage(true);
|
setProcessingMessage(true);
|
||||||
setQueryToProcess(storedMessage);
|
setQueryToProcess(storedMessage);
|
||||||
}
|
}
|
||||||
}, [setQueryToProcess]);
|
}, [setQueryToProcess, props.setImages]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (message) {
|
if (message) {
|
||||||
|
@ -89,6 +93,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
props.streamedMessages[props.streamedMessages.length - 1].completed
|
props.streamedMessages[props.streamedMessages.length - 1].completed
|
||||||
) {
|
) {
|
||||||
setProcessingMessage(false);
|
setProcessingMessage(false);
|
||||||
|
setImages([]); // Reset images after processing
|
||||||
} else {
|
} else {
|
||||||
setMessage("");
|
setMessage("");
|
||||||
}
|
}
|
||||||
|
@ -108,21 +113,23 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
setAgent={setAgentMetadata}
|
setAgent={setAgentMetadata}
|
||||||
pendingMessage={processingMessage ? message : ""}
|
pendingMessage={processingMessage ? message : ""}
|
||||||
incomingMessages={props.streamedMessages}
|
incomingMessages={props.streamedMessages}
|
||||||
|
customClassName={chatHistoryCustomClassName}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl h-fit`}
|
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl h-fit ${chatHistoryCustomClassName} mr-auto ml-auto`}
|
||||||
>
|
>
|
||||||
<ChatInputArea
|
<ChatInputArea
|
||||||
agentColor={agentMetadata?.color}
|
agentColor={agentMetadata?.color}
|
||||||
isLoggedIn={props.isLoggedIn}
|
isLoggedIn={props.isLoggedIn}
|
||||||
sendMessage={(message) => setMessage(message)}
|
sendMessage={(message) => setMessage(message)}
|
||||||
sendImage={(image) => setImage(image)}
|
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||||
sendDisabled={processingMessage}
|
sendDisabled={processingMessage}
|
||||||
chatOptionsData={props.chatOptionsData}
|
chatOptionsData={props.chatOptionsData}
|
||||||
conversationId={conversationId}
|
conversationId={conversationId}
|
||||||
isMobileWidth={props.isMobileWidth}
|
isMobileWidth={props.isMobileWidth}
|
||||||
setUploadedFiles={props.setUploadedFiles}
|
setUploadedFiles={props.setUploadedFiles}
|
||||||
|
ref={chatInputRef}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
|
@ -139,7 +146,7 @@ export default function Chat() {
|
||||||
const [queryToProcess, setQueryToProcess] = useState<string>("");
|
const [queryToProcess, setQueryToProcess] = useState<string>("");
|
||||||
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
||||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||||
const [image64, setImage64] = useState<string>("");
|
const [images, setImages] = useState<string[]>([]);
|
||||||
|
|
||||||
const locationData = useIPLocationData() || {
|
const locationData = useIPLocationData() || {
|
||||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||||
|
@ -176,7 +183,7 @@ export default function Chat() {
|
||||||
completed: false,
|
completed: false,
|
||||||
timestamp: new Date().toISOString(),
|
timestamp: new Date().toISOString(),
|
||||||
rawQuery: queryToProcess || "",
|
rawQuery: queryToProcess || "",
|
||||||
uploadedImageData: decodeURIComponent(image64),
|
images: images,
|
||||||
};
|
};
|
||||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
||||||
setProcessQuerySignal(true);
|
setProcessQuerySignal(true);
|
||||||
|
@ -208,7 +215,7 @@ export default function Chat() {
|
||||||
if (done) {
|
if (done) {
|
||||||
setQueryToProcess("");
|
setQueryToProcess("");
|
||||||
setProcessQuerySignal(false);
|
setProcessQuerySignal(false);
|
||||||
setImage64("");
|
setImages([]);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -257,7 +264,7 @@ export default function Chat() {
|
||||||
country_code: locationData.countryCode,
|
country_code: locationData.countryCode,
|
||||||
timezone: locationData.timezone,
|
timezone: locationData.timezone,
|
||||||
}),
|
}),
|
||||||
...(image64 && { image: image64 }),
|
...(images.length > 0 && { images: images }),
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await fetch(chatAPI, {
|
const response = await fetch(chatAPI, {
|
||||||
|
@ -271,7 +278,8 @@ export default function Chat() {
|
||||||
try {
|
try {
|
||||||
await readChatStream(response);
|
await readChatStream(response);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err);
|
const apiError = await response.json();
|
||||||
|
console.error(apiError);
|
||||||
// Retrieve latest message being processed
|
// Retrieve latest message being processed
|
||||||
const currentMessage = messages.find((message) => !message.completed);
|
const currentMessage = messages.find((message) => !message.completed);
|
||||||
if (!currentMessage) return;
|
if (!currentMessage) return;
|
||||||
|
@ -280,7 +288,11 @@ export default function Chat() {
|
||||||
const errorMessage = (err as Error).message;
|
const errorMessage = (err as Error).message;
|
||||||
if (errorMessage.includes("Error in input stream"))
|
if (errorMessage.includes("Error in input stream"))
|
||||||
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
|
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
|
||||||
else
|
else if (response.status === 429) {
|
||||||
|
"detail" in apiError
|
||||||
|
? (currentMessage.rawResponse = `${apiError.detail}`)
|
||||||
|
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
|
||||||
|
} else
|
||||||
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
|
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
|
||||||
|
|
||||||
// Complete message streaming teardown properly
|
// Complete message streaming teardown properly
|
||||||
|
@ -339,7 +351,7 @@ export default function Chat() {
|
||||||
setUploadedFiles={setUploadedFiles}
|
setUploadedFiles={setUploadedFiles}
|
||||||
isMobileWidth={isMobileWidth}
|
isMobileWidth={isMobileWidth}
|
||||||
onConversationIdChange={handleConversationIdChange}
|
onConversationIdChange={handleConversationIdChange}
|
||||||
setImage64={setImage64}
|
setImages={setImages}
|
||||||
/>
|
/>
|
||||||
</Suspense>
|
</Suspense>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -68,7 +68,8 @@ export interface UserConfig {
|
||||||
selected_voice_model_config: number;
|
selected_voice_model_config: number;
|
||||||
// user billing info
|
// user billing info
|
||||||
subscription_state: SubscriptionStates;
|
subscription_state: SubscriptionStates;
|
||||||
subscription_renewal_date: string;
|
subscription_renewal_date: string | undefined;
|
||||||
|
subscription_enabled_trial_at: string | undefined;
|
||||||
// server settings
|
// server settings
|
||||||
khoj_cloud_subscription_url: string | undefined;
|
khoj_cloud_subscription_url: string | undefined;
|
||||||
billing_enabled: boolean;
|
billing_enabled: boolean;
|
||||||
|
@ -78,6 +79,7 @@ export interface UserConfig {
|
||||||
anonymous_mode: boolean;
|
anonymous_mode: boolean;
|
||||||
notion_oauth_url: string;
|
notion_oauth_url: string;
|
||||||
detail: string;
|
detail: string;
|
||||||
|
length_of_free_trial: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useUserConfig(detailed: boolean = false) {
|
export function useUserConfig(detailed: boolean = false) {
|
||||||
|
@ -93,3 +95,15 @@ export function useUserConfig(detailed: boolean = false) {
|
||||||
|
|
||||||
return { userConfig, isLoadingUserConfig };
|
return { userConfig, isLoadingUserConfig };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function isUserSubscribed(userConfig: UserConfig | null): boolean {
|
||||||
|
return (
|
||||||
|
(userConfig?.subscription_state &&
|
||||||
|
[
|
||||||
|
SubscriptionStates.SUBSCRIBED.valueOf(),
|
||||||
|
SubscriptionStates.TRIAL.valueOf(),
|
||||||
|
SubscriptionStates.UNSUBSCRIBED.valueOf(),
|
||||||
|
].includes(userConfig.subscription_state)) ||
|
||||||
|
false
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
|
@ -11,11 +11,10 @@ export interface RawReferenceData {
|
||||||
codeContext?: CodeContext;
|
codeContext?: CodeContext;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ResponseWithReferences {
|
export interface ResponseWithIntent {
|
||||||
context?: Context[];
|
intentType: string;
|
||||||
online?: OnlineContext;
|
response: string;
|
||||||
codeContext?: CodeContext;
|
inferredQueries?: string[];
|
||||||
response?: string;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
interface MessageChunk {
|
interface MessageChunk {
|
||||||
|
@ -56,10 +55,14 @@ export function convertMessageChunkToJson(chunk: string): MessageChunk {
|
||||||
function handleJsonResponse(chunkData: any) {
|
function handleJsonResponse(chunkData: any) {
|
||||||
const jsonData = chunkData as any;
|
const jsonData = chunkData as any;
|
||||||
if (jsonData.image || jsonData.detail) {
|
if (jsonData.image || jsonData.detail) {
|
||||||
let responseWithReference = handleImageResponse(chunkData, true);
|
let responseWithIntent = handleImageResponse(chunkData, true);
|
||||||
if (responseWithReference.response) return responseWithReference.response;
|
return responseWithIntent;
|
||||||
} else if (jsonData.response) {
|
} else if (jsonData.response) {
|
||||||
return jsonData.response;
|
return {
|
||||||
|
response: jsonData.response,
|
||||||
|
intentType: "",
|
||||||
|
inferredQueries: [],
|
||||||
|
};
|
||||||
} else {
|
} else {
|
||||||
throw new Error("Invalid JSON response");
|
throw new Error("Invalid JSON response");
|
||||||
}
|
}
|
||||||
|
@ -89,8 +92,18 @@ export function processMessageChunk(
|
||||||
return { context, onlineContext, codeContext };
|
return { context, onlineContext, codeContext };
|
||||||
} else if (chunk.type === "message") {
|
} else if (chunk.type === "message") {
|
||||||
const chunkData = chunk.data;
|
const chunkData = chunk.data;
|
||||||
|
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw
|
||||||
if (chunkData !== null && typeof chunkData === "object") {
|
if (chunkData !== null && typeof chunkData === "object") {
|
||||||
currentMessage.rawResponse += handleJsonResponse(chunkData);
|
let responseWithIntent = handleJsonResponse(chunkData);
|
||||||
|
|
||||||
|
if (responseWithIntent.intentType && responseWithIntent.intentType === "excalidraw") {
|
||||||
|
currentMessage.rawResponse = responseWithIntent.response;
|
||||||
|
} else {
|
||||||
|
currentMessage.rawResponse += responseWithIntent.response;
|
||||||
|
}
|
||||||
|
|
||||||
|
currentMessage.intentType = responseWithIntent.intentType;
|
||||||
|
currentMessage.inferredQueries = responseWithIntent.inferredQueries;
|
||||||
} else if (
|
} else if (
|
||||||
typeof chunkData === "string" &&
|
typeof chunkData === "string" &&
|
||||||
chunkData.trim()?.startsWith("{") &&
|
chunkData.trim()?.startsWith("{") &&
|
||||||
|
@ -98,7 +111,10 @@ export function processMessageChunk(
|
||||||
) {
|
) {
|
||||||
try {
|
try {
|
||||||
const jsonData = JSON.parse(chunkData.trim());
|
const jsonData = JSON.parse(chunkData.trim());
|
||||||
currentMessage.rawResponse += handleJsonResponse(jsonData);
|
let responseWithIntent = handleJsonResponse(jsonData);
|
||||||
|
currentMessage.rawResponse += responseWithIntent.response;
|
||||||
|
currentMessage.intentType = responseWithIntent.intentType;
|
||||||
|
currentMessage.inferredQueries = responseWithIntent.inferredQueries;
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
currentMessage.rawResponse += JSON.stringify(chunkData);
|
currentMessage.rawResponse += JSON.stringify(chunkData);
|
||||||
}
|
}
|
||||||
|
@ -148,42 +164,26 @@ export function processMessageChunk(
|
||||||
return { context, onlineContext, codeContext };
|
return { context, onlineContext, codeContext };
|
||||||
}
|
}
|
||||||
|
|
||||||
export function handleImageResponse(imageJson: any, liveStream: boolean): ResponseWithReferences {
|
export function handleImageResponse(imageJson: any, liveStream: boolean): ResponseWithIntent {
|
||||||
let rawResponse = "";
|
let rawResponse = "";
|
||||||
|
|
||||||
if (imageJson.image) {
|
if (imageJson.image) {
|
||||||
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
|
// If response has image field, response may be a generated image
|
||||||
|
rawResponse = imageJson.image;
|
||||||
// If response has image field, response is a generated image.
|
|
||||||
if (imageJson.intentType === "text-to-image") {
|
|
||||||
rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`;
|
|
||||||
} else if (imageJson.intentType === "text-to-image2") {
|
|
||||||
rawResponse += `![generated_image](${imageJson.image})`;
|
|
||||||
} else if (imageJson.intentType === "text-to-image-v3") {
|
|
||||||
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
|
|
||||||
}
|
|
||||||
if (inferredQuery && !liveStream) {
|
|
||||||
rawResponse += `\n\n${inferredQuery}`;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let reference: ResponseWithReferences = {};
|
let responseWithIntent: ResponseWithIntent = {
|
||||||
|
intentType: imageJson.intentType,
|
||||||
|
response: rawResponse,
|
||||||
|
inferredQueries: imageJson.inferredQueries,
|
||||||
|
};
|
||||||
|
|
||||||
if (imageJson.context && imageJson.context.length > 0) {
|
|
||||||
const rawReferenceAsJson = imageJson.context;
|
|
||||||
if (rawReferenceAsJson instanceof Array) {
|
|
||||||
reference.context = rawReferenceAsJson;
|
|
||||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
|
||||||
reference.online = rawReferenceAsJson;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (imageJson.detail) {
|
if (imageJson.detail) {
|
||||||
// The detail field contains the improved image prompt
|
// The detail field contains the improved image prompt
|
||||||
rawResponse += imageJson.detail;
|
rawResponse += imageJson.detail;
|
||||||
}
|
}
|
||||||
|
|
||||||
reference.response = rawResponse;
|
return responseWithIntent;
|
||||||
return reference;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export function renderCodeGenImageInline(message: string, codeContext: CodeContext) {
|
export function renderCodeGenImageInline(message: string, codeContext: CodeContext) {
|
||||||
|
@ -228,7 +228,11 @@ export function modifyFileFilterForConversation(
|
||||||
},
|
},
|
||||||
body: JSON.stringify(body),
|
body: JSON.stringify(body),
|
||||||
})
|
})
|
||||||
.then((response) => response.json())
|
.then((res) => {
|
||||||
|
if (!res.ok)
|
||||||
|
throw new Error(`Failed to call API at ${addUrl} with error ${res.statusText}`);
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
.then((data) => {
|
.then((data) => {
|
||||||
setAddedFiles(data);
|
setAddedFiles(data);
|
||||||
})
|
})
|
||||||
|
|
|
@ -48,6 +48,7 @@ import {
|
||||||
Oven,
|
Oven,
|
||||||
Gavel,
|
Gavel,
|
||||||
Broadcast,
|
Broadcast,
|
||||||
|
KeyReturn,
|
||||||
} from "@phosphor-icons/react";
|
} from "@phosphor-icons/react";
|
||||||
import { Markdown, OrgMode, Pdf, Word } from "@/app/components/logo/fileLogo";
|
import { Markdown, OrgMode, Pdf, Word } from "@/app/components/logo/fileLogo";
|
||||||
|
|
||||||
|
@ -193,6 +194,10 @@ export function getIconForSlashCommand(command: string, customClassName: string
|
||||||
}
|
}
|
||||||
|
|
||||||
if (command.includes("default")) {
|
if (command.includes("default")) {
|
||||||
|
return <KeyReturn className={className} />;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (command.includes("diagram")) {
|
||||||
return <Shapes className={className} />;
|
return <Shapes className={className} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,6 +246,7 @@ function getIconFromFilename(
|
||||||
case "jpg":
|
case "jpg":
|
||||||
case "jpeg":
|
case "jpeg":
|
||||||
case "png":
|
case "png":
|
||||||
|
case "webp":
|
||||||
return <Image className={className} weight="fill" />;
|
return <Image className={className} weight="fill" />;
|
||||||
default:
|
default:
|
||||||
return <File className={className} weight="fill" />;
|
return <File className={className} weight="fill" />;
|
||||||
|
|
|
@ -70,3 +70,19 @@ export function useIsMobileWidth() {
|
||||||
|
|
||||||
return isMobileWidth;
|
return isMobileWidth;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function useDebounce<T>(value: T, delay: number): T {
|
||||||
|
const [debouncedValue, setDebouncedValue] = useState<T>(value);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const handler = setTimeout(() => {
|
||||||
|
setDebouncedValue(value);
|
||||||
|
}, delay);
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
clearTimeout(handler);
|
||||||
|
};
|
||||||
|
}, [value, delay]);
|
||||||
|
|
||||||
|
return debouncedValue;
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
.agentPersonality p {
|
||||||
|
white-space: inherit;
|
||||||
|
overflow: hidden;
|
||||||
|
height: 77px;
|
||||||
|
line-height: 1.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.agentPersonality {
|
||||||
|
text-align: left;
|
||||||
|
grid-column: span 3;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
button.infoButton {
|
||||||
|
border: none;
|
||||||
|
background-color: transparent !important;
|
||||||
|
text-align: left;
|
||||||
|
font-family: inherit;
|
||||||
|
font-size: medium;
|
||||||
|
}
|
1297
src/interface/web/app/components/agentCard/agentCard.tsx
Normal file
1297
src/interface/web/app/components/agentCard/agentCard.tsx
Normal file
File diff suppressed because it is too large
Load diff
|
@ -2,12 +2,7 @@ div.chatHistory {
|
||||||
display: flex;
|
display: flex;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
height: 100%;
|
height: 100%;
|
||||||
}
|
margin: auto;
|
||||||
|
|
||||||
div.chatLayout {
|
|
||||||
height: 80vh;
|
|
||||||
overflow-y: auto;
|
|
||||||
margin: 0 auto;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
div.agentIndicator a {
|
div.agentIndicator a {
|
||||||
|
|
|
@ -37,6 +37,7 @@ interface ChatHistoryProps {
|
||||||
pendingMessage?: string;
|
pendingMessage?: string;
|
||||||
publicConversationSlug?: string;
|
publicConversationSlug?: string;
|
||||||
setAgent: (agent: AgentData) => void;
|
setAgent: (agent: AgentData) => void;
|
||||||
|
customClassName?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
function constructTrainOfThought(
|
function constructTrainOfThought(
|
||||||
|
@ -255,7 +256,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||||
return (
|
return (
|
||||||
<ScrollArea className={`h-[80vh] relative`} ref={scrollAreaRef}>
|
<ScrollArea className={`h-[80vh] relative`} ref={scrollAreaRef}>
|
||||||
<div>
|
<div>
|
||||||
<div className={styles.chatHistory}>
|
<div className={`${styles.chatHistory} ${props.customClassName}`}>
|
||||||
<div ref={sentinelRef} style={{ height: "1px" }}>
|
<div ref={sentinelRef} style={{ height: "1px" }}>
|
||||||
{fetchingData && (
|
{fetchingData && (
|
||||||
<InlineLoading message="Loading Conversation" className="opacity-50" />
|
<InlineLoading message="Loading Conversation" className="opacity-50" />
|
||||||
|
@ -299,7 +300,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||||
created: message.timestamp,
|
created: message.timestamp,
|
||||||
by: "you",
|
by: "you",
|
||||||
automationId: "",
|
automationId: "",
|
||||||
uploadedImageData: message.uploadedImageData,
|
images: message.images,
|
||||||
}}
|
}}
|
||||||
customClassName="fullHistory"
|
customClassName="fullHistory"
|
||||||
borderLeftColor={`${data?.agent?.color}-500`}
|
borderLeftColor={`${data?.agent?.color}-500`}
|
||||||
|
@ -324,6 +325,12 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||||
by: "khoj",
|
by: "khoj",
|
||||||
automationId: "",
|
automationId: "",
|
||||||
rawQuery: message.rawQuery,
|
rawQuery: message.rawQuery,
|
||||||
|
intent: {
|
||||||
|
type: message.intentType || "",
|
||||||
|
query: message.rawQuery,
|
||||||
|
"memory-type": "",
|
||||||
|
"inferred-queries": message.inferredQueries || [],
|
||||||
|
},
|
||||||
}}
|
}}
|
||||||
customClassName="fullHistory"
|
customClassName="fullHistory"
|
||||||
borderLeftColor={`${data?.agent?.color}-500`}
|
borderLeftColor={`${data?.agent?.color}-500`}
|
||||||
|
@ -344,7 +351,6 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||||
created: new Date().getTime().toString(),
|
created: new Date().getTime().toString(),
|
||||||
by: "you",
|
by: "you",
|
||||||
automationId: "",
|
automationId: "",
|
||||||
uploadedImageData: props.pendingMessage,
|
|
||||||
}}
|
}}
|
||||||
customClassName="fullHistory"
|
customClassName="fullHistory"
|
||||||
borderLeftColor={`${data?.agent?.color}-500`}
|
borderLeftColor={`${data?.agent?.color}-500`}
|
||||||
|
@ -369,18 +375,20 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
{!isNearBottom && (
|
<div className={`${props.customClassName} fixed bottom-[15%] z-10`}>
|
||||||
<button
|
{!isNearBottom && (
|
||||||
title="Scroll to bottom"
|
<button
|
||||||
className="absolute bottom-4 right-5 bg-white dark:bg-[hsl(var(--background))] text-neutral-500 dark:text-white p-2 rounded-full shadow-xl"
|
title="Scroll to bottom"
|
||||||
onClick={() => {
|
className="absolute bottom-0 right-0 bg-white dark:bg-[hsl(var(--background))] text-neutral-500 dark:text-white p-2 rounded-full shadow-xl"
|
||||||
scrollToBottom();
|
onClick={() => {
|
||||||
setIsNearBottom(true);
|
scrollToBottom();
|
||||||
}}
|
setIsNearBottom(true);
|
||||||
>
|
}}
|
||||||
<ArrowDown size={24} />
|
>
|
||||||
</button>
|
<ArrowDown size={24} />
|
||||||
)}
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</ScrollArea>
|
</ScrollArea>
|
||||||
);
|
);
|
||||||
|
|
|
@ -1,25 +1,9 @@
|
||||||
import styles from "./chatInputArea.module.css";
|
import styles from "./chatInputArea.module.css";
|
||||||
import React, { useEffect, useRef, useState } from "react";
|
import React, { useEffect, useRef, useState, forwardRef } from "react";
|
||||||
|
|
||||||
import DOMPurify from "dompurify";
|
import DOMPurify from "dompurify";
|
||||||
import "katex/dist/katex.min.css";
|
import "katex/dist/katex.min.css";
|
||||||
import {
|
import { ArrowUp, Microphone, Paperclip, X, Stop } from "@phosphor-icons/react";
|
||||||
ArrowRight,
|
|
||||||
ArrowUp,
|
|
||||||
Browser,
|
|
||||||
ChatsTeardrop,
|
|
||||||
GlobeSimple,
|
|
||||||
Gps,
|
|
||||||
Image,
|
|
||||||
Microphone,
|
|
||||||
Notebook,
|
|
||||||
Paperclip,
|
|
||||||
X,
|
|
||||||
Question,
|
|
||||||
Robot,
|
|
||||||
Shapes,
|
|
||||||
Stop,
|
|
||||||
} from "@phosphor-icons/react";
|
|
||||||
|
|
||||||
import {
|
import {
|
||||||
Command,
|
Command,
|
||||||
|
@ -68,7 +52,7 @@ interface ChatInputProps {
|
||||||
agentColor?: string;
|
agentColor?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function ChatInputArea(props: ChatInputProps) {
|
export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((props, ref) => {
|
||||||
const [message, setMessage] = useState("");
|
const [message, setMessage] = useState("");
|
||||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||||
|
|
||||||
|
@ -78,15 +62,17 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
const [loginRedirectMessage, setLoginRedirectMessage] = useState<string | null>(null);
|
const [loginRedirectMessage, setLoginRedirectMessage] = useState<string | null>(null);
|
||||||
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
|
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
|
||||||
|
|
||||||
const [recording, setRecording] = useState(false);
|
|
||||||
const [imageUploaded, setImageUploaded] = useState(false);
|
const [imageUploaded, setImageUploaded] = useState(false);
|
||||||
const [imagePath, setImagePath] = useState<string>("");
|
const [imagePaths, setImagePaths] = useState<string[]>([]);
|
||||||
const [imageData, setImageData] = useState<string | null>(null);
|
const [imageData, setImageData] = useState<string[]>([]);
|
||||||
|
|
||||||
|
const [recording, setRecording] = useState(false);
|
||||||
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
|
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
|
||||||
|
|
||||||
const [progressValue, setProgressValue] = useState(0);
|
const [progressValue, setProgressValue] = useState(0);
|
||||||
const [isDragAndDropping, setIsDragAndDropping] = useState(false);
|
const [isDragAndDropping, setIsDragAndDropping] = useState(false);
|
||||||
|
|
||||||
|
const chatInputRef = ref as React.MutableRefObject<HTMLTextAreaElement>;
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!uploading) {
|
if (!uploading) {
|
||||||
setProgressValue(0);
|
setProgressValue(0);
|
||||||
|
@ -106,27 +92,31 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
async function fetchImageData() {
|
async function fetchImageData() {
|
||||||
if (imagePath) {
|
if (imagePaths.length > 0) {
|
||||||
const response = await fetch(imagePath);
|
const newImageData = await Promise.all(
|
||||||
const blob = await response.blob();
|
imagePaths.map(async (path) => {
|
||||||
const reader = new FileReader();
|
const response = await fetch(path);
|
||||||
reader.onload = function () {
|
const blob = await response.blob();
|
||||||
const base64data = reader.result;
|
return new Promise<string>((resolve) => {
|
||||||
setImageData(base64data as string);
|
const reader = new FileReader();
|
||||||
};
|
reader.onload = () => resolve(reader.result as string);
|
||||||
reader.readAsDataURL(blob);
|
reader.readAsDataURL(blob);
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
setImageData(newImageData);
|
||||||
}
|
}
|
||||||
setUploading(false);
|
setUploading(false);
|
||||||
}
|
}
|
||||||
setUploading(true);
|
setUploading(true);
|
||||||
fetchImageData();
|
fetchImageData();
|
||||||
}, [imagePath]);
|
}, [imagePaths]);
|
||||||
|
|
||||||
function onSendMessage() {
|
function onSendMessage() {
|
||||||
if (imageUploaded) {
|
if (imageUploaded) {
|
||||||
setImageUploaded(false);
|
setImageUploaded(false);
|
||||||
setImagePath("");
|
setImagePaths([]);
|
||||||
props.sendImage(imageData || "");
|
imageData.forEach((data) => props.sendImage(data));
|
||||||
}
|
}
|
||||||
if (!message.trim()) return;
|
if (!message.trim()) return;
|
||||||
|
|
||||||
|
@ -168,22 +158,29 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
|
|
||||||
function uploadFiles(files: FileList) {
|
function uploadFiles(files: FileList) {
|
||||||
if (!props.isLoggedIn) {
|
if (!props.isLoggedIn) {
|
||||||
setLoginRedirectMessage("Whoa! You need to login to upload files");
|
setLoginRedirectMessage("Please login to chat with your files");
|
||||||
setShowLoginPrompt(true);
|
setShowLoginPrompt(true);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// check for image file
|
// check for image files
|
||||||
const image_endings = ["jpg", "jpeg", "png"];
|
const image_endings = ["jpg", "jpeg", "png", "webp"];
|
||||||
|
const newImagePaths: string[] = [];
|
||||||
for (let i = 0; i < files.length; i++) {
|
for (let i = 0; i < files.length; i++) {
|
||||||
const file = files[i];
|
const file = files[i];
|
||||||
const file_extension = file.name.split(".").pop();
|
const file_extension = file.name.split(".").pop();
|
||||||
if (image_endings.includes(file_extension || "")) {
|
if (image_endings.includes(file_extension || "")) {
|
||||||
setImageUploaded(true);
|
newImagePaths.push(DOMPurify.sanitize(URL.createObjectURL(file)));
|
||||||
setImagePath(DOMPurify.sanitize(URL.createObjectURL(file)));
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (newImagePaths.length > 0) {
|
||||||
|
setImageUploaded(true);
|
||||||
|
setImagePaths((prevPaths) => [...prevPaths, ...newImagePaths]);
|
||||||
|
// Set focus to the input for user message after uploading files
|
||||||
|
chatInputRef?.current?.focus();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
uploadDataForIndexing(
|
uploadDataForIndexing(
|
||||||
files,
|
files,
|
||||||
setWarning,
|
setWarning,
|
||||||
|
@ -192,6 +189,9 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
props.setUploadedFiles,
|
props.setUploadedFiles,
|
||||||
props.conversationId,
|
props.conversationId,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Set focus to the input for user message after uploading files
|
||||||
|
chatInputRef?.current?.focus();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assuming this function is added within the same context as the provided excerpt
|
// Assuming this function is added within the same context as the provided excerpt
|
||||||
|
@ -270,9 +270,8 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
}
|
}
|
||||||
}, [recording, mediaRecorder]);
|
}, [recording, mediaRecorder]);
|
||||||
|
|
||||||
const chatInputRef = useRef<HTMLTextAreaElement>(null);
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!chatInputRef.current) return;
|
if (!chatInputRef?.current) return;
|
||||||
chatInputRef.current.style.height = "auto";
|
chatInputRef.current.style.height = "auto";
|
||||||
chatInputRef.current.style.height =
|
chatInputRef.current.style.height =
|
||||||
Math.max(chatInputRef.current.scrollHeight - 24, 64) + "px";
|
Math.max(chatInputRef.current.scrollHeight - 24, 64) + "px";
|
||||||
|
@ -288,9 +287,12 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
setIsDragAndDropping(false);
|
setIsDragAndDropping(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
function removeImageUpload() {
|
function removeImageUpload(index: number) {
|
||||||
setImageUploaded(false);
|
setImagePaths((prevPaths) => prevPaths.filter((_, i) => i !== index));
|
||||||
setImagePath("");
|
setImageData((prevData) => prevData.filter((_, i) => i !== index));
|
||||||
|
if (imagePaths.length === 1) {
|
||||||
|
setImageUploaded(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -407,24 +409,11 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<div
|
<div
|
||||||
className={`${styles.actualInputArea} items-center justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
||||||
onDragOver={handleDragOver}
|
onDragOver={handleDragOver}
|
||||||
onDragLeave={handleDragLeave}
|
onDragLeave={handleDragLeave}
|
||||||
onDrop={handleDragAndDropFiles}
|
onDrop={handleDragAndDropFiles}
|
||||||
>
|
>
|
||||||
{imageUploaded && (
|
|
||||||
<div className="absolute bottom-[80px] left-0 right-0 dark:bg-neutral-700 bg-white pt-5 pb-5 w-full rounded-lg border dark:border-none grid grid-cols-2">
|
|
||||||
<div className="pl-4 pr-4">
|
|
||||||
<img src={imagePath} alt="img" className="w-auto max-h-[100px]" />
|
|
||||||
</div>
|
|
||||||
<div className="pl-4 pr-4">
|
|
||||||
<X
|
|
||||||
className="w-6 h-6 float-right dark:hover:bg-[hsl(var(--background))] hover:bg-neutral-100 rounded-sm"
|
|
||||||
onClick={removeImageUpload}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
<input
|
<input
|
||||||
type="file"
|
type="file"
|
||||||
multiple={true}
|
multiple={true}
|
||||||
|
@ -432,15 +421,37 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
onChange={handleFileChange}
|
onChange={handleFileChange}
|
||||||
style={{ display: "none" }}
|
style={{ display: "none" }}
|
||||||
/>
|
/>
|
||||||
<Button
|
<div className="flex items-end pb-4">
|
||||||
variant={"ghost"}
|
<Button
|
||||||
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
|
variant={"ghost"}
|
||||||
disabled={props.sendDisabled}
|
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
|
||||||
onClick={handleFileButtonClick}
|
disabled={props.sendDisabled}
|
||||||
>
|
onClick={handleFileButtonClick}
|
||||||
<Paperclip className="w-8 h-8" />
|
>
|
||||||
</Button>
|
<Paperclip className="w-8 h-8" />
|
||||||
<div className="grid w-full gap-1.5 relative">
|
</Button>
|
||||||
|
</div>
|
||||||
|
<div className="flex-grow flex flex-col w-full gap-1.5 relative pb-2">
|
||||||
|
<div className="flex items-center gap-2 overflow-x-auto">
|
||||||
|
{imageUploaded &&
|
||||||
|
imagePaths.map((path, index) => (
|
||||||
|
<div key={index} className="relative flex-shrink-0 pb-3 pt-2 group">
|
||||||
|
<img
|
||||||
|
src={path}
|
||||||
|
alt={`img-${index}`}
|
||||||
|
className="w-auto h-16 object-cover rounded-xl"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||||
|
onClick={() => removeImageUpload(index)}
|
||||||
|
>
|
||||||
|
<X className="h-3 w-3" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
<Textarea
|
<Textarea
|
||||||
ref={chatInputRef}
|
ref={chatInputRef}
|
||||||
className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`}
|
className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`}
|
||||||
|
@ -449,9 +460,9 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
autoFocus={true}
|
autoFocus={true}
|
||||||
value={message}
|
value={message}
|
||||||
onKeyDown={(e) => {
|
onKeyDown={(e) => {
|
||||||
if (e.key === "Enter" && !e.shiftKey) {
|
if (e.key === "Enter" && !e.shiftKey && !props.isMobileWidth) {
|
||||||
setImageUploaded(false);
|
setImageUploaded(false);
|
||||||
setImagePath("");
|
setImagePaths([]);
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
onSendMessage();
|
onSendMessage();
|
||||||
}
|
}
|
||||||
|
@ -460,58 +471,62 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
disabled={props.sendDisabled || recording}
|
disabled={props.sendDisabled || recording}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
{recording ? (
|
<div className="flex items-end pb-4">
|
||||||
<TooltipProvider>
|
{recording ? (
|
||||||
<Tooltip>
|
<TooltipProvider>
|
||||||
<TooltipTrigger asChild>
|
<Tooltip>
|
||||||
<Button
|
<TooltipTrigger asChild>
|
||||||
variant="default"
|
<Button
|
||||||
className={`${!recording && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
variant="default"
|
||||||
onClick={() => {
|
className={`${!recording && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||||
setRecording(!recording);
|
onClick={() => {
|
||||||
}}
|
setRecording(!recording);
|
||||||
disabled={props.sendDisabled}
|
}}
|
||||||
>
|
disabled={props.sendDisabled}
|
||||||
<Stop weight="fill" className="w-6 h-6" />
|
>
|
||||||
</Button>
|
<Stop weight="fill" className="w-6 h-6" />
|
||||||
</TooltipTrigger>
|
</Button>
|
||||||
<TooltipContent>
|
</TooltipTrigger>
|
||||||
Click to stop recording and transcribe your voice.
|
<TooltipContent>
|
||||||
</TooltipContent>
|
Click to stop recording and transcribe your voice.
|
||||||
</Tooltip>
|
</TooltipContent>
|
||||||
</TooltipProvider>
|
</Tooltip>
|
||||||
) : mediaRecorder ? (
|
</TooltipProvider>
|
||||||
<InlineLoading />
|
) : mediaRecorder ? (
|
||||||
) : (
|
<InlineLoading />
|
||||||
<TooltipProvider>
|
) : (
|
||||||
<Tooltip>
|
<TooltipProvider>
|
||||||
<TooltipTrigger asChild>
|
<Tooltip>
|
||||||
<Button
|
<TooltipTrigger asChild>
|
||||||
variant="default"
|
<Button
|
||||||
className={`${!message || recording || "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
variant="default"
|
||||||
onClick={() => {
|
className={`${!message || recording || "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||||
setMessage("Listening...");
|
onClick={() => {
|
||||||
setRecording(!recording);
|
setMessage("Listening...");
|
||||||
}}
|
setRecording(!recording);
|
||||||
disabled={props.sendDisabled}
|
}}
|
||||||
>
|
disabled={props.sendDisabled}
|
||||||
<Microphone weight="fill" className="w-6 h-6" />
|
>
|
||||||
</Button>
|
<Microphone weight="fill" className="w-6 h-6" />
|
||||||
</TooltipTrigger>
|
</Button>
|
||||||
<TooltipContent>
|
</TooltipTrigger>
|
||||||
Click to transcribe your message with voice.
|
<TooltipContent>
|
||||||
</TooltipContent>
|
Click to transcribe your message with voice.
|
||||||
</Tooltip>
|
</TooltipContent>
|
||||||
</TooltipProvider>
|
</Tooltip>
|
||||||
)}
|
</TooltipProvider>
|
||||||
<Button
|
)}
|
||||||
className={`${(!message || recording) && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
<Button
|
||||||
onClick={onSendMessage}
|
className={`${(!message || recording) && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||||
disabled={props.sendDisabled}
|
onClick={onSendMessage}
|
||||||
>
|
disabled={props.sendDisabled}
|
||||||
<ArrowUp className="w-6 h-6" weight="bold" />
|
>
|
||||||
</Button>
|
<ArrowUp className="w-6 h-6" weight="bold" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
}
|
});
|
||||||
|
|
||||||
|
ChatInputArea.displayName = "ChatInputArea";
|
||||||
|
|
|
@ -57,7 +57,26 @@ div.emptyChatMessage {
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
div.chatMessageContainer img {
|
div.imagesContainer {
|
||||||
|
display: flex;
|
||||||
|
overflow-x: auto;
|
||||||
|
padding-bottom: 8px;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.imageWrapper {
|
||||||
|
flex: 0 0 auto;
|
||||||
|
margin-right: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.imageWrapper img {
|
||||||
|
width: auto;
|
||||||
|
height: 128px;
|
||||||
|
object-fit: cover;
|
||||||
|
border-radius: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.chatMessageContainer > img {
|
||||||
width: auto;
|
width: auto;
|
||||||
height: auto;
|
height: auto;
|
||||||
max-width: 100%;
|
max-width: 100%;
|
||||||
|
|
|
@ -28,6 +28,7 @@ import {
|
||||||
ClipboardText,
|
ClipboardText,
|
||||||
Check,
|
Check,
|
||||||
Code,
|
Code,
|
||||||
|
Shapes,
|
||||||
} from "@phosphor-icons/react";
|
} from "@phosphor-icons/react";
|
||||||
|
|
||||||
import DOMPurify from "dompurify";
|
import DOMPurify from "dompurify";
|
||||||
|
@ -37,6 +38,7 @@ import { AgentData } from "@/app/agents/page";
|
||||||
|
|
||||||
import renderMathInElement from "katex/contrib/auto-render";
|
import renderMathInElement from "katex/contrib/auto-render";
|
||||||
import "katex/dist/katex.min.css";
|
import "katex/dist/katex.min.css";
|
||||||
|
import ExcalidrawComponent from "../excalidraw/excalidraw";
|
||||||
|
|
||||||
const md = new markdownIt({
|
const md = new markdownIt({
|
||||||
html: true,
|
html: true,
|
||||||
|
@ -137,7 +139,7 @@ export interface SingleChatMessage {
|
||||||
rawQuery?: string;
|
rawQuery?: string;
|
||||||
intent?: Intent;
|
intent?: Intent;
|
||||||
agent?: AgentData;
|
agent?: AgentData;
|
||||||
uploadedImageData?: string;
|
images?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface StreamMessage {
|
export interface StreamMessage {
|
||||||
|
@ -150,7 +152,9 @@ export interface StreamMessage {
|
||||||
rawQuery: string;
|
rawQuery: string;
|
||||||
timestamp: string;
|
timestamp: string;
|
||||||
agent?: AgentData;
|
agent?: AgentData;
|
||||||
uploadedImageData?: string;
|
images?: string[];
|
||||||
|
intentType?: string;
|
||||||
|
inferredQueries?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatHistoryData {
|
export interface ChatHistoryData {
|
||||||
|
@ -232,7 +236,6 @@ interface ChatMessageProps {
|
||||||
borderLeftColor?: string;
|
borderLeftColor?: string;
|
||||||
isLastMessage?: boolean;
|
isLastMessage?: boolean;
|
||||||
agent?: AgentData;
|
agent?: AgentData;
|
||||||
uploadedImageData?: string;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
interface TrainOfThoughtProps {
|
interface TrainOfThoughtProps {
|
||||||
|
@ -276,6 +279,10 @@ function chooseIconFromHeader(header: string, iconColor: string) {
|
||||||
return <Aperture className={`${classNames}`} />;
|
return <Aperture className={`${classNames}`} />;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (compareHeader.includes("diagram")) {
|
||||||
|
return <Shapes className={`${classNames}`} />;
|
||||||
|
}
|
||||||
|
|
||||||
if (compareHeader.includes("paint")) {
|
if (compareHeader.includes("paint")) {
|
||||||
return <Palette className={`${classNames}`} />;
|
return <Palette className={`${classNames}`} />;
|
||||||
}
|
}
|
||||||
|
@ -311,6 +318,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||||
const [markdownRendered, setMarkdownRendered] = useState<string>("");
|
const [markdownRendered, setMarkdownRendered] = useState<string>("");
|
||||||
const [isPlaying, setIsPlaying] = useState<boolean>(false);
|
const [isPlaying, setIsPlaying] = useState<boolean>(false);
|
||||||
const [interrupted, setInterrupted] = useState<boolean>(false);
|
const [interrupted, setInterrupted] = useState<boolean>(false);
|
||||||
|
const [excalidrawData, setExcalidrawData] = useState<string>("");
|
||||||
|
|
||||||
const interruptedRef = useRef<boolean>(false);
|
const interruptedRef = useRef<boolean>(false);
|
||||||
const messageRef = useRef<HTMLDivElement>(null);
|
const messageRef = useRef<HTMLDivElement>(null);
|
||||||
|
@ -347,8 +355,14 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||||
}, [messageRef.current]);
|
}, [messageRef.current]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
// Prepare initial message for rendering
|
||||||
let message = props.chatMessage.message;
|
let message = props.chatMessage.message;
|
||||||
|
|
||||||
|
if (props.chatMessage.intent && props.chatMessage.intent.type == "excalidraw") {
|
||||||
|
message = props.chatMessage.intent["inferred-queries"][0];
|
||||||
|
setExcalidrawData(props.chatMessage.message);
|
||||||
|
}
|
||||||
|
|
||||||
// Replace LaTeX delimiters with placeholders
|
// Replace LaTeX delimiters with placeholders
|
||||||
message = message
|
message = message
|
||||||
.replace(/\\\(/g, "LEFTPAREN")
|
.replace(/\\\(/g, "LEFTPAREN")
|
||||||
|
@ -356,8 +370,50 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||||
.replace(/\\\[/g, "LEFTBRACKET")
|
.replace(/\\\[/g, "LEFTBRACKET")
|
||||||
.replace(/\\\]/g, "RIGHTBRACKET");
|
.replace(/\\\]/g, "RIGHTBRACKET");
|
||||||
|
|
||||||
if (props.chatMessage.uploadedImageData) {
|
const intentTypeHandlers = {
|
||||||
message = `![uploaded image](${props.chatMessage.uploadedImageData})\n\n${message}`;
|
"text-to-image": (msg: string) => `![generated image](data:image/png;base64,${msg})`,
|
||||||
|
"text-to-image2": (msg: string) => `![generated image](${msg})`,
|
||||||
|
"text-to-image-v3": (msg: string) =>
|
||||||
|
`![generated image](data:image/webp;base64,${msg})`,
|
||||||
|
excalidraw: (msg: string) => msg,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Handle intent-specific rendering
|
||||||
|
if (props.chatMessage.intent) {
|
||||||
|
const { type, "inferred-queries": inferredQueries } = props.chatMessage.intent;
|
||||||
|
|
||||||
|
console.log("intent type", type);
|
||||||
|
if (type in intentTypeHandlers) {
|
||||||
|
message = intentTypeHandlers[type as keyof typeof intentTypeHandlers](message);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type.includes("text-to-image") && inferredQueries?.length > 0) {
|
||||||
|
message += `\n\n${inferredQueries[0]}`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Handle user attached images rendering
|
||||||
|
let messageForClipboard = message;
|
||||||
|
let messageToRender = message;
|
||||||
|
if (props.chatMessage.images && props.chatMessage.images.length > 0) {
|
||||||
|
const sanitizedImages = props.chatMessage.images.map((image) => {
|
||||||
|
const decodedImage = image.startsWith("data%3Aimage")
|
||||||
|
? decodeURIComponent(image)
|
||||||
|
: image;
|
||||||
|
return DOMPurify.sanitize(decodedImage);
|
||||||
|
});
|
||||||
|
const imagesInMd = sanitizedImages
|
||||||
|
.map((sanitizedImage, index) => {
|
||||||
|
return `![uploaded image ${index + 1}](${sanitizedImage})`;
|
||||||
|
})
|
||||||
|
.join("\n");
|
||||||
|
const imagesInHtml = sanitizedImages
|
||||||
|
.map((sanitizedImage, index) => {
|
||||||
|
return `<div class="${styles.imageWrapper}"><img src="${sanitizedImage}" alt="uploaded image ${index + 1}" /></div>`;
|
||||||
|
})
|
||||||
|
.join("");
|
||||||
|
const userImagesInHtml = `<div class="${styles.imagesContainer}">${imagesInHtml}</div>`;
|
||||||
|
messageForClipboard = `${imagesInMd}\n\n${messageForClipboard}`;
|
||||||
|
messageToRender = `${userImagesInHtml}${messageToRender}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (props.chatMessage.intent && props.chatMessage.intent.type == "text-to-image") {
|
if (props.chatMessage.intent && props.chatMessage.intent.type == "text-to-image") {
|
||||||
|
@ -402,10 +458,11 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
setTextRendered(message);
|
// Set the message text
|
||||||
|
setTextRendered(messageForClipboard);
|
||||||
|
|
||||||
// Render the markdown
|
// Render the markdown
|
||||||
let markdownRendered = md.render(message);
|
let markdownRendered = md.render(messageToRender);
|
||||||
|
|
||||||
// Replace placeholders with LaTeX delimiters
|
// Replace placeholders with LaTeX delimiters
|
||||||
markdownRendered = markdownRendered
|
markdownRendered = markdownRendered
|
||||||
|
@ -416,7 +473,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||||
|
|
||||||
// Sanitize and set the rendered markdown
|
// Sanitize and set the rendered markdown
|
||||||
setMarkdownRendered(DOMPurify.sanitize(markdownRendered));
|
setMarkdownRendered(DOMPurify.sanitize(markdownRendered));
|
||||||
}, [props.chatMessage.message, props.chatMessage.intent]);
|
}, [props.chatMessage.message, props.chatMessage.images, props.chatMessage.intent]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (copySuccess) {
|
if (copySuccess) {
|
||||||
|
@ -607,6 +664,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||||
className={styles.chatMessage}
|
className={styles.chatMessage}
|
||||||
dangerouslySetInnerHTML={{ __html: markdownRendered }}
|
dangerouslySetInnerHTML={{ __html: markdownRendered }}
|
||||||
/>
|
/>
|
||||||
|
{excalidrawData && <ExcalidrawComponent data={excalidrawData} />}
|
||||||
</div>
|
</div>
|
||||||
<div className={styles.teaserReferencesContainer}>
|
<div className={styles.teaserReferencesContainer}>
|
||||||
<TeaserReferencesSection
|
<TeaserReferencesSection
|
||||||
|
|
24
src/interface/web/app/components/excalidraw/excalidraw.tsx
Normal file
24
src/interface/web/app/components/excalidraw/excalidraw.tsx
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import dynamic from "next/dynamic";
|
||||||
|
import { Suspense } from "react";
|
||||||
|
import Loading from "../../components/loading/loading";
|
||||||
|
|
||||||
|
// Since client components get prerenderd on server as well hence importing
|
||||||
|
// the excalidraw stuff dynamically with ssr false
|
||||||
|
|
||||||
|
const ExcalidrawWrapper = dynamic(() => import("./excalidrawWrapper").then((mod) => mod.default), {
|
||||||
|
ssr: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
interface ExcalidrawComponentProps {
|
||||||
|
data: any;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function ExcalidrawComponent(props: ExcalidrawComponentProps) {
|
||||||
|
return (
|
||||||
|
<Suspense fallback={<Loading />}>
|
||||||
|
<ExcalidrawWrapper data={props.data} />
|
||||||
|
</Suspense>
|
||||||
|
);
|
||||||
|
}
|
|
@ -0,0 +1,149 @@
|
||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState, useEffect } from "react";
|
||||||
|
|
||||||
|
import dynamic from "next/dynamic";
|
||||||
|
|
||||||
|
import { ExcalidrawProps } from "@excalidraw/excalidraw/types/types";
|
||||||
|
import { ExcalidrawElement } from "@excalidraw/excalidraw/types/element/types";
|
||||||
|
import { ExcalidrawElementSkeleton } from "@excalidraw/excalidraw/types/data/transform";
|
||||||
|
|
||||||
|
const Excalidraw = dynamic<ExcalidrawProps>(
|
||||||
|
async () => (await import("@excalidraw/excalidraw")).Excalidraw,
|
||||||
|
{
|
||||||
|
ssr: false,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
import { convertToExcalidrawElements } from "@excalidraw/excalidraw";
|
||||||
|
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
|
||||||
|
import { ArrowsInSimple, ArrowsOutSimple } from "@phosphor-icons/react";
|
||||||
|
|
||||||
|
interface ExcalidrawWrapperProps {
|
||||||
|
data: ExcalidrawElementSkeleton[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function ExcalidrawWrapper(props: ExcalidrawWrapperProps) {
|
||||||
|
const [excalidrawElements, setExcalidrawElements] = useState<ExcalidrawElement[]>([]);
|
||||||
|
const [expanded, setExpanded] = useState<boolean>(false);
|
||||||
|
|
||||||
|
const isValidExcalidrawElement = (element: ExcalidrawElementSkeleton): boolean => {
|
||||||
|
return (
|
||||||
|
element.x !== undefined &&
|
||||||
|
element.y !== undefined &&
|
||||||
|
element.id !== undefined &&
|
||||||
|
element.type !== undefined
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (expanded) {
|
||||||
|
onkeydown = (e) => {
|
||||||
|
if (e.key === "Escape") {
|
||||||
|
setExpanded(false);
|
||||||
|
// Trigger a resize event to make Excalidraw adjust its size
|
||||||
|
window.dispatchEvent(new Event("resize"));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
onkeydown = null;
|
||||||
|
}
|
||||||
|
}, [expanded]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// Do some basic validation
|
||||||
|
const basicValidSkeletons: ExcalidrawElementSkeleton[] = [];
|
||||||
|
|
||||||
|
for (const element of props.data) {
|
||||||
|
if (isValidExcalidrawElement(element as ExcalidrawElementSkeleton)) {
|
||||||
|
basicValidSkeletons.push(element as ExcalidrawElementSkeleton);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const validSkeletons: ExcalidrawElementSkeleton[] = [];
|
||||||
|
for (const element of basicValidSkeletons) {
|
||||||
|
if (element.type === "frame") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (element.type === "arrow") {
|
||||||
|
const start = basicValidSkeletons.find((child) => child.id === element.start?.id);
|
||||||
|
const end = basicValidSkeletons.find((child) => child.id === element.end?.id);
|
||||||
|
if (start && end) {
|
||||||
|
validSkeletons.push(element);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
validSkeletons.push(element);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const element of basicValidSkeletons) {
|
||||||
|
if (element.type === "frame") {
|
||||||
|
const children = element.children?.map((childId) => {
|
||||||
|
return validSkeletons.find((child) => child.id === childId);
|
||||||
|
});
|
||||||
|
// Get the valid children, filter out any undefined values
|
||||||
|
const validChildrenIds: readonly string[] = children
|
||||||
|
?.map((child) => child?.id)
|
||||||
|
.filter((id) => id !== undefined) as string[];
|
||||||
|
|
||||||
|
if (validChildrenIds === undefined || validChildrenIds.length === 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
validSkeletons.push({
|
||||||
|
...element,
|
||||||
|
children: validChildrenIds,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const elements = convertToExcalidrawElements(validSkeletons);
|
||||||
|
setExcalidrawElements(elements);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="relative">
|
||||||
|
<div
|
||||||
|
className={`${expanded ? "fixed inset-0 bg-black bg-opacity-50 backdrop-blur-sm z-50 flex items-center justify-center" : ""}`}
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
onClick={() => {
|
||||||
|
setExpanded(!expanded);
|
||||||
|
// Trigger a resize event to make Excalidraw adjust its size
|
||||||
|
window.dispatchEvent(new Event("resize"));
|
||||||
|
}}
|
||||||
|
variant={"outline"}
|
||||||
|
className={`${expanded ? "absolute top-2 left-2 z-[60]" : ""}`}
|
||||||
|
>
|
||||||
|
{expanded ? (
|
||||||
|
<ArrowsInSimple className="h-4 w-4" />
|
||||||
|
) : (
|
||||||
|
<ArrowsOutSimple className="h-4 w-4" />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
<div
|
||||||
|
className={`
|
||||||
|
${expanded ? "w-[80vw] h-[80vh]" : "w-full h-[500px]"}
|
||||||
|
bg-white overflow-hidden rounded-lg relative
|
||||||
|
`}
|
||||||
|
>
|
||||||
|
<Excalidraw
|
||||||
|
initialData={{
|
||||||
|
elements: excalidrawElements,
|
||||||
|
appState: { zenModeEnabled: true },
|
||||||
|
scrollToContent: true,
|
||||||
|
}}
|
||||||
|
// TODO - Create a common function to detect if the theme is dark?
|
||||||
|
theme={localStorage.getItem("theme") === "dark" ? "dark" : "light"}
|
||||||
|
validateEmbeddable={true}
|
||||||
|
renderTopRightUI={(isMobile, appState) => {
|
||||||
|
return <></>;
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
|
@ -98,7 +98,11 @@ import { KhojLogoType } from "@/app/components/logo/khojLogo";
|
||||||
import NavMenu from "@/app/components/navMenu/navMenu";
|
import NavMenu from "@/app/components/navMenu/navMenu";
|
||||||
|
|
||||||
// Define a fetcher function
|
// Define a fetcher function
|
||||||
const fetcher = (url: string) => fetch(url).then((res) => res.json());
|
const fetcher = (url: string) =>
|
||||||
|
fetch(url).then((res) => {
|
||||||
|
if (!res.ok) throw new Error(`Failed to call API at ${url} with error ${res.statusText}`);
|
||||||
|
return res.json();
|
||||||
|
});
|
||||||
|
|
||||||
interface GroupedChatHistory {
|
interface GroupedChatHistory {
|
||||||
[key: string]: ChatHistory[];
|
[key: string]: ChatHistory[];
|
||||||
|
@ -181,20 +185,15 @@ function FilesMenu(props: FilesMenuProps) {
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!files) return;
|
if (!files) return;
|
||||||
|
|
||||||
const uniqueFiles = Array.from(new Set(files));
|
let sortedUniqueFiles = Array.from(new Set(files)).sort();
|
||||||
|
|
||||||
// First, sort lexically
|
if (Array.isArray(addedFiles)) {
|
||||||
uniqueFiles.sort();
|
sortedUniqueFiles = addedFiles.concat(
|
||||||
|
sortedUniqueFiles.filter((filename: string) => !addedFiles.includes(filename)),
|
||||||
let sortedFiles = uniqueFiles;
|
|
||||||
|
|
||||||
if (addedFiles) {
|
|
||||||
sortedFiles = addedFiles.concat(
|
|
||||||
sortedFiles.filter((filename: string) => !addedFiles.includes(filename)),
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
setUnfilteredFiles(sortedFiles);
|
setUnfilteredFiles(sortedUniqueFiles);
|
||||||
}, [files, addedFiles]);
|
}, [files, addedFiles]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
@ -204,8 +203,10 @@ function FilesMenu(props: FilesMenuProps) {
|
||||||
}, [props.uploadedFiles]);
|
}, [props.uploadedFiles]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (selectedFiles) {
|
if (Array.isArray(selectedFiles)) {
|
||||||
setAddedFiles(selectedFiles);
|
setAddedFiles(selectedFiles);
|
||||||
|
} else {
|
||||||
|
setAddedFiles([]);
|
||||||
}
|
}
|
||||||
}, [selectedFiles]);
|
}, [selectedFiles]);
|
||||||
|
|
||||||
|
@ -269,7 +270,7 @@ function FilesMenu(props: FilesMenuProps) {
|
||||||
</CommandItem>
|
</CommandItem>
|
||||||
)}
|
)}
|
||||||
{unfilteredFiles.map((filename: string) =>
|
{unfilteredFiles.map((filename: string) =>
|
||||||
addedFiles && addedFiles.includes(filename) ? (
|
Array.isArray(addedFiles) && addedFiles.includes(filename) ? (
|
||||||
<CommandItem
|
<CommandItem
|
||||||
key={filename}
|
key={filename}
|
||||||
value={filename}
|
value={filename}
|
||||||
|
|
|
@ -3,26 +3,33 @@ import "./globals.css";
|
||||||
import styles from "./page.module.css";
|
import styles from "./page.module.css";
|
||||||
import "katex/dist/katex.min.css";
|
import "katex/dist/katex.min.css";
|
||||||
|
|
||||||
import React, { useEffect, useState } from "react";
|
import React, { useEffect, useRef, useState } from "react";
|
||||||
import useSWR from "swr";
|
import useSWR from "swr";
|
||||||
import Image from "next/image";
|
|
||||||
import { ArrowCounterClockwise } from "@phosphor-icons/react";
|
import { ArrowCounterClockwise } from "@phosphor-icons/react";
|
||||||
|
|
||||||
import { Card, CardTitle } from "@/components/ui/card";
|
import { Card, CardTitle } from "@/components/ui/card";
|
||||||
import SuggestionCard from "@/app/components/suggestions/suggestionCard";
|
import SuggestionCard from "@/app/components/suggestions/suggestionCard";
|
||||||
import SidePanel from "@/app/components/sidePanel/chatHistorySidePanel";
|
import SidePanel from "@/app/components/sidePanel/chatHistorySidePanel";
|
||||||
import Loading from "@/app/components/loading/loading";
|
import Loading from "@/app/components/loading/loading";
|
||||||
import ChatInputArea, { ChatOptions } from "@/app/components/chatInputArea/chatInputArea";
|
import { ChatInputArea, ChatOptions } from "@/app/components/chatInputArea/chatInputArea";
|
||||||
import { Suggestion, suggestionsData } from "@/app/components/suggestions/suggestionsData";
|
import { Suggestion, suggestionsData } from "@/app/components/suggestions/suggestionsData";
|
||||||
import LoginPrompt from "@/app/components/loginPrompt/loginPrompt";
|
import LoginPrompt from "@/app/components/loginPrompt/loginPrompt";
|
||||||
|
|
||||||
import { useAuthenticatedData, UserConfig, useUserConfig } from "@/app/common/auth";
|
import {
|
||||||
|
isUserSubscribed,
|
||||||
|
useAuthenticatedData,
|
||||||
|
UserConfig,
|
||||||
|
useUserConfig,
|
||||||
|
} from "@/app/common/auth";
|
||||||
import { convertColorToBorderClass } from "@/app/common/colorUtils";
|
import { convertColorToBorderClass } from "@/app/common/colorUtils";
|
||||||
import { getIconFromIconName } from "@/app/common/iconUtils";
|
import { getIconFromIconName } from "@/app/common/iconUtils";
|
||||||
import { AgentData } from "@/app/agents/page";
|
import { AgentData } from "@/app/agents/page";
|
||||||
import { createNewConversation } from "./common/chatFunctions";
|
import { createNewConversation } from "./common/chatFunctions";
|
||||||
import { useIsMobileWidth } from "./common/utils";
|
import { useDebounce, useIsMobileWidth } from "./common/utils";
|
||||||
import { useSearchParams } from "next/navigation";
|
import { useRouter, useSearchParams } from "next/navigation";
|
||||||
|
import { ScrollArea, ScrollBar } from "@/components/ui/scroll-area";
|
||||||
|
import { AgentCard } from "@/app/components/agentCard/agentCard";
|
||||||
|
import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover";
|
||||||
|
|
||||||
interface ChatBodyDataProps {
|
interface ChatBodyDataProps {
|
||||||
chatOptionsData: ChatOptions | null;
|
chatOptionsData: ChatOptions | null;
|
||||||
|
@ -44,14 +51,19 @@ function FisherYatesShuffle(array: any[]) {
|
||||||
|
|
||||||
function ChatBodyData(props: ChatBodyDataProps) {
|
function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
const [message, setMessage] = useState("");
|
const [message, setMessage] = useState("");
|
||||||
const [image, setImage] = useState<string | null>(null);
|
const [images, setImages] = useState<string[]>([]);
|
||||||
const [processingMessage, setProcessingMessage] = useState(false);
|
const [processingMessage, setProcessingMessage] = useState(false);
|
||||||
const [greeting, setGreeting] = useState("");
|
const [greeting, setGreeting] = useState("");
|
||||||
const [shuffledOptions, setShuffledOptions] = useState<Suggestion[]>([]);
|
const [shuffledOptions, setShuffledOptions] = useState<Suggestion[]>([]);
|
||||||
|
const [hoveredAgent, setHoveredAgent] = useState<string | null>(null);
|
||||||
|
const debouncedHoveredAgent = useDebounce(hoveredAgent, 500);
|
||||||
|
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
|
||||||
const [selectedAgent, setSelectedAgent] = useState<string | null>("khoj");
|
const [selectedAgent, setSelectedAgent] = useState<string | null>("khoj");
|
||||||
const [agentIcons, setAgentIcons] = useState<JSX.Element[]>([]);
|
const [agentIcons, setAgentIcons] = useState<JSX.Element[]>([]);
|
||||||
const [agents, setAgents] = useState<AgentData[]>([]);
|
const [agents, setAgents] = useState<AgentData[]>([]);
|
||||||
|
const chatInputRef = useRef<HTMLTextAreaElement>(null);
|
||||||
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
|
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
|
||||||
|
const router = useRouter();
|
||||||
const searchParams = useSearchParams();
|
const searchParams = useSearchParams();
|
||||||
const queryParam = searchParams.get("q");
|
const queryParam = searchParams.get("q");
|
||||||
|
|
||||||
|
@ -61,6 +73,12 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
}
|
}
|
||||||
}, [queryParam]);
|
}, [queryParam]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (debouncedHoveredAgent) {
|
||||||
|
setIsPopoverOpen(true);
|
||||||
|
}
|
||||||
|
}, [debouncedHoveredAgent]);
|
||||||
|
|
||||||
const onConversationIdChange = props.onConversationIdChange;
|
const onConversationIdChange = props.onConversationIdChange;
|
||||||
|
|
||||||
const agentsFetcher = () =>
|
const agentsFetcher = () =>
|
||||||
|
@ -72,6 +90,10 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
revalidateOnFocus: false,
|
revalidateOnFocus: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const openAgentEditCard = (agentSlug: string) => {
|
||||||
|
router.push(`/agents?agent=${agentSlug}`);
|
||||||
|
};
|
||||||
|
|
||||||
function shuffleAndSetOptions() {
|
function shuffleAndSetOptions() {
|
||||||
const shuffled = FisherYatesShuffle(suggestionsData);
|
const shuffled = FisherYatesShuffle(suggestionsData);
|
||||||
setShuffledOptions(shuffled.slice(0, 3));
|
setShuffledOptions(shuffled.slice(0, 3));
|
||||||
|
@ -108,22 +130,13 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
}, [props.chatOptionsData]);
|
}, [props.chatOptionsData]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const nSlice = props.isMobileWidth ? 2 : 4;
|
const agents = (agentsData || []).filter((agent) => agent !== null && agent !== undefined);
|
||||||
const shuffledAgents = agentsData ? [...agentsData].sort(() => 0.5 - Math.random()) : [];
|
|
||||||
const agents = agentsData ? [agentsData[0]] : []; // Always add the first/default agent.
|
|
||||||
|
|
||||||
shuffledAgents.slice(0, nSlice - 1).forEach((agent) => {
|
|
||||||
if (!agents.find((a) => a.slug === agent.slug)) {
|
|
||||||
agents.push(agent);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
setAgents(agents);
|
setAgents(agents);
|
||||||
|
// set the first agent, which is always the default agent, as the default for chat
|
||||||
|
setSelectedAgent(agents.length > 1 ? agents[0].slug : "khoj");
|
||||||
|
|
||||||
//generate colored icons for the selected agents
|
// generate colored icons for the available agents
|
||||||
const agentIcons = agents
|
const agentIcons = agents.map((agent) => getIconFromIconName(agent.icon, agent.color)!);
|
||||||
.filter((agent) => agent !== null && agent !== undefined)
|
|
||||||
.map((agent) => getIconFromIconName(agent.icon, agent.color)!);
|
|
||||||
setAgentIcons(agentIcons);
|
setAgentIcons(agentIcons);
|
||||||
}, [agentsData, props.isMobileWidth]);
|
}, [agentsData, props.isMobileWidth]);
|
||||||
|
|
||||||
|
@ -138,24 +151,39 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
try {
|
try {
|
||||||
const newConversationId = await createNewConversation(selectedAgent || "khoj");
|
const newConversationId = await createNewConversation(selectedAgent || "khoj");
|
||||||
onConversationIdChange?.(newConversationId);
|
onConversationIdChange?.(newConversationId);
|
||||||
window.location.href = `/chat?conversationId=${newConversationId}`;
|
|
||||||
localStorage.setItem("message", message);
|
localStorage.setItem("message", message);
|
||||||
if (image) {
|
if (images.length > 0) {
|
||||||
localStorage.setItem("image", image);
|
localStorage.setItem("images", JSON.stringify(images));
|
||||||
}
|
}
|
||||||
|
window.location.href = `/chat?conversationId=${newConversationId}`;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error creating new conversation:", error);
|
console.error("Error creating new conversation:", error);
|
||||||
setProcessingMessage(false);
|
setProcessingMessage(false);
|
||||||
}
|
}
|
||||||
setMessage("");
|
setMessage("");
|
||||||
|
setImages([]);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
processMessage();
|
processMessage();
|
||||||
if (message) {
|
if (message || images.length > 0) {
|
||||||
setProcessingMessage(true);
|
setProcessingMessage(true);
|
||||||
}
|
}
|
||||||
}, [selectedAgent, message, processingMessage, onConversationIdChange]);
|
}, [selectedAgent, message, processingMessage, onConversationIdChange]);
|
||||||
|
|
||||||
|
// Close the agent detail hover card when scroll on agent pane
|
||||||
|
useEffect(() => {
|
||||||
|
const scrollAreaSelector = "[data-radix-scroll-area-viewport]";
|
||||||
|
const scrollAreaEl = document.querySelector<HTMLElement>(scrollAreaSelector);
|
||||||
|
const handleScroll = () => {
|
||||||
|
setHoveredAgent(null);
|
||||||
|
setIsPopoverOpen(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
scrollAreaEl?.addEventListener("scroll", handleScroll);
|
||||||
|
|
||||||
|
return () => scrollAreaEl?.removeEventListener("scroll", handleScroll);
|
||||||
|
}, []);
|
||||||
|
|
||||||
function fillArea(link: string, type: string, prompt: string) {
|
function fillArea(link: string, type: string, prompt: string) {
|
||||||
if (!link) {
|
if (!link) {
|
||||||
let message_str = "";
|
let message_str = "";
|
||||||
|
@ -194,37 +222,76 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
</h1>
|
</h1>
|
||||||
</div>
|
</div>
|
||||||
{!props.isMobileWidth && (
|
{!props.isMobileWidth && (
|
||||||
<div className="flex pb-6 gap-2 items-center justify-center">
|
<ScrollArea className="w-full max-w-[600px] mx-auto">
|
||||||
{agentIcons.map((icon, index) => (
|
<div className="flex pb-2 gap-2 items-center justify-center">
|
||||||
<Card
|
{agents.map((agent, index) => (
|
||||||
key={`${index}-${agents[index].slug}`}
|
<Popover
|
||||||
className={`${
|
key={`${index}-${agent.slug}`}
|
||||||
selectedAgent === agents[index].slug
|
open={isPopoverOpen && debouncedHoveredAgent === agent.slug}
|
||||||
? convertColorToBorderClass(agents[index].color)
|
onOpenChange={(open) => {
|
||||||
: "border-stone-100 dark:border-neutral-700 text-muted-foreground"
|
if (!open) {
|
||||||
}
|
setHoveredAgent(null);
|
||||||
hover:cursor-pointer rounded-lg px-2 py-2`}
|
setIsPopoverOpen(false);
|
||||||
>
|
}
|
||||||
<CardTitle
|
}}
|
||||||
className="text-center text-md font-medium flex justify-center items-center"
|
|
||||||
onClick={() => setSelectedAgent(agents[index].slug)}
|
|
||||||
>
|
>
|
||||||
{icon} {agents[index].name}
|
<PopoverTrigger asChild>
|
||||||
</CardTitle>
|
<Card
|
||||||
</Card>
|
className={`${
|
||||||
))}
|
selectedAgent === agent.slug
|
||||||
<Card
|
? convertColorToBorderClass(agent.color)
|
||||||
className="border-none shadow-none flex justify-center items-center hover:cursor-pointer"
|
: "border-stone-100 dark:border-neutral-700 text-muted-foreground"
|
||||||
onClick={() => (window.location.href = "/agents")}
|
}
|
||||||
>
|
hover:cursor-pointer rounded-lg px-2 py-2`}
|
||||||
<CardTitle className="text-center text-md font-normal flex justify-center items-center px-1.5 py-2">
|
onDoubleClick={() => openAgentEditCard(agent.slug)}
|
||||||
See All →
|
onClick={() => {
|
||||||
</CardTitle>
|
setSelectedAgent(agent.slug);
|
||||||
</Card>
|
chatInputRef.current?.focus();
|
||||||
</div>
|
setHoveredAgent(null);
|
||||||
|
setIsPopoverOpen(false);
|
||||||
|
}}
|
||||||
|
onMouseEnter={() => setHoveredAgent(agent.slug)}
|
||||||
|
onMouseLeave={() => {
|
||||||
|
setHoveredAgent(null);
|
||||||
|
setIsPopoverOpen(false);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<CardTitle className="text-center text-md font-medium flex justify-center items-center whitespace-nowrap">
|
||||||
|
{agentIcons[index]} {agent.name}
|
||||||
|
</CardTitle>
|
||||||
|
</Card>
|
||||||
|
</PopoverTrigger>
|
||||||
|
<PopoverContent
|
||||||
|
className="w-80 p-0 border-none bg-transparent shadow-none"
|
||||||
|
onMouseLeave={() => {
|
||||||
|
setHoveredAgent(null);
|
||||||
|
setIsPopoverOpen(false);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<AgentCard
|
||||||
|
data={agent}
|
||||||
|
userProfile={null}
|
||||||
|
isMobileWidth={props.isMobileWidth || false}
|
||||||
|
showChatButton={false}
|
||||||
|
editCard={false}
|
||||||
|
filesOptions={[]}
|
||||||
|
selectedChatModelOption=""
|
||||||
|
agentSlug=""
|
||||||
|
isSubscribed={isUserSubscribed(props.userConfig)}
|
||||||
|
setAgentChangeTriggered={() => {}}
|
||||||
|
modelOptions={[]}
|
||||||
|
inputToolOptions={{}}
|
||||||
|
outputModeOptions={{}}
|
||||||
|
/>
|
||||||
|
</PopoverContent>
|
||||||
|
</Popover>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
<ScrollBar orientation="horizontal" />
|
||||||
|
</ScrollArea>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit"}`}>
|
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit max-w-screen-md"}`}>
|
||||||
{!props.isMobileWidth && (
|
{!props.isMobileWidth && (
|
||||||
<div
|
<div
|
||||||
className={`w-full ${styles.inputBox} shadow-lg bg-background align-middle items-center justify-center px-3 py-1 dark:bg-neutral-700 border-stone-100 dark:border-none dark:shadow-none rounded-2xl`}
|
className={`w-full ${styles.inputBox} shadow-lg bg-background align-middle items-center justify-center px-3 py-1 dark:bg-neutral-700 border-stone-100 dark:border-none dark:shadow-none rounded-2xl`}
|
||||||
|
@ -232,12 +299,13 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
<ChatInputArea
|
<ChatInputArea
|
||||||
isLoggedIn={props.isLoggedIn}
|
isLoggedIn={props.isLoggedIn}
|
||||||
sendMessage={(message) => setMessage(message)}
|
sendMessage={(message) => setMessage(message)}
|
||||||
sendImage={(image) => setImage(image)}
|
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||||
sendDisabled={processingMessage}
|
sendDisabled={processingMessage}
|
||||||
chatOptionsData={props.chatOptionsData}
|
chatOptionsData={props.chatOptionsData}
|
||||||
conversationId={null}
|
conversationId={null}
|
||||||
isMobileWidth={props.isMobileWidth}
|
isMobileWidth={props.isMobileWidth}
|
||||||
setUploadedFiles={props.setUploadedFiles}
|
setUploadedFiles={props.setUploadedFiles}
|
||||||
|
ref={chatInputRef}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
@ -285,40 +353,40 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
<div
|
<div
|
||||||
className={`${styles.inputBox} pt-1 shadow-[0_-20px_25px_-5px_rgba(0,0,0,0.1)] dark:bg-neutral-700 bg-background align-middle items-center justify-center pb-3 mx-1 rounded-t-2xl rounded-b-none`}
|
className={`${styles.inputBox} pt-1 shadow-[0_-20px_25px_-5px_rgba(0,0,0,0.1)] dark:bg-neutral-700 bg-background align-middle items-center justify-center pb-3 mx-1 rounded-t-2xl rounded-b-none`}
|
||||||
>
|
>
|
||||||
<div className="flex gap-2 items-center justify-left pt-1 pb-2 px-12">
|
<ScrollArea className="w-full max-w-[85vw]">
|
||||||
{agentIcons.map((icon, index) => (
|
<div className="flex gap-2 items-center justify-left pt-1 pb-2 px-12">
|
||||||
<Card
|
{agentIcons.map((icon, index) => (
|
||||||
key={`${index}-${agents[index].slug}`}
|
<Card
|
||||||
className={`${selectedAgent === agents[index].slug ? convertColorToBorderClass(agents[index].color) : "border-muted text-muted-foreground"} hover:cursor-pointer`}
|
key={`${index}-${agents[index].slug}`}
|
||||||
>
|
className={`${selectedAgent === agents[index].slug ? convertColorToBorderClass(agents[index].color) : "border-muted text-muted-foreground"} hover:cursor-pointer`}
|
||||||
<CardTitle
|
|
||||||
className="text-center text-xs font-medium flex justify-center items-center px-1.5 py-1"
|
|
||||||
onClick={() => setSelectedAgent(agents[index].slug)}
|
|
||||||
>
|
>
|
||||||
{icon} {agents[index].name}
|
<CardTitle
|
||||||
</CardTitle>
|
className="text-center text-xs font-medium flex justify-center items-center px-1.5 py-1"
|
||||||
</Card>
|
onDoubleClick={() =>
|
||||||
))}
|
openAgentEditCard(agents[index].slug)
|
||||||
<Card
|
}
|
||||||
className="border-none shadow-none flex justify-center items-center hover:cursor-pointer"
|
onClick={() => {
|
||||||
onClick={() => (window.location.href = "/agents")}
|
setSelectedAgent(agents[index].slug);
|
||||||
>
|
chatInputRef.current?.focus();
|
||||||
<CardTitle
|
}}
|
||||||
className={`text-center ${props.isMobileWidth ? "text-xs" : "text-md"} font-normal flex justify-center items-center px-1.5 py-2`}
|
>
|
||||||
>
|
{icon} {agents[index].name}
|
||||||
See All →
|
</CardTitle>
|
||||||
</CardTitle>
|
</Card>
|
||||||
</Card>
|
))}
|
||||||
</div>
|
</div>
|
||||||
|
<ScrollBar orientation="horizontal" />
|
||||||
|
</ScrollArea>
|
||||||
<ChatInputArea
|
<ChatInputArea
|
||||||
isLoggedIn={props.isLoggedIn}
|
isLoggedIn={props.isLoggedIn}
|
||||||
sendMessage={(message) => setMessage(message)}
|
sendMessage={(message) => setMessage(message)}
|
||||||
sendImage={(image) => setImage(image)}
|
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||||
sendDisabled={processingMessage}
|
sendDisabled={processingMessage}
|
||||||
chatOptionsData={props.chatOptionsData}
|
chatOptionsData={props.chatOptionsData}
|
||||||
conversationId={null}
|
conversationId={null}
|
||||||
isMobileWidth={props.isMobileWidth}
|
isMobileWidth={props.isMobileWidth}
|
||||||
setUploadedFiles={props.setUploadedFiles}
|
setUploadedFiles={props.setUploadedFiles}
|
||||||
|
ref={chatInputRef}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
|
|
|
@ -513,7 +513,7 @@ export default function SettingsView() {
|
||||||
const isMobileWidth = useIsMobileWidth();
|
const isMobileWidth = useIsMobileWidth();
|
||||||
|
|
||||||
const cardClassName =
|
const cardClassName =
|
||||||
"w-full lg:w-1/3 grid grid-flow-column border border-gray-300 shadow-md rounded-lg bg-gradient-to-b from-background to-gray-50 dark:to-gray-950";
|
"w-full lg:w-1/3 grid grid-flow-column border border-gray-300 shadow-md rounded-lg bg-gradient-to-b from-background to-gray-50 dark:to-gray-950 border border-opacity-50";
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setUserConfig(initialUserConfig);
|
setUserConfig(initialUserConfig);
|
||||||
|
@ -640,6 +640,51 @@ export default function SettingsView() {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const enableFreeTrial = async () => {
|
||||||
|
const formatDate = (dateString: Date) => {
|
||||||
|
const date = new Date(dateString);
|
||||||
|
return new Intl.DateTimeFormat("en-US", {
|
||||||
|
day: "2-digit",
|
||||||
|
month: "short",
|
||||||
|
year: "numeric",
|
||||||
|
}).format(date);
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch(`/api/subscription/trial`, {
|
||||||
|
method: "POST",
|
||||||
|
});
|
||||||
|
if (!response.ok) throw new Error("Failed to enable free trial");
|
||||||
|
|
||||||
|
const responseBody = await response.json();
|
||||||
|
|
||||||
|
// Set updated user settings
|
||||||
|
if (responseBody.trial_enabled && userConfig) {
|
||||||
|
let newUserConfig = userConfig;
|
||||||
|
newUserConfig.subscription_state = SubscriptionStates.TRIAL;
|
||||||
|
const renewalDate = new Date(
|
||||||
|
Date.now() + userConfig.length_of_free_trial * 24 * 60 * 60 * 1000,
|
||||||
|
);
|
||||||
|
newUserConfig.subscription_renewal_date = formatDate(renewalDate);
|
||||||
|
newUserConfig.subscription_enabled_trial_at = new Date().toISOString();
|
||||||
|
setUserConfig(newUserConfig);
|
||||||
|
|
||||||
|
// Notify user of free trial
|
||||||
|
toast({
|
||||||
|
title: "🎉 Trial Enabled",
|
||||||
|
description: `Your free trial will end on ${newUserConfig.subscription_renewal_date}`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error enabling free trial:", error);
|
||||||
|
toast({
|
||||||
|
title: "⚠️ Failed to Enable Free Trial",
|
||||||
|
description:
|
||||||
|
"Failed to enable free trial. Try again or contact us at team@khoj.dev",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
const saveName = async () => {
|
const saveName = async () => {
|
||||||
if (!name) return;
|
if (!name) return;
|
||||||
try {
|
try {
|
||||||
|
@ -673,7 +718,7 @@ export default function SettingsView() {
|
||||||
};
|
};
|
||||||
|
|
||||||
const updateModel = (name: string) => async (id: string) => {
|
const updateModel = (name: string) => async (id: string) => {
|
||||||
if (!userConfig?.is_active && name !== "search") {
|
if (!userConfig?.is_active) {
|
||||||
toast({
|
toast({
|
||||||
title: `Model Update`,
|
title: `Model Update`,
|
||||||
description: `You need to be subscribed to update ${name} models`,
|
description: `You need to be subscribed to update ${name} models`,
|
||||||
|
@ -866,10 +911,13 @@ export default function SettingsView() {
|
||||||
Futurist (Trial)
|
Futurist (Trial)
|
||||||
</p>
|
</p>
|
||||||
<p className="text-gray-400">
|
<p className="text-gray-400">
|
||||||
You are on a 14 day trial of the Khoj
|
You are on a{" "}
|
||||||
Futurist plan. Check{" "}
|
{userConfig.length_of_free_trial} day trial
|
||||||
|
of the Khoj Futurist plan. Your trial ends
|
||||||
|
on {userConfig.subscription_renewal_date}.
|
||||||
|
Check{" "}
|
||||||
<a
|
<a
|
||||||
href="https://khoj.dev/pricing"
|
href="https://khoj.dev/#pricing"
|
||||||
target="_blank"
|
target="_blank"
|
||||||
>
|
>
|
||||||
pricing page
|
pricing page
|
||||||
|
@ -909,7 +957,7 @@ export default function SettingsView() {
|
||||||
)) ||
|
)) ||
|
||||||
(userConfig.subscription_state === "expired" && (
|
(userConfig.subscription_state === "expired" && (
|
||||||
<>
|
<>
|
||||||
<p className="text-xl">Free Plan</p>
|
<p className="text-xl">Humanist</p>
|
||||||
{(userConfig.subscription_renewal_date && (
|
{(userConfig.subscription_renewal_date && (
|
||||||
<p className="text-gray-400">
|
<p className="text-gray-400">
|
||||||
Subscription <b>expired</b> on{" "}
|
Subscription <b>expired</b> on{" "}
|
||||||
|
@ -923,7 +971,7 @@ export default function SettingsView() {
|
||||||
<p className="text-gray-400">
|
<p className="text-gray-400">
|
||||||
Check{" "}
|
Check{" "}
|
||||||
<a
|
<a
|
||||||
href="https://khoj.dev/pricing"
|
href="https://khoj.dev/#pricing"
|
||||||
target="_blank"
|
target="_blank"
|
||||||
>
|
>
|
||||||
pricing page
|
pricing page
|
||||||
|
@ -960,7 +1008,8 @@ export default function SettingsView() {
|
||||||
/>
|
/>
|
||||||
Resubscribe
|
Resubscribe
|
||||||
</Button>
|
</Button>
|
||||||
)) || (
|
)) ||
|
||||||
|
(userConfig.subscription_enabled_trial_at && (
|
||||||
<Button
|
<Button
|
||||||
variant="outline"
|
variant="outline"
|
||||||
className="text-primary/80 hover:text-primary"
|
className="text-primary/80 hover:text-primary"
|
||||||
|
@ -978,6 +1027,18 @@ export default function SettingsView() {
|
||||||
/>
|
/>
|
||||||
Subscribe
|
Subscribe
|
||||||
</Button>
|
</Button>
|
||||||
|
)) || (
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
className="text-primary/80 hover:text-primary"
|
||||||
|
onClick={enableFreeTrial}
|
||||||
|
>
|
||||||
|
<ArrowCircleUp
|
||||||
|
weight="bold"
|
||||||
|
className="h-5 w-5 mr-2"
|
||||||
|
/>
|
||||||
|
Enable Trial
|
||||||
|
</Button>
|
||||||
)}
|
)}
|
||||||
</CardFooter>
|
</CardFooter>
|
||||||
</Card>
|
</Card>
|
||||||
|
@ -1172,27 +1233,6 @@ export default function SettingsView() {
|
||||||
</CardFooter>
|
</CardFooter>
|
||||||
</Card>
|
</Card>
|
||||||
)}
|
)}
|
||||||
{userConfig.search_model_options.length > 0 && (
|
|
||||||
<Card className={cardClassName}>
|
|
||||||
<CardHeader className="text-xl flex flex-row">
|
|
||||||
<FileMagnifyingGlass className="h-7 w-7 mr-2" />
|
|
||||||
Search
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent className="overflow-hidden pb-12 grid gap-8 h-fit">
|
|
||||||
<p className="text-gray-400">
|
|
||||||
Pick the search model to find your documents
|
|
||||||
</p>
|
|
||||||
<DropdownComponent
|
|
||||||
items={userConfig.search_model_options}
|
|
||||||
selected={
|
|
||||||
userConfig.selected_search_model_config
|
|
||||||
}
|
|
||||||
callbackFunc={updateModel("search")}
|
|
||||||
/>
|
|
||||||
</CardContent>
|
|
||||||
<CardFooter className="flex flex-wrap gap-4"></CardFooter>
|
|
||||||
</Card>
|
|
||||||
)}
|
|
||||||
{userConfig.paint_model_options.length > 0 && (
|
{userConfig.paint_model_options.length > 0 && (
|
||||||
<Card className={cardClassName}>
|
<Card className={cardClassName}>
|
||||||
<CardHeader className="text-xl flex flex-row">
|
<CardHeader className="text-xl flex flex-row">
|
||||||
|
|
|
@ -27,7 +27,14 @@ export default function RootLayout({
|
||||||
child-src 'none';
|
child-src 'none';
|
||||||
object-src 'none';"
|
object-src 'none';"
|
||||||
></meta>
|
></meta>
|
||||||
<body className={inter.className}>{children}</body>
|
<body className={inter.className}>
|
||||||
|
{children}
|
||||||
|
<script
|
||||||
|
dangerouslySetInnerHTML={{
|
||||||
|
__html: `window.EXCALIDRAW_ASSET_PATH = 'https://assets.khoj.dev/@excalidraw/excalidraw/dist/';`,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</body>
|
||||||
</html>
|
</html>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,7 +13,7 @@ import "katex/dist/katex.min.css";
|
||||||
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../../common/utils";
|
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../../common/utils";
|
||||||
import { useAuthenticatedData } from "@/app/common/auth";
|
import { useAuthenticatedData } from "@/app/common/auth";
|
||||||
|
|
||||||
import ChatInputArea, { ChatOptions } from "@/app/components/chatInputArea/chatInputArea";
|
import { ChatInputArea, ChatOptions } from "@/app/components/chatInputArea/chatInputArea";
|
||||||
import { StreamMessage } from "@/app/components/chatMessage/chatMessage";
|
import { StreamMessage } from "@/app/components/chatMessage/chatMessage";
|
||||||
import { processMessageChunk } from "@/app/common/chatFunctions";
|
import { processMessageChunk } from "@/app/common/chatFunctions";
|
||||||
import { AgentData } from "@/app/agents/page";
|
import { AgentData } from "@/app/agents/page";
|
||||||
|
@ -28,22 +28,44 @@ interface ChatBodyDataProps {
|
||||||
isLoggedIn: boolean;
|
isLoggedIn: boolean;
|
||||||
conversationId?: string;
|
conversationId?: string;
|
||||||
setQueryToProcess: (query: string) => void;
|
setQueryToProcess: (query: string) => void;
|
||||||
setImage64: (image64: string) => void;
|
setImages: (images: string[]) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
function ChatBodyData(props: ChatBodyDataProps) {
|
function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
const [message, setMessage] = useState("");
|
const [message, setMessage] = useState("");
|
||||||
const [image, setImage] = useState<string | null>(null);
|
const [images, setImages] = useState<string[]>([]);
|
||||||
const [processingMessage, setProcessingMessage] = useState(false);
|
const [processingMessage, setProcessingMessage] = useState(false);
|
||||||
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
||||||
|
const chatInputRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
|
||||||
const setQueryToProcess = props.setQueryToProcess;
|
const setQueryToProcess = props.setQueryToProcess;
|
||||||
const streamedMessages = props.streamedMessages;
|
const streamedMessages = props.streamedMessages;
|
||||||
|
|
||||||
|
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (image) {
|
if (images.length > 0) {
|
||||||
props.setImage64(encodeURIComponent(image));
|
const encodedImages = images.map((image) => encodeURIComponent(image));
|
||||||
|
props.setImages(encodedImages);
|
||||||
}
|
}
|
||||||
}, [image, props.setImage64]);
|
}, [images, props.setImages]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const storedImages = localStorage.getItem("images");
|
||||||
|
if (storedImages) {
|
||||||
|
const parsedImages: string[] = JSON.parse(storedImages);
|
||||||
|
setImages(parsedImages);
|
||||||
|
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
|
||||||
|
props.setImages(encodedImages);
|
||||||
|
localStorage.removeItem("images");
|
||||||
|
}
|
||||||
|
|
||||||
|
const storedMessage = localStorage.getItem("message");
|
||||||
|
if (storedMessage) {
|
||||||
|
setProcessingMessage(true);
|
||||||
|
setQueryToProcess(storedMessage);
|
||||||
|
}
|
||||||
|
}, [setQueryToProcess, props.setImages]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (message) {
|
if (message) {
|
||||||
|
@ -78,21 +100,23 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
setTitle={props.setTitle}
|
setTitle={props.setTitle}
|
||||||
pendingMessage={processingMessage ? message : ""}
|
pendingMessage={processingMessage ? message : ""}
|
||||||
incomingMessages={props.streamedMessages}
|
incomingMessages={props.streamedMessages}
|
||||||
|
customClassName={chatHistoryCustomClassName}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<div
|
<div
|
||||||
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl`}
|
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl h-fit ${chatHistoryCustomClassName} mr-auto ml-auto`}
|
||||||
>
|
>
|
||||||
<ChatInputArea
|
<ChatInputArea
|
||||||
isLoggedIn={props.isLoggedIn}
|
isLoggedIn={props.isLoggedIn}
|
||||||
sendMessage={(message) => setMessage(message)}
|
sendMessage={(message) => setMessage(message)}
|
||||||
sendImage={(image) => setImage(image)}
|
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||||
sendDisabled={processingMessage}
|
sendDisabled={processingMessage}
|
||||||
chatOptionsData={props.chatOptionsData}
|
chatOptionsData={props.chatOptionsData}
|
||||||
conversationId={props.conversationId}
|
conversationId={props.conversationId}
|
||||||
agentColor={agentMetadata?.color}
|
agentColor={agentMetadata?.color}
|
||||||
isMobileWidth={props.isMobileWidth}
|
isMobileWidth={props.isMobileWidth}
|
||||||
setUploadedFiles={props.setUploadedFiles}
|
setUploadedFiles={props.setUploadedFiles}
|
||||||
|
ref={chatInputRef}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</>
|
</>
|
||||||
|
@ -109,7 +133,7 @@ export default function SharedChat() {
|
||||||
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
||||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||||
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
|
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
|
||||||
const [image64, setImage64] = useState<string>("");
|
const [images, setImages] = useState<string[]>([]);
|
||||||
|
|
||||||
const locationData = useIPLocationData() || {
|
const locationData = useIPLocationData() || {
|
||||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||||
|
@ -168,7 +192,7 @@ export default function SharedChat() {
|
||||||
completed: false,
|
completed: false,
|
||||||
timestamp: new Date().toISOString(),
|
timestamp: new Date().toISOString(),
|
||||||
rawQuery: queryToProcess || "",
|
rawQuery: queryToProcess || "",
|
||||||
uploadedImageData: decodeURIComponent(image64),
|
images: images,
|
||||||
};
|
};
|
||||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
||||||
setProcessQuerySignal(true);
|
setProcessQuerySignal(true);
|
||||||
|
@ -195,7 +219,7 @@ export default function SharedChat() {
|
||||||
if (done) {
|
if (done) {
|
||||||
setQueryToProcess("");
|
setQueryToProcess("");
|
||||||
setProcessQuerySignal(false);
|
setProcessQuerySignal(false);
|
||||||
setImage64("");
|
setImages([]);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -237,7 +261,7 @@ export default function SharedChat() {
|
||||||
country_code: locationData.countryCode,
|
country_code: locationData.countryCode,
|
||||||
timezone: locationData.timezone,
|
timezone: locationData.timezone,
|
||||||
}),
|
}),
|
||||||
...(image64 && { image: image64 }),
|
...(images.length > 0 && { image: images }),
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await fetch(chatAPI, {
|
const response = await fetch(chatAPI, {
|
||||||
|
@ -276,6 +300,19 @@ export default function SharedChat() {
|
||||||
|
|
||||||
<div className={styles.chatBox}>
|
<div className={styles.chatBox}>
|
||||||
<div className={styles.chatBoxBody}>
|
<div className={styles.chatBoxBody}>
|
||||||
|
{!isMobileWidth && title && (
|
||||||
|
<div
|
||||||
|
className={`${styles.chatTitleWrapper} text-nowrap text-ellipsis overflow-hidden max-w-screen-md grid items-top font-bold mr-8 pt-6 col-auto h-fit`}
|
||||||
|
>
|
||||||
|
{title && (
|
||||||
|
<h2
|
||||||
|
className={`text-lg text-ellipsis whitespace-nowrap overflow-x-hidden`}
|
||||||
|
>
|
||||||
|
{title}
|
||||||
|
</h2>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<Suspense fallback={<Loading />}>
|
<Suspense fallback={<Loading />}>
|
||||||
<ChatBodyData
|
<ChatBodyData
|
||||||
conversationId={conversationId}
|
conversationId={conversationId}
|
||||||
|
@ -287,7 +324,7 @@ export default function SharedChat() {
|
||||||
setTitle={setTitle}
|
setTitle={setTitle}
|
||||||
setUploadedFiles={setUploadedFiles}
|
setUploadedFiles={setUploadedFiles}
|
||||||
isMobileWidth={isMobileWidth}
|
isMobileWidth={isMobileWidth}
|
||||||
setImage64={setImage64}
|
setImages={setImages}
|
||||||
/>
|
/>
|
||||||
</Suspense>
|
</Suspense>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -75,7 +75,7 @@ div.titleBar {
|
||||||
div.chatBoxBody {
|
div.chatBoxBody {
|
||||||
display: grid;
|
display: grid;
|
||||||
height: 100%;
|
height: 100%;
|
||||||
width: 70%;
|
width: 95%;
|
||||||
margin: auto;
|
margin: auto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"name": "khoj-ai",
|
"name": "khoj-ai",
|
||||||
"version": "1.25.0",
|
"version": "1.26.4",
|
||||||
"private": true,
|
"private": true,
|
||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "next dev",
|
"dev": "next dev",
|
||||||
|
@ -63,7 +63,8 @@
|
||||||
"swr": "^2.2.5",
|
"swr": "^2.2.5",
|
||||||
"typescript": "^5",
|
"typescript": "^5",
|
||||||
"vaul": "^0.9.1",
|
"vaul": "^0.9.1",
|
||||||
"zod": "^3.23.8"
|
"zod": "^3.23.8",
|
||||||
|
"@excalidraw/excalidraw": "^0.17.6"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@types/dompurify": "^3.0.5",
|
"@types/dompurify": "^3.0.5",
|
||||||
|
|
|
@ -286,6 +286,11 @@
|
||||||
resolved "https://registry.yarnpkg.com/@eslint/js/-/js-8.57.1.tgz#de633db3ec2ef6a3c89e2f19038063e8a122e2c2"
|
resolved "https://registry.yarnpkg.com/@eslint/js/-/js-8.57.1.tgz#de633db3ec2ef6a3c89e2f19038063e8a122e2c2"
|
||||||
integrity sha512-d9zaMRSTIKDLhctzH12MtXvJKSSUhaHcjV+2Z+GK+EEY7XKpP5yR4x+N3TAcHTcu963nIr+TMcCb4DBCYX1z6Q==
|
integrity sha512-d9zaMRSTIKDLhctzH12MtXvJKSSUhaHcjV+2Z+GK+EEY7XKpP5yR4x+N3TAcHTcu963nIr+TMcCb4DBCYX1z6Q==
|
||||||
|
|
||||||
|
"@excalidraw/excalidraw@^0.17.6":
|
||||||
|
version "0.17.6"
|
||||||
|
resolved "https://registry.yarnpkg.com/@excalidraw/excalidraw/-/excalidraw-0.17.6.tgz#5fd208ce69d33ca712d1804b50d7d06d5c46ac4d"
|
||||||
|
integrity sha512-fyCl+zG/Z5yhHDh5Fq2ZGmphcrALmuOdtITm8gN4d8w4ntnaopTXcTfnAAaU3VleDC6LhTkoLOTG6P5kgREiIg==
|
||||||
|
|
||||||
"@floating-ui/core@^1.6.0":
|
"@floating-ui/core@^1.6.0":
|
||||||
version "1.6.8"
|
version "1.6.8"
|
||||||
resolved "https://registry.yarnpkg.com/@floating-ui/core/-/core-1.6.8.tgz#aa43561be075815879305965020f492cdb43da12"
|
resolved "https://registry.yarnpkg.com/@floating-ui/core/-/core-1.6.8.tgz#aa43561be075815879305965020f492cdb43da12"
|
||||||
|
|
|
@ -108,7 +108,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
password="default",
|
password="default",
|
||||||
)
|
)
|
||||||
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
|
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
|
||||||
Subscription.objects.create(user=default_user, type="standard", renewal_date=renewal_date)
|
Subscription.objects.create(user=default_user, type=Subscription.Type.STANDARD, renewal_date=renewal_date)
|
||||||
|
|
||||||
async def authenticate(self, request: HTTPConnection):
|
async def authenticate(self, request: HTTPConnection):
|
||||||
current_user = request.session.get("user")
|
current_user = request.session.get("user")
|
||||||
|
@ -172,7 +172,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="create_user",
|
api="create_user",
|
||||||
metadata={"user_id": str(user.uuid)},
|
metadata={"server_id": str(user.uuid)},
|
||||||
)
|
)
|
||||||
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
||||||
else:
|
else:
|
||||||
|
@ -312,7 +312,7 @@ def configure_routes(app):
|
||||||
logger.info("🔑 Enabled Authentication")
|
logger.info("🔑 Enabled Authentication")
|
||||||
|
|
||||||
if state.billing_enabled:
|
if state.billing_enabled:
|
||||||
from khoj.routers.subscription import subscription_router
|
from khoj.routers.api_subscription import subscription_router
|
||||||
|
|
||||||
app.include_router(subscription_router, prefix="/api/subscription")
|
app.include_router(subscription_router, prefix="/api/subscription")
|
||||||
logger.info("💳 Enabled Billing")
|
logger.info("💳 Enabled Billing")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
|
@ -10,7 +11,6 @@ from enum import Enum
|
||||||
from typing import Callable, Iterable, List, Optional, Type
|
from typing import Callable, Iterable, List, Optional, Type
|
||||||
|
|
||||||
import cron_descriptor
|
import cron_descriptor
|
||||||
import django
|
|
||||||
from apscheduler.job import Job
|
from apscheduler.job import Job
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from django.contrib.sessions.backends.db import SessionStore
|
from django.contrib.sessions.backends.db import SessionStore
|
||||||
|
@ -52,6 +52,7 @@ from khoj.database.models import (
|
||||||
UserTextToImageModelConfig,
|
UserTextToImageModelConfig,
|
||||||
UserVoiceModelConfig,
|
UserVoiceModelConfig,
|
||||||
VoiceModelOption,
|
VoiceModelOption,
|
||||||
|
WebScraper,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
|
@ -59,11 +60,19 @@ from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
from khoj.utils.config import OfflineChatProcessorModel
|
from khoj.utils.config import OfflineChatProcessorModel
|
||||||
from khoj.utils.helpers import generate_random_name, is_none_or_empty, timer
|
from khoj.utils.helpers import (
|
||||||
|
generate_random_name,
|
||||||
|
in_debug_mode,
|
||||||
|
is_none_or_empty,
|
||||||
|
timer,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
LENGTH_OF_FREE_TRIAL = 7 #
|
||||||
|
|
||||||
|
|
||||||
class SubscriptionState(Enum):
|
class SubscriptionState(Enum):
|
||||||
TRIAL = "trial"
|
TRIAL = "trial"
|
||||||
SUBSCRIBED = "subscribed"
|
SUBSCRIBED = "subscribed"
|
||||||
|
@ -162,7 +171,7 @@ async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
|
||||||
)
|
)
|
||||||
await user.asave()
|
await user.asave()
|
||||||
|
|
||||||
await Subscription.objects.acreate(user=user, type="trial")
|
await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
@ -179,11 +188,29 @@ async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]:
|
||||||
|
|
||||||
user_subscription = await Subscription.objects.filter(user=user).afirst()
|
user_subscription = await Subscription.objects.filter(user=user).afirst()
|
||||||
if not user_subscription:
|
if not user_subscription:
|
||||||
await Subscription.objects.acreate(user=user, type="trial")
|
await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
|
||||||
|
|
||||||
return user, is_new
|
return user, is_new
|
||||||
|
|
||||||
|
|
||||||
|
async def astart_trial_subscription(user: KhojUser) -> Subscription:
|
||||||
|
subscription = await Subscription.objects.filter(user=user).afirst()
|
||||||
|
if not subscription:
|
||||||
|
raise HTTPException(status_code=400, detail="User does not have a subscription")
|
||||||
|
|
||||||
|
if subscription.type == Subscription.Type.TRIAL:
|
||||||
|
raise HTTPException(status_code=400, detail="User already has a trial subscription")
|
||||||
|
|
||||||
|
if subscription.enabled_trial_at:
|
||||||
|
raise HTTPException(status_code=400, detail="User already has a trial subscription")
|
||||||
|
|
||||||
|
subscription.type = Subscription.Type.TRIAL
|
||||||
|
subscription.enabled_trial_at = datetime.now(tz=timezone.utc)
|
||||||
|
subscription.renewal_date = datetime.now(tz=timezone.utc) + timedelta(days=LENGTH_OF_FREE_TRIAL)
|
||||||
|
await subscription.asave()
|
||||||
|
return subscription
|
||||||
|
|
||||||
|
|
||||||
async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser:
|
async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser:
|
||||||
user = await KhojUser.objects.filter(email_verification_code=code).afirst()
|
user = await KhojUser.objects.filter(email_verification_code=code).afirst()
|
||||||
if not user:
|
if not user:
|
||||||
|
@ -215,7 +242,7 @@ async def create_user_by_google_token(token: dict) -> KhojUser:
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
await Subscription.objects.acreate(user=user, type="trial")
|
await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
@ -273,16 +300,15 @@ def subscription_to_state(subscription: Subscription) -> str:
|
||||||
if not subscription:
|
if not subscription:
|
||||||
return SubscriptionState.INVALID.value
|
return SubscriptionState.INVALID.value
|
||||||
elif subscription.type == Subscription.Type.TRIAL:
|
elif subscription.type == Subscription.Type.TRIAL:
|
||||||
# Trial subscription is valid for 7 days
|
# Check if the trial has expired
|
||||||
if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=14):
|
if datetime.now(tz=timezone.utc) > subscription.renewal_date:
|
||||||
return SubscriptionState.EXPIRED.value
|
return SubscriptionState.EXPIRED.value
|
||||||
|
|
||||||
return SubscriptionState.TRIAL.value
|
return SubscriptionState.TRIAL.value
|
||||||
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
elif subscription.is_recurring and subscription.renewal_date > datetime.now(tz=timezone.utc):
|
||||||
return SubscriptionState.SUBSCRIBED.value
|
return SubscriptionState.SUBSCRIBED.value
|
||||||
elif not subscription.is_recurring and subscription.renewal_date is None:
|
elif not subscription.is_recurring and subscription.renewal_date is None:
|
||||||
return SubscriptionState.EXPIRED.value
|
return SubscriptionState.EXPIRED.value
|
||||||
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
elif not subscription.is_recurring and subscription.renewal_date > datetime.now(tz=timezone.utc):
|
||||||
return SubscriptionState.UNSUBSCRIBED.value
|
return SubscriptionState.UNSUBSCRIBED.value
|
||||||
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
|
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
|
||||||
return SubscriptionState.EXPIRED.value
|
return SubscriptionState.EXPIRED.value
|
||||||
|
@ -440,18 +466,26 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def get_user_search_model_or_default(user=None):
|
def get_default_search_model() -> SearchModelConfig:
|
||||||
if user and UserSearchModelConfig.objects.filter(user=user).exists():
|
default_search_model = SearchModelConfig.objects.filter(name="default").first()
|
||||||
return UserSearchModelConfig.objects.filter(user=user).first().setting
|
|
||||||
|
|
||||||
if SearchModelConfig.objects.filter(name="default").exists():
|
if default_search_model:
|
||||||
return SearchModelConfig.objects.filter(name="default").first()
|
return default_search_model
|
||||||
else:
|
else:
|
||||||
SearchModelConfig.objects.create()
|
SearchModelConfig.objects.create()
|
||||||
|
|
||||||
return SearchModelConfig.objects.first()
|
return SearchModelConfig.objects.first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_default_search_model(user: KhojUser = None) -> SearchModelConfig:
|
||||||
|
if user:
|
||||||
|
user_search_model = UserSearchModelConfig.objects.filter(user=user).first()
|
||||||
|
if user_search_model:
|
||||||
|
return user_search_model.setting
|
||||||
|
|
||||||
|
return get_default_search_model()
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_search_models():
|
def get_or_create_search_models():
|
||||||
search_models = SearchModelConfig.objects.all()
|
search_models = SearchModelConfig.objects.all()
|
||||||
if search_models.count() == 0:
|
if search_models.count() == 0:
|
||||||
|
@ -461,21 +495,6 @@ def get_or_create_search_models():
|
||||||
return search_models
|
return search_models
|
||||||
|
|
||||||
|
|
||||||
async def aset_user_search_model(user: KhojUser, search_model_config_id: int):
|
|
||||||
config = await SearchModelConfig.objects.filter(id=search_model_config_id).afirst()
|
|
||||||
if not config:
|
|
||||||
return None
|
|
||||||
new_config, _ = await UserSearchModelConfig.objects.aupdate_or_create(user=user, defaults={"setting": config})
|
|
||||||
return new_config
|
|
||||||
|
|
||||||
|
|
||||||
async def aget_user_search_model(user: KhojUser):
|
|
||||||
config = await UserSearchModelConfig.objects.filter(user=user).prefetch_related("setting").afirst()
|
|
||||||
if not config:
|
|
||||||
return None
|
|
||||||
return config.setting
|
|
||||||
|
|
||||||
|
|
||||||
class ProcessLockAdapters:
|
class ProcessLockAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_process_lock(process_name: str):
|
def get_process_lock(process_name: str):
|
||||||
|
@ -616,6 +635,8 @@ class AgentAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all_accessible_agents(user: KhojUser = None):
|
def get_all_accessible_agents(user: KhojUser = None):
|
||||||
public_query = Q(privacy_level=Agent.PrivacyLevel.PUBLIC)
|
public_query = Q(privacy_level=Agent.PrivacyLevel.PUBLIC)
|
||||||
|
# TODO Update this to allow any public agent that's officially approved once that experience is launched
|
||||||
|
public_query &= Q(managed_by_admin=True)
|
||||||
if user:
|
if user:
|
||||||
return (
|
return (
|
||||||
Agent.objects.filter(public_query | Q(creator=user))
|
Agent.objects.filter(public_query | Q(creator=user))
|
||||||
|
@ -634,6 +655,16 @@ class AgentAdapters:
|
||||||
agents = await sync_to_async(AgentAdapters.get_all_accessible_agents)(user)
|
agents = await sync_to_async(AgentAdapters.get_all_accessible_agents)(user)
|
||||||
return await sync_to_async(list)(agents)
|
return await sync_to_async(list)(agents)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def ais_agent_accessible(agent: Agent, user: KhojUser) -> bool:
|
||||||
|
if agent.privacy_level == Agent.PrivacyLevel.PUBLIC:
|
||||||
|
return True
|
||||||
|
if agent.creator == user:
|
||||||
|
return True
|
||||||
|
if agent.privacy_level == Agent.PrivacyLevel.PROTECTED:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_conversation_agent_by_id(agent_id: int):
|
def get_conversation_agent_by_id(agent_id: int):
|
||||||
agent = Agent.objects.filter(id=agent_id).first()
|
agent = Agent.objects.filter(id=agent_id).first()
|
||||||
|
@ -1031,6 +1062,70 @@ class ConversationAdapters:
|
||||||
return server_chat_settings.chat_advanced
|
return server_chat_settings.chat_advanced
|
||||||
return await ConversationAdapters.aget_default_conversation_config(user)
|
return await ConversationAdapters.aget_default_conversation_config(user)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_server_webscraper():
|
||||||
|
server_chat_settings = await ServerChatSettings.objects.filter().prefetch_related("web_scraper").afirst()
|
||||||
|
if server_chat_settings is not None and server_chat_settings.web_scraper is not None:
|
||||||
|
return server_chat_settings.web_scraper
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_enabled_webscrapers() -> list[WebScraper]:
|
||||||
|
enabled_scrapers: list[WebScraper] = []
|
||||||
|
server_webscraper = await ConversationAdapters.aget_server_webscraper()
|
||||||
|
if server_webscraper:
|
||||||
|
# Only use the webscraper set in the server chat settings
|
||||||
|
enabled_scrapers = [server_webscraper]
|
||||||
|
if not enabled_scrapers:
|
||||||
|
# Use the enabled web scrapers, ordered by priority, until get web page content
|
||||||
|
enabled_scrapers = [scraper async for scraper in WebScraper.objects.all().order_by("priority").aiterator()]
|
||||||
|
if not enabled_scrapers:
|
||||||
|
# Use scrapers enabled via environment variables
|
||||||
|
if os.getenv("FIRECRAWL_API_KEY"):
|
||||||
|
api_url = os.getenv("FIRECRAWL_API_URL", "https://api.firecrawl.dev")
|
||||||
|
enabled_scrapers.append(
|
||||||
|
WebScraper(
|
||||||
|
type=WebScraper.WebScraperType.FIRECRAWL,
|
||||||
|
name=WebScraper.WebScraperType.FIRECRAWL.capitalize(),
|
||||||
|
api_key=os.getenv("FIRECRAWL_API_KEY"),
|
||||||
|
api_url=api_url,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if os.getenv("OLOSTEP_API_KEY"):
|
||||||
|
api_url = os.getenv("OLOSTEP_API_URL", "https://agent.olostep.com/olostep-p2p-incomingAPI")
|
||||||
|
enabled_scrapers.append(
|
||||||
|
WebScraper(
|
||||||
|
type=WebScraper.WebScraperType.OLOSTEP,
|
||||||
|
name=WebScraper.WebScraperType.OLOSTEP.capitalize(),
|
||||||
|
api_key=os.getenv("OLOSTEP_API_KEY"),
|
||||||
|
api_url=api_url,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Jina is the default fallback scrapers to use as it does not require an API key
|
||||||
|
api_url = os.getenv("JINA_READER_API_URL", "https://r.jina.ai/")
|
||||||
|
enabled_scrapers.append(
|
||||||
|
WebScraper(
|
||||||
|
type=WebScraper.WebScraperType.JINA,
|
||||||
|
name=WebScraper.WebScraperType.JINA.capitalize(),
|
||||||
|
api_key=os.getenv("JINA_API_KEY"),
|
||||||
|
api_url=api_url,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only enable the direct web page scraper by default in self-hosted single user setups.
|
||||||
|
# Useful for reading webpages on your intranet.
|
||||||
|
if state.anonymous_mode or in_debug_mode():
|
||||||
|
enabled_scrapers.append(
|
||||||
|
WebScraper(
|
||||||
|
type=WebScraper.WebScraperType.DIRECT,
|
||||||
|
name=WebScraper.WebScraperType.DIRECT.capitalize(),
|
||||||
|
api_key=None,
|
||||||
|
api_url=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return enabled_scrapers
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_conversation_from_public_conversation(
|
def create_conversation_from_public_conversation(
|
||||||
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
|
user: KhojUser, public_conversation: PublicConversation, client_app: ClientApplication
|
||||||
|
@ -1393,12 +1488,15 @@ class EntryAdapters:
|
||||||
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
|
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
|
||||||
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
|
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
|
||||||
|
|
||||||
user_or_agent = Q(user=user)
|
owner_filter = Q()
|
||||||
|
|
||||||
|
if user != None:
|
||||||
|
owner_filter = Q(user=user)
|
||||||
if agent != None:
|
if agent != None:
|
||||||
user_or_agent |= Q(agent=agent)
|
owner_filter |= Q(agent=agent)
|
||||||
|
|
||||||
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
||||||
return Entry.objects.filter(user_or_agent)
|
return Entry.objects.filter(owner_filter)
|
||||||
|
|
||||||
for term in word_filters:
|
for term in word_filters:
|
||||||
if term.startswith("+"):
|
if term.startswith("+"):
|
||||||
|
@ -1434,7 +1532,7 @@ class EntryAdapters:
|
||||||
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
|
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
|
||||||
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
|
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
|
||||||
|
|
||||||
relevant_entries = Entry.objects.filter(user_or_agent).filter(q_filter_terms)
|
relevant_entries = Entry.objects.filter(owner_filter).filter(q_filter_terms)
|
||||||
if file_type_filter:
|
if file_type_filter:
|
||||||
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
||||||
return relevant_entries
|
return relevant_entries
|
||||||
|
@ -1449,13 +1547,18 @@ class EntryAdapters:
|
||||||
max_distance: float = math.inf,
|
max_distance: float = math.inf,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
user_or_agent = Q(user=user)
|
owner_filter = Q()
|
||||||
|
|
||||||
|
if user != None:
|
||||||
|
owner_filter = Q(user=user)
|
||||||
if agent != None:
|
if agent != None:
|
||||||
user_or_agent |= Q(agent=agent)
|
owner_filter |= Q(agent=agent)
|
||||||
|
|
||||||
|
if owner_filter == Q():
|
||||||
|
return Entry.objects.none()
|
||||||
|
|
||||||
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter, agent)
|
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter, agent)
|
||||||
relevant_entries = relevant_entries.filter(user_or_agent).annotate(
|
relevant_entries = relevant_entries.filter(owner_filter).annotate(
|
||||||
distance=CosineDistance("embeddings", embeddings)
|
distance=CosineDistance("embeddings", embeddings)
|
||||||
)
|
)
|
||||||
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
|
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
|
||||||
|
|
|
@ -31,6 +31,7 @@ from khoj.database.models import (
|
||||||
UserSearchModelConfig,
|
UserSearchModelConfig,
|
||||||
UserVoiceModelConfig,
|
UserVoiceModelConfig,
|
||||||
VoiceModelOption,
|
VoiceModelOption,
|
||||||
|
WebScraper,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import ImageIntentType
|
from khoj.utils.helpers import ImageIntentType
|
||||||
|
|
||||||
|
@ -69,10 +70,11 @@ class KhojUserAdmin(UserAdmin):
|
||||||
"id",
|
"id",
|
||||||
"email",
|
"email",
|
||||||
"username",
|
"username",
|
||||||
|
"phone_number",
|
||||||
"is_active",
|
"is_active",
|
||||||
|
"uuid",
|
||||||
"is_staff",
|
"is_staff",
|
||||||
"is_superuser",
|
"is_superuser",
|
||||||
"phone_number",
|
|
||||||
)
|
)
|
||||||
search_fields = ("email", "username", "phone_number", "uuid")
|
search_fields = ("email", "username", "phone_number", "uuid")
|
||||||
filter_horizontal = ("groups", "user_permissions")
|
filter_horizontal = ("groups", "user_permissions")
|
||||||
|
@ -124,6 +126,7 @@ class EntryAdmin(admin.ModelAdmin):
|
||||||
"created_at",
|
"created_at",
|
||||||
"updated_at",
|
"updated_at",
|
||||||
"user",
|
"user",
|
||||||
|
"agent",
|
||||||
"file_source",
|
"file_source",
|
||||||
"file_type",
|
"file_type",
|
||||||
"file_name",
|
"file_name",
|
||||||
|
@ -133,6 +136,7 @@ class EntryAdmin(admin.ModelAdmin):
|
||||||
list_filter = (
|
list_filter = (
|
||||||
"file_type",
|
"file_type",
|
||||||
"user__email",
|
"user__email",
|
||||||
|
"search_model__name",
|
||||||
)
|
)
|
||||||
ordering = ("-created_at",)
|
ordering = ("-created_at",)
|
||||||
|
|
||||||
|
@ -197,9 +201,24 @@ class ServerChatSettingsAdmin(admin.ModelAdmin):
|
||||||
list_display = (
|
list_display = (
|
||||||
"chat_default",
|
"chat_default",
|
||||||
"chat_advanced",
|
"chat_advanced",
|
||||||
|
"web_scraper",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@admin.register(WebScraper)
|
||||||
|
class WebScraperAdmin(admin.ModelAdmin):
|
||||||
|
list_display = (
|
||||||
|
"priority",
|
||||||
|
"name",
|
||||||
|
"type",
|
||||||
|
"api_key",
|
||||||
|
"api_url",
|
||||||
|
"created_at",
|
||||||
|
)
|
||||||
|
search_fields = ("name", "api_key", "api_url", "type")
|
||||||
|
ordering = ("priority",)
|
||||||
|
|
||||||
|
|
||||||
@admin.register(Conversation)
|
@admin.register(Conversation)
|
||||||
class ConversationAdmin(admin.ModelAdmin):
|
class ConversationAdmin(admin.ModelAdmin):
|
||||||
list_display = (
|
list_display = (
|
||||||
|
|
182
src/khoj/database/management/commands/change_default_model.py
Normal file
182
src/khoj/database/management/commands/change_default_model.py
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from django.core.management.base import BaseCommand
|
||||||
|
from django.db import transaction
|
||||||
|
from django.db.models import Count, Q
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from khoj.database.adapters import get_default_search_model
|
||||||
|
from khoj.database.models import (
|
||||||
|
Agent,
|
||||||
|
Entry,
|
||||||
|
KhojUser,
|
||||||
|
SearchModelConfig,
|
||||||
|
UserSearchModelConfig,
|
||||||
|
)
|
||||||
|
from khoj.processor.embeddings import EmbeddingsModel
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Command(BaseCommand):
|
||||||
|
help = "Convert all existing Entry objects to use a new default Search model."
|
||||||
|
|
||||||
|
def add_arguments(self, parser):
|
||||||
|
# Pass default SearchModelConfig ID
|
||||||
|
parser.add_argument(
|
||||||
|
"--search_model_id",
|
||||||
|
action="store",
|
||||||
|
help="ID of the SearchModelConfig object to set as the default search model for all existing Entry objects and UserSearchModelConfig objects.",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the apply flag to apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects.
|
||||||
|
parser.add_argument(
|
||||||
|
"--apply",
|
||||||
|
action="store_true",
|
||||||
|
help="Apply the new default Search model to all existing Entry objects and UserSearchModelConfig objects. Otherwise, only display the number of Entry objects and UserSearchModelConfig objects that will be affected.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle(self, *args, **options):
|
||||||
|
@transaction.atomic
|
||||||
|
def regenerate_entries(entry_filter: Q, embeddings_model: EmbeddingsModel, search_model: SearchModelConfig):
|
||||||
|
entries = Entry.objects.filter(entry_filter).all()
|
||||||
|
compiled_entries = [entry.compiled for entry in entries]
|
||||||
|
updated_entries: List[Entry] = []
|
||||||
|
try:
|
||||||
|
embeddings = embeddings_model.embed_documents(compiled_entries)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error embedding documents: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
for i, entry in enumerate(tqdm(entries)):
|
||||||
|
entry.embeddings = embeddings[i]
|
||||||
|
entry.search_model_id = search_model.id
|
||||||
|
updated_entries.append(entry)
|
||||||
|
|
||||||
|
Entry.objects.bulk_update(updated_entries, ["embeddings", "search_model_id", "file_path"])
|
||||||
|
|
||||||
|
search_model_config_id = options.get("search_model_id")
|
||||||
|
apply = options.get("apply")
|
||||||
|
|
||||||
|
logger.info(f"SearchModelConfig ID: {search_model_config_id}")
|
||||||
|
logger.info(f"Apply: {apply}")
|
||||||
|
|
||||||
|
embeddings_model = dict()
|
||||||
|
|
||||||
|
search_models = SearchModelConfig.objects.all()
|
||||||
|
for model in search_models:
|
||||||
|
embeddings_model.update(
|
||||||
|
{
|
||||||
|
model.name: EmbeddingsModel(
|
||||||
|
model.bi_encoder,
|
||||||
|
model.embeddings_inference_endpoint,
|
||||||
|
model.embeddings_inference_endpoint_api_key,
|
||||||
|
query_encode_kwargs=model.bi_encoder_query_encode_config,
|
||||||
|
docs_encode_kwargs=model.bi_encoder_docs_encode_config,
|
||||||
|
model_kwargs=model.bi_encoder_model_config,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
new_default_search_model_config = SearchModelConfig.objects.get(id=search_model_config_id)
|
||||||
|
logger.info(f"New default Search model: {new_default_search_model_config}")
|
||||||
|
user_search_model_configs_to_update = UserSearchModelConfig.objects.exclude(
|
||||||
|
setting_id=search_model_config_id
|
||||||
|
).all()
|
||||||
|
logger.info(f"Number of UserSearchModelConfig objects to update: {user_search_model_configs_to_update.count()}")
|
||||||
|
|
||||||
|
for user_config in user_search_model_configs_to_update:
|
||||||
|
affected_user = user_config.user
|
||||||
|
entry_filter = Q(user=affected_user)
|
||||||
|
relevant_entries = Entry.objects.filter(entry_filter).all()
|
||||||
|
logger.info(f"Number of Entry objects to update for user {affected_user}: {relevant_entries.count()}")
|
||||||
|
|
||||||
|
if apply:
|
||||||
|
try:
|
||||||
|
regenerate_entries(
|
||||||
|
entry_filter,
|
||||||
|
embeddings_model[new_default_search_model_config.name],
|
||||||
|
new_default_search_model_config,
|
||||||
|
)
|
||||||
|
user_config.setting = new_default_search_model_config
|
||||||
|
user_config.save()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Updated UserSearchModelConfig object for user {affected_user} to use the new default Search model."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Updated {relevant_entries.count()} Entry objects for user {affected_user} to use the new default Search model."
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error embedding documents: {e}")
|
||||||
|
|
||||||
|
logger.info("----")
|
||||||
|
|
||||||
|
# There are also plenty of users who have indexed documents without explicitly creating a UserSearchModelConfig object. You would have to migrate these users as well, if the default is different from search_model_config_id.
|
||||||
|
current_default = get_default_search_model()
|
||||||
|
if current_default.id != new_default_search_model_config.id:
|
||||||
|
users_without_user_search_model_config = KhojUser.objects.annotate(
|
||||||
|
user_search_model_config_count=Count("usersearchmodelconfig")
|
||||||
|
).filter(user_search_model_config_count=0)
|
||||||
|
|
||||||
|
logger.info(f"Number of User objects to update: {users_without_user_search_model_config.count()}")
|
||||||
|
for user in users_without_user_search_model_config:
|
||||||
|
entry_filter = Q(user=user)
|
||||||
|
relevant_entries = Entry.objects.filter(entry_filter).all()
|
||||||
|
logger.info(f"Number of Entry objects to update for user {user}: {relevant_entries.count()}")
|
||||||
|
|
||||||
|
if apply:
|
||||||
|
try:
|
||||||
|
regenerate_entries(
|
||||||
|
entry_filter,
|
||||||
|
embeddings_model[new_default_search_model_config.name],
|
||||||
|
new_default_search_model_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
UserSearchModelConfig.objects.create(user=user, setting=new_default_search_model_config)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Created UserSearchModelConfig object for user {user} to use the new default Search model."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Updated {relevant_entries.count()} Entry objects for user {user} to use the new default Search model."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error embedding documents: {e}")
|
||||||
|
else:
|
||||||
|
logger.info("Default is the same as search_model_config_id.")
|
||||||
|
|
||||||
|
all_agents = Agent.objects.all()
|
||||||
|
logger.info(f"Number of Agent objects to update: {all_agents.count()}")
|
||||||
|
for agent in all_agents:
|
||||||
|
entry_filter = Q(agent=agent)
|
||||||
|
relevant_entries = Entry.objects.filter(entry_filter).all()
|
||||||
|
logger.info(f"Number of Entry objects to update for agent {agent}: {relevant_entries.count()}")
|
||||||
|
|
||||||
|
if apply:
|
||||||
|
try:
|
||||||
|
regenerate_entries(
|
||||||
|
entry_filter,
|
||||||
|
embeddings_model[new_default_search_model_config.name],
|
||||||
|
new_default_search_model_config,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Updated {relevant_entries.count()} Entry objects for agent {agent} to use the new default Search model."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error embedding documents: {e}")
|
||||||
|
if apply and current_default.id != new_default_search_model_config.id:
|
||||||
|
# Get the existing default SearchModelConfig object and update its name
|
||||||
|
current_default.name = f"prev_default_{current_default.id}"
|
||||||
|
current_default.save()
|
||||||
|
|
||||||
|
# Update the new default SearchModelConfig object's name
|
||||||
|
new_default_search_model_config.name = "default"
|
||||||
|
new_default_search_model_config.save()
|
||||||
|
if not apply:
|
||||||
|
logger.info("Run the command with the --apply flag to apply the new default Search model.")
|
|
@ -0,0 +1,24 @@
|
||||||
|
# Generated by Django 5.0.8 on 2024-10-17 18:13
|
||||||
|
|
||||||
|
import django.contrib.postgres.fields
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0067_alter_agent_style_icon"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="agent",
|
||||||
|
name="output_modes",
|
||||||
|
field=django.contrib.postgres.fields.ArrayField(
|
||||||
|
base_field=models.CharField(
|
||||||
|
choices=[("text", "Text"), ("image", "Image"), ("automation", "Automation")], max_length=200
|
||||||
|
),
|
||||||
|
default=list,
|
||||||
|
size=None,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -0,0 +1,89 @@
|
||||||
|
# Generated by Django 5.0.8 on 2024-10-18 00:41
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0068_alter_agent_output_modes"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.CreateModel(
|
||||||
|
name="WebScraper",
|
||||||
|
fields=[
|
||||||
|
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||||
|
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||||
|
("updated_at", models.DateTimeField(auto_now=True)),
|
||||||
|
(
|
||||||
|
"name",
|
||||||
|
models.CharField(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
help_text="Friendly name. If not set, it will be set to the type of the scraper.",
|
||||||
|
max_length=200,
|
||||||
|
null=True,
|
||||||
|
unique=True,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"type",
|
||||||
|
models.CharField(
|
||||||
|
choices=[
|
||||||
|
("Firecrawl", "Firecrawl"),
|
||||||
|
("Olostep", "Olostep"),
|
||||||
|
("Jina", "Jina"),
|
||||||
|
("Direct", "Direct"),
|
||||||
|
],
|
||||||
|
default="Jina",
|
||||||
|
max_length=20,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"api_key",
|
||||||
|
models.CharField(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
help_text="API key of the web scraper. Only set if scraper service requires an API key. Default is set from env var.",
|
||||||
|
max_length=200,
|
||||||
|
null=True,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"api_url",
|
||||||
|
models.URLField(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
help_text="API URL of the web scraper. Only set if scraper service on non-default URL.",
|
||||||
|
null=True,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"priority",
|
||||||
|
models.IntegerField(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
help_text="Priority of the web scraper. Lower numbers run first.",
|
||||||
|
null=True,
|
||||||
|
unique=True,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
options={
|
||||||
|
"abstract": False,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="serverchatsettings",
|
||||||
|
name="web_scraper",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.CASCADE,
|
||||||
|
related_name="web_scraper",
|
||||||
|
to="database.webscraper",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -0,0 +1,46 @@
|
||||||
|
# Generated by Django 5.0.8 on 2024-10-21 05:16
|
||||||
|
|
||||||
|
import django.contrib.postgres.fields
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0069_webscraper_serverchatsettings_web_scraper"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="agent",
|
||||||
|
name="input_tools",
|
||||||
|
field=django.contrib.postgres.fields.ArrayField(
|
||||||
|
base_field=models.CharField(
|
||||||
|
choices=[
|
||||||
|
("general", "General"),
|
||||||
|
("online", "Online"),
|
||||||
|
("notes", "Notes"),
|
||||||
|
("summarize", "Summarize"),
|
||||||
|
("webpage", "Webpage"),
|
||||||
|
],
|
||||||
|
max_length=200,
|
||||||
|
),
|
||||||
|
blank=True,
|
||||||
|
default=list,
|
||||||
|
null=True,
|
||||||
|
size=None,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="agent",
|
||||||
|
name="output_modes",
|
||||||
|
field=django.contrib.postgres.fields.ArrayField(
|
||||||
|
base_field=models.CharField(
|
||||||
|
choices=[("text", "Text"), ("image", "Image"), ("automation", "Automation")], max_length=200
|
||||||
|
),
|
||||||
|
blank=True,
|
||||||
|
default=list,
|
||||||
|
null=True,
|
||||||
|
size=None,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Generated by Django 5.0.8 on 2024-10-20 19:24
|
||||||
|
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
def set_enabled_trial_at(apps, schema_editor):
|
||||||
|
Subscription = apps.get_model("database", "Subscription")
|
||||||
|
for subscription in Subscription.objects.all():
|
||||||
|
subscription.enabled_trial_at = subscription.created_at
|
||||||
|
subscription.save()
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0070_alter_agent_input_tools_alter_agent_output_modes"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="subscription",
|
||||||
|
name="enabled_trial_at",
|
||||||
|
field=models.DateTimeField(blank=True, default=None, null=True),
|
||||||
|
),
|
||||||
|
migrations.AlterField(
|
||||||
|
model_name="subscription",
|
||||||
|
name="type",
|
||||||
|
field=models.CharField(
|
||||||
|
choices=[("trial", "Trial"), ("standard", "Standard")], default="standard", max_length=20
|
||||||
|
),
|
||||||
|
),
|
||||||
|
migrations.RunPython(set_enabled_trial_at),
|
||||||
|
]
|
24
src/khoj/database/migrations/0072_entry_search_model.py
Normal file
24
src/khoj/database/migrations/0072_entry_search_model.py
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
# Generated by Django 5.0.8 on 2024-10-21 21:09
|
||||||
|
|
||||||
|
import django.db.models.deletion
|
||||||
|
from django.db import migrations, models
|
||||||
|
|
||||||
|
|
||||||
|
class Migration(migrations.Migration):
|
||||||
|
dependencies = [
|
||||||
|
("database", "0071_subscription_enabled_trial_at_and_more"),
|
||||||
|
]
|
||||||
|
|
||||||
|
operations = [
|
||||||
|
migrations.AddField(
|
||||||
|
model_name="entry",
|
||||||
|
name="search_model",
|
||||||
|
field=models.ForeignKey(
|
||||||
|
blank=True,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
on_delete=django.db.models.deletion.SET_NULL,
|
||||||
|
to="database.searchmodelconfig",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from random import choice
|
from random import choice
|
||||||
|
@ -11,8 +12,6 @@ from django.dispatch import receiver
|
||||||
from pgvector.django import VectorField
|
from pgvector.django import VectorField
|
||||||
from phonenumber_field.modelfields import PhoneNumberField
|
from phonenumber_field.modelfields import PhoneNumberField
|
||||||
|
|
||||||
from khoj.utils.helpers import ConversationCommand
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(models.Model):
|
class BaseModel(models.Model):
|
||||||
created_at = models.DateTimeField(auto_now_add=True)
|
created_at = models.DateTimeField(auto_now_add=True)
|
||||||
|
@ -74,9 +73,10 @@ class Subscription(BaseModel):
|
||||||
STANDARD = "standard"
|
STANDARD = "standard"
|
||||||
|
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
|
||||||
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
|
type = models.CharField(max_length=20, choices=Type.choices, default=Type.STANDARD)
|
||||||
is_recurring = models.BooleanField(default=False)
|
is_recurring = models.BooleanField(default=False)
|
||||||
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
|
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
|
||||||
|
enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIProcessorConversationConfig(BaseModel):
|
class OpenAIProcessorConversationConfig(BaseModel):
|
||||||
|
@ -174,14 +174,19 @@ class Agent(BaseModel):
|
||||||
# These map to various ConversationCommand types
|
# These map to various ConversationCommand types
|
||||||
TEXT = "text"
|
TEXT = "text"
|
||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
|
AUTOMATION = "automation"
|
||||||
|
|
||||||
creator = models.ForeignKey(
|
creator = models.ForeignKey(
|
||||||
KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True
|
KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True
|
||||||
) # Creator will only be null when the agents are managed by admin
|
) # Creator will only be null when the agents are managed by admin
|
||||||
name = models.CharField(max_length=200)
|
name = models.CharField(max_length=200)
|
||||||
personality = models.TextField()
|
personality = models.TextField()
|
||||||
input_tools = ArrayField(models.CharField(max_length=200, choices=InputToolOptions.choices), default=list)
|
input_tools = ArrayField(
|
||||||
output_modes = ArrayField(models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list)
|
models.CharField(max_length=200, choices=InputToolOptions.choices), default=list, null=True, blank=True
|
||||||
|
)
|
||||||
|
output_modes = ArrayField(
|
||||||
|
models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True
|
||||||
|
)
|
||||||
managed_by_admin = models.BooleanField(default=False)
|
managed_by_admin = models.BooleanField(default=False)
|
||||||
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
|
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
|
||||||
slug = models.CharField(max_length=200, unique=True)
|
slug = models.CharField(max_length=200, unique=True)
|
||||||
|
@ -243,6 +248,79 @@ class GithubRepoConfig(BaseModel):
|
||||||
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
|
github_config = models.ForeignKey(GithubConfig, on_delete=models.CASCADE, related_name="githubrepoconfig")
|
||||||
|
|
||||||
|
|
||||||
|
class WebScraper(BaseModel):
|
||||||
|
class WebScraperType(models.TextChoices):
|
||||||
|
FIRECRAWL = "Firecrawl"
|
||||||
|
OLOSTEP = "Olostep"
|
||||||
|
JINA = "Jina"
|
||||||
|
DIRECT = "Direct"
|
||||||
|
|
||||||
|
name = models.CharField(
|
||||||
|
max_length=200,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
blank=True,
|
||||||
|
unique=True,
|
||||||
|
help_text="Friendly name. If not set, it will be set to the type of the scraper.",
|
||||||
|
)
|
||||||
|
type = models.CharField(max_length=20, choices=WebScraperType.choices, default=WebScraperType.JINA)
|
||||||
|
api_key = models.CharField(
|
||||||
|
max_length=200,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
blank=True,
|
||||||
|
help_text="API key of the web scraper. Only set if scraper service requires an API key. Default is set from env var.",
|
||||||
|
)
|
||||||
|
api_url = models.URLField(
|
||||||
|
max_length=200,
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
blank=True,
|
||||||
|
help_text="API URL of the web scraper. Only set if scraper service on non-default URL.",
|
||||||
|
)
|
||||||
|
priority = models.IntegerField(
|
||||||
|
default=None,
|
||||||
|
null=True,
|
||||||
|
blank=True,
|
||||||
|
unique=True,
|
||||||
|
help_text="Priority of the web scraper. Lower numbers run first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def clean(self):
|
||||||
|
error = {}
|
||||||
|
if self.name is None:
|
||||||
|
self.name = self.type.capitalize()
|
||||||
|
if self.api_url is None:
|
||||||
|
if self.type == self.WebScraperType.FIRECRAWL:
|
||||||
|
self.api_url = os.getenv("FIRECRAWL_API_URL", "https://api.firecrawl.dev")
|
||||||
|
elif self.type == self.WebScraperType.OLOSTEP:
|
||||||
|
self.api_url = os.getenv("OLOSTEP_API_URL", "https://agent.olostep.com/olostep-p2p-incomingAPI")
|
||||||
|
elif self.type == self.WebScraperType.JINA:
|
||||||
|
self.api_url = os.getenv("JINA_READER_API_URL", "https://r.jina.ai/")
|
||||||
|
if self.api_key is None:
|
||||||
|
if self.type == self.WebScraperType.FIRECRAWL:
|
||||||
|
self.api_key = os.getenv("FIRECRAWL_API_KEY")
|
||||||
|
if not self.api_key and self.api_url == "https://api.firecrawl.dev":
|
||||||
|
error["api_key"] = "Set API key to use default Firecrawl. Get API key from https://firecrawl.dev."
|
||||||
|
elif self.type == self.WebScraperType.OLOSTEP:
|
||||||
|
self.api_key = os.getenv("OLOSTEP_API_KEY")
|
||||||
|
if self.api_key is None:
|
||||||
|
error["api_key"] = "Set API key to use Olostep. Get API key from https://olostep.com/."
|
||||||
|
elif self.type == self.WebScraperType.JINA:
|
||||||
|
self.api_key = os.getenv("JINA_API_KEY")
|
||||||
|
if error:
|
||||||
|
raise ValidationError(error)
|
||||||
|
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
self.clean()
|
||||||
|
|
||||||
|
if self.priority is None:
|
||||||
|
max_priority = WebScraper.objects.aggregate(models.Max("priority"))["priority__max"]
|
||||||
|
self.priority = max_priority + 1 if max_priority else 1
|
||||||
|
|
||||||
|
super().save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ServerChatSettings(BaseModel):
|
class ServerChatSettings(BaseModel):
|
||||||
chat_default = models.ForeignKey(
|
chat_default = models.ForeignKey(
|
||||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
|
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_default"
|
||||||
|
@ -250,6 +328,9 @@ class ServerChatSettings(BaseModel):
|
||||||
chat_advanced = models.ForeignKey(
|
chat_advanced = models.ForeignKey(
|
||||||
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
|
ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="chat_advanced"
|
||||||
)
|
)
|
||||||
|
web_scraper = models.ForeignKey(
|
||||||
|
WebScraper, on_delete=models.CASCADE, default=None, null=True, blank=True, related_name="web_scraper"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LocalOrgConfig(BaseModel):
|
class LocalOrgConfig(BaseModel):
|
||||||
|
@ -368,6 +449,7 @@ class UserVoiceModelConfig(BaseModel):
|
||||||
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
setting = models.ForeignKey(VoiceModelOption, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO Delete this model once all users have been migrated to the server's default settings
|
||||||
class UserSearchModelConfig(BaseModel):
|
class UserSearchModelConfig(BaseModel):
|
||||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||||
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
setting = models.ForeignKey(SearchModelConfig, on_delete=models.CASCADE)
|
||||||
|
@ -454,6 +536,7 @@ class Entry(BaseModel):
|
||||||
url = models.URLField(max_length=400, default=None, null=True, blank=True)
|
url = models.URLField(max_length=400, default=None, null=True, blank=True)
|
||||||
hashed_value = models.CharField(max_length=100)
|
hashed_value = models.CharField(max_length=100)
|
||||||
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
||||||
|
search_model = models.ForeignKey(SearchModelConfig, on_delete=models.SET_NULL, default=None, null=True, blank=True)
|
||||||
|
|
||||||
def save(self, *args, **kwargs):
|
def save(self, *args, **kwargs):
|
||||||
if self.user and self.agent:
|
if self.user and self.agent:
|
||||||
|
|
|
@ -14,5 +14,6 @@
|
||||||
clip-rule="evenodd"
|
clip-rule="evenodd"
|
||||||
fill-rule="evenodd"
|
fill-rule="evenodd"
|
||||||
fill="currentColor"
|
fill="currentColor"
|
||||||
|
stroke="currentColor"
|
||||||
stroke-width="0.95844" />
|
stroke-width="0.95844" />
|
||||||
</svg>
|
</svg>
|
||||||
|
|
Before Width: | Height: | Size: 2.5 KiB After Width: | Height: | Size: 2.5 KiB |
|
@ -12,6 +12,7 @@
|
||||||
fill-rule="evenodd"
|
fill-rule="evenodd"
|
||||||
clip-rule="evenodd"
|
clip-rule="evenodd"
|
||||||
fill-opacity="1"
|
fill-opacity="1"
|
||||||
|
stroke="currentColor"
|
||||||
stroke-width="1.16584"
|
stroke-width="1.16584"
|
||||||
stroke-dasharray="none"
|
stroke-dasharray="none"
|
||||||
/>
|
/>
|
||||||
|
|
Before Width: | Height: | Size: 1.5 KiB After Width: | Height: | Size: 1.5 KiB |
24
src/khoj/interface/web/assets/icons/chat.svg
Normal file
24
src/khoj/interface/web/assets/icons/chat.svg
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg"
|
||||||
|
width="800px"
|
||||||
|
height="800px"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
fill="none"
|
||||||
|
version="1.1">
|
||||||
|
<path
|
||||||
|
d="m 14.024348,9.8497703 0.04627,1.9750167"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="1.77073"
|
||||||
|
stroke-linecap="round" />
|
||||||
|
<path
|
||||||
|
d="m 9.6453624,9.7953624 0.046275,1.9750166"
|
||||||
|
stroke="currentColor"
|
||||||
|
stroke-width="1.77072"
|
||||||
|
stroke-linecap="round" />
|
||||||
|
<path
|
||||||
|
d="m 11.90538,2.3619994 c -5.4939109,0 -9.6890976,4.0608185 -9.6890976,9.8578926 0,1.477202 0.2658016,2.542848 0.6989332,3.331408 0.433559,0.789293 1.0740097,1.372483 1.9230615,1.798517 1.7362861,0.87132 4.1946007,1.018626 7.0671029,1.018626 0.317997,0 0.593711,0.167879 0.784844,0.458501 0.166463,0.253124 0.238617,0.552748 0.275566,0.787233 0.07263,0.460801 0.05871,1.030165 0.04785,1.474824 v 4.8e-5 l -2.26e-4,0.0091 c -0.0085,0.348246 -0.01538,0.634247 -0.0085,0.861186 0.105589,-0.07971 0.227925,-0.185287 0.36735,-0.31735 0.348613,-0.330307 0.743513,-0.767362 1.176607,-1.246635 l 0.07837,-0.08673 c 0.452675,-0.500762 0.941688,-1.037938 1.41216,-1.473209 0.453774,-0.419787 0.969948,-0.822472 1.476003,-0.953853 1.323661,-0.343655 2.330132,-0.904027 3.005749,-1.76381 0.658957,-0.838568 1.073167,-2.051868 1.073167,-3.898667 0,-5.7970748 -4.195186,-9.8578946 -9.689097,-9.8578946 z M 0.92440678,12.219892 c 0,-7.0067939 5.05909412,-11.47090892 10.98097322,-11.47090892 5.921878,0 10.980972,4.46411502 10.980972,11.47090892 0,2.172259 -0.497596,3.825405 -1.442862,5.028357 -0.928601,1.181693 -2.218843,1.837914 -3.664937,2.213334 -0.211641,0.05502 -0.53529,0.268579 -0.969874,0.670658 -0.417861,0.386604 -0.865628,0.876836 -1.324566,1.384504 l -0.09131,0.101202 c -0.419252,0.464136 -0.849637,0.94059 -1.239338,1.309807 -0.210187,0.199169 -0.425281,0.383422 -0.635348,0.523424 -0.200911,0.133819 -0.449635,0.263369 -0.716376,0.281474 -0.327812,0.02226 -0.61539,-0.149209 -0.804998,-0.457293 -0.157614,-0.255993 -0.217622,-0.557143 -0.246564,-0.778198 -0.0542,-0.414027 -0.04101,-0.933065 -0.03027,-1.355183 l 0.0024,-0.0922 c 0.01099,-0.463865 0.01489,-0.820507 -0.01611,-1.06842 C 8.9434608,19.975238 6.3139711,19.828758 4.356743,18.84659 3.3355029,18.334136 2.4624526,17.578678 1.8500164,16.463713 1.2372016,15.348029 0.92459928,13.943803 0.92459928,12.219967 Z"
|
||||||
|
clip-rule="evenodd"
|
||||||
|
stroke-width="0.360886"
|
||||||
|
fill="currentColor"
|
||||||
|
fill-rule="evenodd"
|
||||||
|
fill-opacity="1" />
|
||||||
|
</svg>
|
After Width: | Height: | Size: 2.4 KiB |
|
@ -46,33 +46,16 @@
|
||||||
<p>Transform the way you think, create, and remember</p>
|
<p>Transform the way you think, create, and remember</p>
|
||||||
<div class="features">
|
<div class="features">
|
||||||
<div class="feature">
|
<div class="feature">
|
||||||
<svg viewBox="0 0 24 24" width="24" height="24" stroke="currentColor" stroke-width="2"
|
<img src="/static/assets/icons/chat.svg" alt="Chat" width="24" height="24">
|
||||||
fill="none">
|
|
||||||
<path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z" />
|
|
||||||
<path d="M14 2v6h6" />
|
|
||||||
<path d="M16 13H8" />
|
|
||||||
<path d="M16 17H8" />
|
|
||||||
<path d="M10 9H8" />
|
|
||||||
</svg>
|
|
||||||
<span>Get answers across your documents and the internet</span>
|
<span>Get answers across your documents and the internet</span>
|
||||||
</div>
|
</div>
|
||||||
<div class="feature">
|
<div class="feature">
|
||||||
<svg viewBox="0 0 24 24" width="24" height="24" stroke="currentColor" stroke-width="2"
|
<img src="/static/assets/icons/agents.svg" alt="Agents" width="24" height="24">
|
||||||
fill="none">
|
<span>Create agents with the knowledge and tools to take on any role</span>
|
||||||
<path
|
|
||||||
d="M21 16V8a2 2 0 0 0-1-1.73l-7-4a2 2 0 0 0-2 0l-7 4A2 2 0 0 0 3 8v8a2 2 0 0 0 1 1.73l7 4a2 2 0 0 0 2 0l7-4A2 2 0 0 0 21 16z" />
|
|
||||||
<path d="M3.3 7l8.7 5 8.7-5" />
|
|
||||||
</svg>
|
|
||||||
<span>Go deeper in the topics personal to you</span>
|
|
||||||
</div>
|
</div>
|
||||||
<div class="feature">
|
<div class="feature">
|
||||||
<svg viewBox="0 0 24 24" width="24" height="24" stroke="currentColor" stroke-width="2"
|
<img src="/static/assets/icons/automation.svg" alt="Automations" width="24" height="24">
|
||||||
fill="none">
|
<span>Automate away repetitive research</span>
|
||||||
<path d="M12 2L2 7l10 5 10-5-10-5z" />
|
|
||||||
<path d="M2 17l10 5 10-5" />
|
|
||||||
<path d="M2 12l10 5 10-5" />
|
|
||||||
</svg>
|
|
||||||
<span>Use specialized agents</span>
|
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
@ -160,6 +143,12 @@
|
||||||
height: 24px;
|
height: 24px;
|
||||||
stroke: white;
|
stroke: white;
|
||||||
}
|
}
|
||||||
|
.feature img {
|
||||||
|
width: 24px;
|
||||||
|
height: 24px;
|
||||||
|
filter: invert(100%) sepia(0%) saturate(0%) hue-rotate(0deg) brightness(100%) contrast(100%);
|
||||||
|
stroke: white;
|
||||||
|
}
|
||||||
|
|
||||||
#login-modal {
|
#login-modal {
|
||||||
display: grid;
|
display: grid;
|
||||||
|
|
|
@ -64,6 +64,8 @@ class ImageToEntries(TextToEntries):
|
||||||
tmp_file = f"tmp_image_file_{timestamp_now}.png"
|
tmp_file = f"tmp_image_file_{timestamp_now}.png"
|
||||||
elif image_file.endswith(".jpg") or image_file.endswith(".jpeg"):
|
elif image_file.endswith(".jpg") or image_file.endswith(".jpeg"):
|
||||||
tmp_file = f"tmp_image_file_{timestamp_now}.jpg"
|
tmp_file = f"tmp_image_file_{timestamp_now}.jpg"
|
||||||
|
elif image_file.endswith(".webp"):
|
||||||
|
tmp_file = f"tmp_image_file_{timestamp_now}.webp"
|
||||||
with open(tmp_file, "wb") as f:
|
with open(tmp_file, "wb") as f:
|
||||||
bytes = image_files[image_file]
|
bytes = image_files[image_file]
|
||||||
f.write(bytes)
|
f.write(bytes)
|
||||||
|
|
|
@ -67,7 +67,7 @@ class PdfToEntries(TextToEntries):
|
||||||
bytes = pdf_files[pdf_file]
|
bytes = pdf_files[pdf_file]
|
||||||
f.write(bytes)
|
f.write(bytes)
|
||||||
try:
|
try:
|
||||||
loader = PyMuPDFLoader(f"{tmp_file}", extract_images=True)
|
loader = PyMuPDFLoader(f"{tmp_file}", extract_images=False)
|
||||||
pdf_entries_per_file = [page.page_content for page in loader.load()]
|
pdf_entries_per_file = [page.page_content for page in loader.load()]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
loader = PyMuPDFLoader(f"{tmp_file}")
|
loader = PyMuPDFLoader(f"{tmp_file}")
|
||||||
|
|
|
@ -12,7 +12,8 @@ from tqdm import tqdm
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
FileObjectAdapters,
|
FileObjectAdapters,
|
||||||
get_user_search_model_or_default,
|
get_default_search_model,
|
||||||
|
get_user_default_search_model,
|
||||||
)
|
)
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import EntryDates, KhojUser
|
from khoj.database.models import EntryDates, KhojUser
|
||||||
|
@ -148,10 +149,10 @@ class TextToEntries(ABC):
|
||||||
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
hashes_to_process |= hashes_for_file - existing_entry_hashes
|
||||||
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
model = get_user_default_search_model(user=user)
|
||||||
with timer("Generated embeddings for entries to add to database in", logger):
|
with timer("Generated embeddings for entries to add to database in", logger):
|
||||||
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
entries_to_process = [hash_to_current_entries[hashed_val] for hashed_val in hashes_to_process]
|
||||||
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
data_to_embed = [getattr(entry, key) for entry in entries_to_process]
|
||||||
model = get_user_search_model_or_default(user)
|
|
||||||
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
|
embeddings += self.embeddings_model[model.name].embed_documents(data_to_embed)
|
||||||
|
|
||||||
added_entries: list[DbEntry] = []
|
added_entries: list[DbEntry] = []
|
||||||
|
@ -177,6 +178,7 @@ class TextToEntries(ABC):
|
||||||
file_type=file_type,
|
file_type=file_type,
|
||||||
hashed_value=entry_hash,
|
hashed_value=entry_hash,
|
||||||
corpus_id=entry.corpus_id,
|
corpus_id=entry.corpus_id,
|
||||||
|
search_model=model,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -6,14 +6,17 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.google.utils import (
|
from khoj.processor.conversation.google.utils import (
|
||||||
format_messages_for_gemini,
|
format_messages_for_gemini,
|
||||||
gemini_chat_completion_with_backoff,
|
gemini_chat_completion_with_backoff,
|
||||||
gemini_completion_with_backoff,
|
gemini_completion_with_backoff,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
from khoj.processor.conversation.utils import (
|
||||||
|
construct_structured_message,
|
||||||
|
generate_chatml_messages_with_context,
|
||||||
|
)
|
||||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
|
@ -29,6 +32,8 @@ def extract_questions_gemini(
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
|
query_images: Optional[list[str]] = None,
|
||||||
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -70,17 +75,17 @@ def extract_questions_gemini(
|
||||||
text=text,
|
text=text,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [ChatMessage(content=prompt, role="user")]
|
prompt = construct_structured_message(
|
||||||
|
message=prompt,
|
||||||
|
images=query_images,
|
||||||
|
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
model_kwargs = {"response_mime_type": "application/json"}
|
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
||||||
|
|
||||||
response = gemini_completion_with_backoff(
|
response = gemini_send_message_to_model(
|
||||||
messages=messages,
|
messages, api_key, model, response_type="json_object", temperature=temperature
|
||||||
system_prompt=system_prompt,
|
|
||||||
model_name=model,
|
|
||||||
temperature=temperature,
|
|
||||||
api_key=api_key,
|
|
||||||
model_kwargs=model_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from Gemini's Response
|
# Extract, Clean Message from Gemini's Response
|
||||||
|
@ -102,7 +107,7 @@ def extract_questions_gemini(
|
||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
|
||||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
|
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
|
@ -114,7 +119,12 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
|
||||||
|
|
||||||
# Get Response from Gemini
|
# Get Response from Gemini
|
||||||
return gemini_completion_with_backoff(
|
return gemini_completion_with_backoff(
|
||||||
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
|
messages=messages,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
model_name=model,
|
||||||
|
api_key=api_key,
|
||||||
|
temperature=temperature,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -134,6 +144,8 @@ def converse_gemini(
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_images: Optional[list[str]] = None,
|
||||||
|
vision_available: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Converse with user using Google's Gemini
|
Converse with user using Google's Gemini
|
||||||
|
@ -192,6 +204,9 @@ def converse_gemini(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
|
query_images=query_images,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
from io import BytesIO
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
|
import PIL.Image
|
||||||
|
import requests
|
||||||
from google.generativeai.types.answer_types import FinishReason
|
from google.generativeai.types.answer_types import FinishReason
|
||||||
from google.generativeai.types.generation_types import StopCandidateException
|
from google.generativeai.types.generation_types import StopCandidateException
|
||||||
from google.generativeai.types.safety_types import (
|
from google.generativeai.types.safety_types import (
|
||||||
|
@ -53,14 +56,14 @@ def gemini_completion_with_backoff(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||||
|
|
||||||
# Start chat session. All messages up to the last are considered to be part of the chat history
|
# Start chat session. All messages up to the last are considered to be part of the chat history
|
||||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate the response. The last message is considered to be the current prompt
|
# Generate the response. The last message is considered to be the current prompt
|
||||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
|
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||||
return aggregated_response.text
|
return aggregated_response.text
|
||||||
except StopCandidateException as e:
|
except StopCandidateException as e:
|
||||||
response_message, _ = handle_gemini_response(e.args)
|
response_message, _ = handle_gemini_response(e.args)
|
||||||
|
@ -117,11 +120,11 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||||
# all messages up to the last are considered to be part of the chat history
|
# all messages up to the last are considered to be part of the chat history
|
||||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||||
# the last message is considered to be the current prompt
|
# the last message is considered to be the current prompt
|
||||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
|
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
||||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||||
message = message or chunk.text
|
message = message or chunk.text
|
||||||
g.send(message)
|
g.send(message)
|
||||||
|
@ -148,6 +151,10 @@ def handle_gemini_response(candidates, prompt_feedback=None):
|
||||||
elif candidates[0].finish_reason == FinishReason.SAFETY:
|
elif candidates[0].finish_reason == FinishReason.SAFETY:
|
||||||
message = generate_safety_response(candidates[0].safety_ratings)
|
message = generate_safety_response(candidates[0].safety_ratings)
|
||||||
stopped = True
|
stopped = True
|
||||||
|
# Check if finish reason is empty, therefore generation is in progress
|
||||||
|
elif not candidates[0].finish_reason:
|
||||||
|
message = None
|
||||||
|
stopped = False
|
||||||
# Check if the response was stopped due to reaching maximum token limit or other reasons
|
# Check if the response was stopped due to reaching maximum token limit or other reasons
|
||||||
elif candidates[0].finish_reason != FinishReason.STOP:
|
elif candidates[0].finish_reason != FinishReason.STOP:
|
||||||
message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**"
|
message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**"
|
||||||
|
@ -187,14 +194,6 @@ def generate_safety_response(safety_ratings):
|
||||||
|
|
||||||
|
|
||||||
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
||||||
if len(messages) == 1:
|
|
||||||
messages[0].role = "user"
|
|
||||||
return messages, system_prompt
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
if message.role == "assistant":
|
|
||||||
message.role = "model"
|
|
||||||
|
|
||||||
# Extract system message
|
# Extract system message
|
||||||
system_prompt = system_prompt or ""
|
system_prompt = system_prompt or ""
|
||||||
for message in messages.copy():
|
for message in messages.copy():
|
||||||
|
@ -203,4 +202,31 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
|
||||||
messages.remove(message)
|
messages.remove(message)
|
||||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
# Convert message content to string list from chatml dictionary list
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
||||||
|
message.content = [
|
||||||
|
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"]
|
||||||
|
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||||
|
]
|
||||||
|
elif isinstance(message.content, str):
|
||||||
|
message.content = [message.content]
|
||||||
|
|
||||||
|
if message.role == "assistant":
|
||||||
|
message.role = "model"
|
||||||
|
|
||||||
|
if len(messages) == 1:
|
||||||
|
messages[0].role = "user"
|
||||||
|
|
||||||
return messages, system_prompt
|
return messages, system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_from_url(image_url: str) -> PIL.Image:
|
||||||
|
try:
|
||||||
|
response = requests.get(image_url)
|
||||||
|
response.raise_for_status() # Check if the request was successful
|
||||||
|
return PIL.Image.open(BytesIO(response.content))
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||||
|
return None
|
||||||
|
|
|
@ -1,127 +0,0 @@
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters, ais_user_subscribed
|
|
||||||
from khoj.database.models import ChatModelOptions, KhojUser
|
|
||||||
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
|
||||||
anthropic_send_message_to_model,
|
|
||||||
)
|
|
||||||
from khoj.processor.conversation.google.gemini_chat import gemini_send_message_to_model
|
|
||||||
from khoj.processor.conversation.offline.chat_model import send_message_to_model_offline
|
|
||||||
from khoj.processor.conversation.openai.gpt import send_message_to_model
|
|
||||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
|
||||||
from khoj.utils import state
|
|
||||||
from khoj.utils.config import OfflineChatProcessorModel
|
|
||||||
|
|
||||||
|
|
||||||
async def send_message_to_model_wrapper(
|
|
||||||
message: str,
|
|
||||||
system_message: str = "",
|
|
||||||
response_type: str = "text",
|
|
||||||
chat_model_option: ChatModelOptions = None,
|
|
||||||
user: KhojUser = None,
|
|
||||||
uploaded_image_url: str = None,
|
|
||||||
):
|
|
||||||
conversation_config: ChatModelOptions = (
|
|
||||||
chat_model_option or await ConversationAdapters.aget_default_conversation_config(user)
|
|
||||||
)
|
|
||||||
|
|
||||||
vision_available = conversation_config.vision_enabled
|
|
||||||
if not vision_available and uploaded_image_url:
|
|
||||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
|
||||||
if vision_enabled_config:
|
|
||||||
conversation_config = vision_enabled_config
|
|
||||||
vision_available = True
|
|
||||||
|
|
||||||
subscribed = await ais_user_subscribed(user)
|
|
||||||
chat_model = conversation_config.chat_model
|
|
||||||
max_tokens = (
|
|
||||||
conversation_config.subscribed_max_prompt_size
|
|
||||||
if subscribed and conversation_config.subscribed_max_prompt_size
|
|
||||||
else conversation_config.max_prompt_size
|
|
||||||
)
|
|
||||||
tokenizer = conversation_config.tokenizer
|
|
||||||
model_type = conversation_config.model_type
|
|
||||||
vision_available = conversation_config.vision_enabled
|
|
||||||
|
|
||||||
if model_type == ChatModelOptions.ModelType.OFFLINE:
|
|
||||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
|
||||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
|
||||||
|
|
||||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=message,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model,
|
|
||||||
loaded_model=loaded_model,
|
|
||||||
tokenizer_name=tokenizer,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
model_type=conversation_config.model_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
return send_message_to_model_offline(
|
|
||||||
messages=truncated_messages,
|
|
||||||
loaded_model=loaded_model,
|
|
||||||
model=chat_model,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
streaming=False,
|
|
||||||
response_type=response_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
|
||||||
openai_chat_config = conversation_config.openai_config
|
|
||||||
api_key = openai_chat_config.api_key
|
|
||||||
api_base_url = openai_chat_config.api_base_url
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=message,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
tokenizer_name=tokenizer,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
|
||||||
model_type=conversation_config.model_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
return send_message_to_model(
|
|
||||||
messages=truncated_messages,
|
|
||||||
api_key=api_key,
|
|
||||||
model=chat_model,
|
|
||||||
response_type=response_type,
|
|
||||||
api_base_url=api_base_url,
|
|
||||||
)
|
|
||||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
|
||||||
api_key = conversation_config.openai_config.api_key
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=message,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
tokenizer_name=tokenizer,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
|
||||||
model_type=conversation_config.model_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
return anthropic_send_message_to_model(
|
|
||||||
messages=truncated_messages,
|
|
||||||
api_key=api_key,
|
|
||||||
model=chat_model,
|
|
||||||
)
|
|
||||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
|
||||||
api_key = conversation_config.openai_config.api_key
|
|
||||||
truncated_messages = generate_chatml_messages_with_context(
|
|
||||||
user_message=message,
|
|
||||||
system_message=system_message,
|
|
||||||
model_name=chat_model,
|
|
||||||
max_prompt_size=max_tokens,
|
|
||||||
tokenizer_name=tokenizer,
|
|
||||||
vision_enabled=vision_available,
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
|
||||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
|
|
@ -30,7 +30,7 @@ def extract_questions(
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
@ -74,7 +74,7 @@ def extract_questions(
|
||||||
|
|
||||||
prompt = construct_structured_message(
|
prompt = construct_structured_message(
|
||||||
message=prompt,
|
message=prompt,
|
||||||
image_url=uploaded_image_url,
|
images=query_images,
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
)
|
)
|
||||||
|
@ -136,7 +136,7 @@ def converse(
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
image_url: Optional[str] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -196,7 +196,7 @@ def converse(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
uploaded_image_url=image_url,
|
query_images=query_images,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||||
)
|
)
|
||||||
|
|
|
@ -49,7 +49,7 @@ Instructions:\n{bio}
|
||||||
# Prompt forked from https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
|
# Prompt forked from https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models
|
||||||
gemini_verbose_language_personality = """
|
gemini_verbose_language_personality = """
|
||||||
All questions should be answered comprehensively with details, unless the user requests a concise response specifically.
|
All questions should be answered comprehensively with details, unless the user requests a concise response specifically.
|
||||||
Respond in the same language as the query.
|
Respond in the same language as the query. Use markdown to format your responses.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
## General Conversation
|
## General Conversation
|
||||||
|
@ -176,6 +176,150 @@ Improved Prompt:
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## Diagram Generation
|
||||||
|
## --
|
||||||
|
|
||||||
|
improve_diagram_description_prompt = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
you are an architect working with a novice artist using a diagramming tool.
|
||||||
|
{personality_context}
|
||||||
|
|
||||||
|
you need to convert the user's query to a description format that the novice artist can use very well. you are allowed to use primitives like
|
||||||
|
- text
|
||||||
|
- rectangle
|
||||||
|
- diamond
|
||||||
|
- ellipse
|
||||||
|
- line
|
||||||
|
- arrow
|
||||||
|
- frame
|
||||||
|
|
||||||
|
use these primitives to describe what sort of diagram the drawer should create. the artist must recreate the diagram every time, so include all relevant prior information in your description.
|
||||||
|
|
||||||
|
use simple, concise language.
|
||||||
|
|
||||||
|
Today's Date: {current_date}
|
||||||
|
User's Location: {location}
|
||||||
|
|
||||||
|
User's Notes:
|
||||||
|
{references}
|
||||||
|
|
||||||
|
Online References:
|
||||||
|
{online_results}
|
||||||
|
|
||||||
|
Conversation Log:
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
Query: {query}
|
||||||
|
|
||||||
|
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
excalidraw_diagram_generation_prompt = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
You are a program manager with the ability to describe diagrams to compose in professional, fine detail.
|
||||||
|
{personality_context}
|
||||||
|
|
||||||
|
You need to create a declarative description of the diagram and relevant components, using this base schema. Use the `label` property to specify the text to be rendered in the respective elements. Always use light colors for the `backgroundColor` property, like white, or light blue, green, red. "type", "x", "y", "id", are required properties for all elements.
|
||||||
|
|
||||||
|
{{
|
||||||
|
type: string,
|
||||||
|
x: number,
|
||||||
|
y: number,
|
||||||
|
strokeColor: string,
|
||||||
|
backgroundColor: string,
|
||||||
|
width: number,
|
||||||
|
height: number,
|
||||||
|
id: string,
|
||||||
|
label: {{
|
||||||
|
text: string,
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
|
||||||
|
Valid types:
|
||||||
|
- text
|
||||||
|
- rectangle
|
||||||
|
- diamond
|
||||||
|
- ellipse
|
||||||
|
- line
|
||||||
|
- arrow
|
||||||
|
|
||||||
|
For arrows and lines, you can use the `points` property to specify the start and end points of the arrow. You may also use the `label` property to specify the text to be rendered. You may use the `start` and `end` properties to connect the linear elements to other elements. The start and end point can either be the ID to map to an existing object, or the `type` to create a new object. Mapping to an existing object is useful if you want to connect it to multiple objects. Lines and arrows can only start and end at rectangle, text, diamond, or ellipse elements.
|
||||||
|
|
||||||
|
{{
|
||||||
|
type: "arrow",
|
||||||
|
id: string,
|
||||||
|
x: number,
|
||||||
|
y: number,
|
||||||
|
width: number,
|
||||||
|
height: number,
|
||||||
|
strokeColor: string,
|
||||||
|
start: {{
|
||||||
|
id: string,
|
||||||
|
type: string,
|
||||||
|
}},
|
||||||
|
end: {{
|
||||||
|
id: string,
|
||||||
|
type: string,
|
||||||
|
}},
|
||||||
|
label: {{
|
||||||
|
text: string,
|
||||||
|
}}
|
||||||
|
points: [
|
||||||
|
[number, number],
|
||||||
|
[number, number],
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
For text, you must use the `text` property to specify the text to be rendered. You may also use `fontSize` property to specify the font size of the text. Only use the `text` element for titles, subtitles, and overviews. For labels, use the `label` property in the respective elements.
|
||||||
|
|
||||||
|
{{
|
||||||
|
type: "text",
|
||||||
|
id: string,
|
||||||
|
x: number,
|
||||||
|
y: number,
|
||||||
|
fontSize: number,
|
||||||
|
text: string,
|
||||||
|
}}
|
||||||
|
|
||||||
|
For frames, use the `children` property to specify the elements that are inside the frame by their ids.
|
||||||
|
|
||||||
|
{{
|
||||||
|
type: "frame",
|
||||||
|
id: string,
|
||||||
|
x: number,
|
||||||
|
y: number,
|
||||||
|
width: number,
|
||||||
|
height: number,
|
||||||
|
name: string,
|
||||||
|
children: [
|
||||||
|
string
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
Here's an example of a valid diagram:
|
||||||
|
|
||||||
|
Design Description: Create a diagram describing a circular development process with 3 stages: design, implementation and feedback. The design stage is connected to the implementation stage and the implementation stage is connected to the feedback stage and the feedback stage is connected to the design stage. Each stage should be labeled with the stage name.
|
||||||
|
|
||||||
|
Response:
|
||||||
|
|
||||||
|
[
|
||||||
|
{{"type":"text","x":-150,"y":50,"width":300,"height":40,"id":"title_text","text":"Circular Development Process","fontSize":24}},
|
||||||
|
{{"type":"ellipse","x":-169,"y":113,"width":188,"height":202,"id":"design_ellipse", "label": {{"text": "Design"}}}},
|
||||||
|
{{"type":"ellipse","x":62,"y":394,"width":186,"height":188,"id":"implement_ellipse", "label": {{"text": "Implement"}}}},
|
||||||
|
{{"type":"ellipse","x":-348,"y":430,"width":184,"height":170,"id":"feedback_ellipse", "label": {{"text": "Feedback"}}}},
|
||||||
|
{{"type":"arrow","x":21,"y":273,"id":"design_to_implement_arrow","points":[[0,0],[86,105]],"start":{{"id":"design_ellipse"}}, "end":{{"id":"implement_ellipse"}}}},
|
||||||
|
{{"type":"arrow","x":50,"y":519,"id":"implement_to_feedback_arrow","points":[[0,0],[-198,-6]],"start":{{"id":"implement_ellipse"}}, "end":{{"id":"feedback_ellipse"}}}},
|
||||||
|
{{"type":"arrow","x":-228,"y":417,"id":"feedback_to_design_arrow","points":[[0,0],[85,-123]],"start":{{"id":"feedback_ellipse"}}, "end":{{"id":"design_ellipse"}}}},
|
||||||
|
]
|
||||||
|
|
||||||
|
Create a detailed diagram from the provided context and user prompt below. Return a valid JSON object:
|
||||||
|
|
||||||
|
Diagram Description: {query}
|
||||||
|
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
## Online Search Conversation
|
## Online Search Conversation
|
||||||
## --
|
## --
|
||||||
online_search_conversation = PromptTemplate.from_template(
|
online_search_conversation = PromptTemplate.from_template(
|
||||||
|
|
|
@ -168,7 +168,7 @@ def save_to_conversation_log(
|
||||||
client_application: ClientApplication = None,
|
client_application: ClientApplication = None,
|
||||||
conversation_id: str = None,
|
conversation_id: str = None,
|
||||||
automation_id: str = None,
|
automation_id: str = None,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
):
|
):
|
||||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
updated_conversation = message_to_log(
|
updated_conversation = message_to_log(
|
||||||
|
@ -176,7 +176,7 @@ def save_to_conversation_log(
|
||||||
chat_response=chat_response,
|
chat_response=chat_response,
|
||||||
user_message_metadata={
|
user_message_metadata={
|
||||||
"created": user_message_time,
|
"created": user_message_time,
|
||||||
"uploadedImageData": uploaded_image_url,
|
"images": query_images,
|
||||||
},
|
},
|
||||||
khoj_message_metadata={
|
khoj_message_metadata={
|
||||||
"context": compiled_references,
|
"context": compiled_references,
|
||||||
|
@ -205,10 +205,18 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Format user and system messages to chatml format
|
def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool):
|
||||||
def construct_structured_message(message, image_url, model_type, vision_enabled):
|
"""
|
||||||
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
|
Format messages into appropriate multimedia format for supported chat model types
|
||||||
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
|
"""
|
||||||
|
if not images or not vision_enabled:
|
||||||
|
return message
|
||||||
|
|
||||||
|
if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]:
|
||||||
|
return [
|
||||||
|
{"type": "text", "text": message},
|
||||||
|
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
|
||||||
|
]
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
@ -220,7 +228,7 @@ def generate_chatml_messages_with_context(
|
||||||
loaded_model: Optional[Llama] = None,
|
loaded_model: Optional[Llama] = None,
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
uploaded_image_url=None,
|
query_images=None,
|
||||||
vision_enabled=False,
|
vision_enabled=False,
|
||||||
model_type="",
|
model_type="",
|
||||||
):
|
):
|
||||||
|
@ -241,11 +249,12 @@ def generate_chatml_messages_with_context(
|
||||||
message_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
|
message_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
|
||||||
role = "user" if chat["by"] == "you" else "assistant"
|
role = "user" if chat["by"] == "you" else "assistant"
|
||||||
|
|
||||||
message_content = chat["message"] + message_notes
|
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type"):
|
||||||
|
message_content = chat.get("intent").get("inferred-queries")[0] + message_notes
|
||||||
|
else:
|
||||||
|
message_content = chat["message"] + message_notes
|
||||||
|
|
||||||
message_content = construct_structured_message(
|
message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled)
|
||||||
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
|
|
||||||
)
|
|
||||||
|
|
||||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||||
|
|
||||||
|
@ -258,7 +267,7 @@ def generate_chatml_messages_with_context(
|
||||||
if not is_none_or_empty(user_message):
|
if not is_none_or_empty(user_message):
|
||||||
messages.append(
|
messages.append(
|
||||||
ChatMessage(
|
ChatMessage(
|
||||||
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
|
content=construct_structured_message(user_message, query_images, model_type, vision_enabled),
|
||||||
role="user",
|
role="user",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -282,7 +291,6 @@ def truncate_messages(
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
) -> list[ChatMessage]:
|
) -> list[ChatMessage]:
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
|
|
||||||
default_tokenizer = "gpt-4o"
|
default_tokenizer = "gpt-4o"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -312,6 +320,7 @@ def truncate_messages(
|
||||||
system_message = messages.pop(idx)
|
system_message = messages.pop(idx)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
|
||||||
system_message_tokens = (
|
system_message_tokens = (
|
||||||
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,7 +26,7 @@ async def text_to_image(
|
||||||
references: List[Dict[str, Any]],
|
references: List[Dict[str, Any]],
|
||||||
online_results: Dict[str, Any],
|
online_results: Dict[str, Any],
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
status_code = 200
|
status_code = 200
|
||||||
|
@ -65,7 +65,7 @@ async def text_to_image(
|
||||||
note_references=references,
|
note_references=references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
model_type=text_to_image_config.model_type,
|
model_type=text_to_image_config.model_type,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
@ -87,18 +87,18 @@ async def text_to_image(
|
||||||
if "content_policy_violation" in e.message:
|
if "content_policy_violation" in e.message:
|
||||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore
|
||||||
yield image_url or image, status_code, message, intent_type.value
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
message = f"Image generation failed using OpenAI" # type: ignore
|
||||||
status_code = e.status_code # type: ignore
|
status_code = e.status_code # type: ignore
|
||||||
yield image_url or image, status_code, message, intent_type.value
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
return
|
return
|
||||||
except requests.RequestException as e:
|
except requests.RequestException as e:
|
||||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||||
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}"
|
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
|
||||||
status_code = 502
|
status_code = 502
|
||||||
yield image_url or image, status_code, message, intent_type.value
|
yield image_url or image, status_code, message, intent_type.value
|
||||||
return
|
return
|
||||||
|
|
|
@ -10,14 +10,22 @@ import aiohttp
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from markdownify import markdownify
|
from markdownify import markdownify
|
||||||
|
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.adapters import ConversationAdapters
|
||||||
|
from khoj.database.models import Agent, KhojUser, WebScraper
|
||||||
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
extract_relevant_info,
|
extract_relevant_info,
|
||||||
generate_online_subqueries,
|
generate_online_subqueries,
|
||||||
infer_webpage_urls,
|
infer_webpage_urls,
|
||||||
)
|
)
|
||||||
from khoj.utils.helpers import is_internet_connected, is_none_or_empty, timer
|
from khoj.utils.helpers import (
|
||||||
|
is_env_var_true,
|
||||||
|
is_internal_url,
|
||||||
|
is_internet_connected,
|
||||||
|
is_none_or_empty,
|
||||||
|
timer,
|
||||||
|
)
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -25,12 +33,11 @@ logger = logging.getLogger(__name__)
|
||||||
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
|
SERPER_DEV_API_KEY = os.getenv("SERPER_DEV_API_KEY")
|
||||||
SERPER_DEV_URL = "https://google.serper.dev/search"
|
SERPER_DEV_URL = "https://google.serper.dev/search"
|
||||||
|
|
||||||
JINA_READER_API_URL = "https://r.jina.ai/"
|
|
||||||
JINA_SEARCH_API_URL = "https://s.jina.ai/"
|
JINA_SEARCH_API_URL = "https://s.jina.ai/"
|
||||||
JINA_API_KEY = os.getenv("JINA_API_KEY")
|
JINA_API_KEY = os.getenv("JINA_API_KEY")
|
||||||
|
|
||||||
OLOSTEP_API_KEY = os.getenv("OLOSTEP_API_KEY")
|
FIRECRAWL_USE_LLM_EXTRACT = is_env_var_true("FIRECRAWL_USE_LLM_EXTRACT")
|
||||||
OLOSTEP_API_URL = "https://agent.olostep.com/olostep-p2p-incomingAPI"
|
|
||||||
OLOSTEP_QUERY_PARAMS = {
|
OLOSTEP_QUERY_PARAMS = {
|
||||||
"timeout": 35, # seconds
|
"timeout": 35, # seconds
|
||||||
"waitBeforeScraping": 1, # seconds
|
"waitBeforeScraping": 1, # seconds
|
||||||
|
@ -54,11 +61,10 @@ async def search_online(
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
location: LocationData,
|
location: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
subscribed: bool = False,
|
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
custom_filters: List[str] = [],
|
custom_filters: List[str] = [],
|
||||||
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
query += " ".join(custom_filters)
|
query += " ".join(custom_filters)
|
||||||
|
@ -69,7 +75,7 @@ async def search_online(
|
||||||
|
|
||||||
# Breakdown the query into subqueries to get the correct answer
|
# Breakdown the query into subqueries to get the correct answer
|
||||||
subqueries = await generate_online_subqueries(
|
subqueries = await generate_online_subqueries(
|
||||||
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
|
query, conversation_history, location, user, query_images=query_images, agent=agent
|
||||||
)
|
)
|
||||||
response_dict = {}
|
response_dict = {}
|
||||||
|
|
||||||
|
@ -86,33 +92,31 @@ async def search_online(
|
||||||
search_results = await asyncio.gather(*search_tasks)
|
search_results = await asyncio.gather(*search_tasks)
|
||||||
response_dict = {subquery: search_result for subquery, search_result in search_results}
|
response_dict = {subquery: search_result for subquery, search_result in search_results}
|
||||||
|
|
||||||
# Gather distinct web page data from organic results of each subquery without an instant answer.
|
# Gather distinct web pages from organic results for subqueries without an instant answer.
|
||||||
# Content of web pages is directly available when Jina is used for search.
|
# Content of web pages is directly available when Jina is used for search.
|
||||||
webpages = {
|
webpages = set()
|
||||||
(organic.get("link"), subquery, organic.get("content"))
|
for subquery in response_dict:
|
||||||
for subquery in response_dict
|
for organic in response_dict[subquery].get("organic", [])[:max_webpages_to_read]:
|
||||||
for organic in response_dict[subquery].get("organic", [])[:max_webpages_to_read]
|
if "answerBox" not in response_dict[subquery]:
|
||||||
if "answerBox" not in response_dict[subquery]
|
webpages.add(organic.get("link"), {"queries": {subquery}, "content": organic.get("content")})
|
||||||
}
|
|
||||||
|
|
||||||
# Read, extract relevant info from the retrieved web pages
|
# Read, extract relevant info from the retrieved web pages
|
||||||
if webpages:
|
if webpages:
|
||||||
webpage_links = set([link for link, _, _ in webpages])
|
logger.info(f"Reading web pages at: {webpages.keys()}")
|
||||||
logger.info(f"Reading web pages at: {list(webpage_links)}")
|
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
webpage_links_str = "\n- " + "\n- ".join(list(webpage_links))
|
webpage_links_str = "\n- " + "\n- ".join(webpages.keys())
|
||||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [
|
tasks = [
|
||||||
read_webpage_and_extract_content(subquery, link, content, user=user, agent=agent)
|
read_webpage_and_extract_content(data["queries"], link, data["content"], user=user, agent=agent)
|
||||||
for link, subquery, content in webpages
|
for link, data in webpages.items()
|
||||||
]
|
]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# Collect extracted info from the retrieved web pages
|
# Collect extracted info from the retrieved web pages
|
||||||
for subquery, webpage_extract, url in results:
|
for subqueries, url, webpage_extract in results:
|
||||||
if webpage_extract is not None:
|
if webpage_extract is not None:
|
||||||
response_dict[subquery]["webpages"] = {"link": url, "snippet": webpage_extract}
|
response_dict[subqueries.pop()]["webpages"] = {"link": url, "snippet": webpage_extract}
|
||||||
|
|
||||||
yield response_dict
|
yield response_dict
|
||||||
|
|
||||||
|
@ -144,7 +148,7 @@ async def read_webpages(
|
||||||
location: LocationData,
|
location: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
"Infer web pages to read from the query and extract relevant information from them"
|
"Infer web pages to read from the query and extract relevant information from them"
|
||||||
|
@ -152,36 +156,73 @@ async def read_webpages(
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
async for event in send_status_func(f"**Inferring web pages to read**"):
|
async for event in send_status_func(f"**Inferring web pages to read**"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url)
|
urls = await infer_webpage_urls(query, conversation_history, location, user, query_images)
|
||||||
|
|
||||||
logger.info(f"Reading web pages at: {urls}")
|
logger.info(f"Reading web pages at: {urls}")
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
tasks = [read_webpage_and_extract_content(query, url, user=user, agent=agent) for url in urls]
|
tasks = [read_webpage_and_extract_content({query}, url, user=user, agent=agent) for url in urls]
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
response: Dict[str, Dict] = defaultdict(dict)
|
response: Dict[str, Dict] = defaultdict(dict)
|
||||||
response[query]["webpages"] = [
|
response[query]["webpages"] = [
|
||||||
{"query": q, "link": url, "snippet": web_extract} for q, web_extract, url in results if web_extract is not None
|
{"query": qs.pop(), "link": url, "snippet": extract} for qs, url, extract in results if extract is not None
|
||||||
]
|
]
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
|
|
||||||
|
async def read_webpage(
|
||||||
|
url, scraper_type=None, api_key=None, api_url=None, subqueries=None, agent=None
|
||||||
|
) -> Tuple[str | None, str | None]:
|
||||||
|
if scraper_type == WebScraper.WebScraperType.FIRECRAWL and FIRECRAWL_USE_LLM_EXTRACT:
|
||||||
|
return None, await query_webpage_with_firecrawl(url, subqueries, api_key, api_url, agent)
|
||||||
|
elif scraper_type == WebScraper.WebScraperType.FIRECRAWL:
|
||||||
|
return await read_webpage_with_firecrawl(url, api_key, api_url), None
|
||||||
|
elif scraper_type == WebScraper.WebScraperType.OLOSTEP:
|
||||||
|
return await read_webpage_with_olostep(url, api_key, api_url), None
|
||||||
|
elif scraper_type == WebScraper.WebScraperType.JINA:
|
||||||
|
return await read_webpage_with_jina(url, api_key, api_url), None
|
||||||
|
else:
|
||||||
|
return await read_webpage_at_url(url), None
|
||||||
|
|
||||||
|
|
||||||
async def read_webpage_and_extract_content(
|
async def read_webpage_and_extract_content(
|
||||||
subquery: str, url: str, content: str = None, user: KhojUser = None, agent: Agent = None
|
subqueries: set[str], url: str, content: str = None, user: KhojUser = None, agent: Agent = None
|
||||||
) -> Tuple[str, Union[None, str], str]:
|
) -> Tuple[set[str], str, Union[None, str]]:
|
||||||
try:
|
# Select the web scrapers to use for reading the web page
|
||||||
if is_none_or_empty(content):
|
web_scrapers = await ConversationAdapters.aget_enabled_webscrapers()
|
||||||
with timer(f"Reading web page at '{url}' took", logger):
|
# Only use the direct web scraper for internal URLs
|
||||||
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
|
if is_internal_url(url):
|
||||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
web_scrapers = [scraper for scraper in web_scrapers if scraper.type == WebScraper.WebScraperType.DIRECT]
|
||||||
extracted_info = await extract_relevant_info(subquery, content, user=user, agent=agent)
|
|
||||||
return subquery, extracted_info, url
|
# Fallback through enabled web scrapers until we successfully read the web page
|
||||||
except Exception as e:
|
extracted_info = None
|
||||||
logger.error(f"Failed to read web page at '{url}' with {e}")
|
for scraper in web_scrapers:
|
||||||
return subquery, None, url
|
try:
|
||||||
|
# Read the web page
|
||||||
|
if is_none_or_empty(content):
|
||||||
|
with timer(f"Reading web page with {scraper.type} at '{url}' took", logger, log_level=logging.INFO):
|
||||||
|
content, extracted_info = await read_webpage(
|
||||||
|
url, scraper.type, scraper.api_key, scraper.api_url, subqueries, agent
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract relevant information from the web page
|
||||||
|
if is_none_or_empty(extracted_info):
|
||||||
|
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||||
|
extracted_info = await extract_relevant_info(subqueries, content, user=user, agent=agent)
|
||||||
|
|
||||||
|
# If we successfully extracted information, break the loop
|
||||||
|
if not is_none_or_empty(extracted_info):
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to read web page with {scraper.type} at '{url}' with {e}")
|
||||||
|
# If this is the last web scraper in the list, log an error
|
||||||
|
if scraper.name == web_scrapers[-1].name:
|
||||||
|
logger.error(f"All web scrapers failed for '{url}'")
|
||||||
|
|
||||||
|
return subqueries, url, extracted_info
|
||||||
|
|
||||||
|
|
||||||
async def read_webpage_at_url(web_url: str) -> str:
|
async def read_webpage_at_url(web_url: str) -> str:
|
||||||
|
@ -198,23 +239,23 @@ async def read_webpage_at_url(web_url: str) -> str:
|
||||||
return markdownify(body)
|
return markdownify(body)
|
||||||
|
|
||||||
|
|
||||||
async def read_webpage_with_olostep(web_url: str) -> str:
|
async def read_webpage_with_olostep(web_url: str, api_key: str, api_url: str) -> str:
|
||||||
headers = {"Authorization": f"Bearer {OLOSTEP_API_KEY}"}
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
web_scraping_params: Dict[str, Union[str, int, bool]] = OLOSTEP_QUERY_PARAMS.copy() # type: ignore
|
web_scraping_params: Dict[str, Union[str, int, bool]] = OLOSTEP_QUERY_PARAMS.copy() # type: ignore
|
||||||
web_scraping_params["url"] = web_url
|
web_scraping_params["url"] = web_url
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(OLOSTEP_API_URL, params=web_scraping_params, headers=headers) as response:
|
async with session.get(api_url, params=web_scraping_params, headers=headers) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
return response_json["markdown_content"]
|
return response_json["markdown_content"]
|
||||||
|
|
||||||
|
|
||||||
async def read_webpage_with_jina(web_url: str) -> str:
|
async def read_webpage_with_jina(web_url: str, api_key: str, api_url: str) -> str:
|
||||||
jina_reader_api_url = f"{JINA_READER_API_URL}/{web_url}"
|
jina_reader_api_url = f"{api_url}/{web_url}"
|
||||||
headers = {"Accept": "application/json", "X-Timeout": "30"}
|
headers = {"Accept": "application/json", "X-Timeout": "30"}
|
||||||
if JINA_API_KEY:
|
if api_key:
|
||||||
headers["Authorization"] = f"Bearer {JINA_API_KEY}"
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(jina_reader_api_url, headers=headers) as response:
|
async with session.get(jina_reader_api_url, headers=headers) as response:
|
||||||
|
@ -223,6 +264,54 @@ async def read_webpage_with_jina(web_url: str) -> str:
|
||||||
return response_json["data"]["content"]
|
return response_json["data"]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
async def read_webpage_with_firecrawl(web_url: str, api_key: str, api_url: str) -> str:
|
||||||
|
firecrawl_api_url = f"{api_url}/v1/scrape"
|
||||||
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||||
|
params = {"url": web_url, "formats": ["markdown"], "excludeTags": ["script", ".ad"]}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(firecrawl_api_url, json=params, headers=headers) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
response_json = await response.json()
|
||||||
|
return response_json["data"]["markdown"]
|
||||||
|
|
||||||
|
|
||||||
|
async def query_webpage_with_firecrawl(
|
||||||
|
web_url: str, queries: set[str], api_key: str, api_url: str, agent: Agent = None
|
||||||
|
) -> str:
|
||||||
|
firecrawl_api_url = f"{api_url}/v1/scrape"
|
||||||
|
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"relevant_extract": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"relevant_extract",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
personality_context = (
|
||||||
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||||
|
)
|
||||||
|
system_prompt = f"""
|
||||||
|
{prompts.system_prompt_extract_relevant_information}
|
||||||
|
|
||||||
|
{personality_context}
|
||||||
|
User Query: {", ".join(queries)}
|
||||||
|
|
||||||
|
Collate only relevant information from the website to answer the target query and in the provided JSON schema.
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
params = {"url": web_url, "formats": ["extract"], "extract": {"systemPrompt": system_prompt, "schema": schema}}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(firecrawl_api_url, json=params, headers=headers) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
response_json = await response.json()
|
||||||
|
return response_json["data"]["extract"]["relevant_extract"]
|
||||||
|
|
||||||
|
|
||||||
async def search_with_jina(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]:
|
async def search_with_jina(query: str, location: LocationData) -> Tuple[str, Dict[str, List[Dict]]]:
|
||||||
encoded_query = urllib.parse.quote(query)
|
encoded_query = urllib.parse.quote(query)
|
||||||
jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}"
|
jina_search_api_url = f"{JINA_SEARCH_API_URL}/{encoded_query}"
|
||||||
|
|
|
@ -10,12 +10,12 @@ import aiohttp
|
||||||
from khoj.database.adapters import ais_user_subscribed
|
from khoj.database.adapters import ais_user_subscribed
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.helpers import send_message_to_model_wrapper
|
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
construct_chat_history,
|
construct_chat_history,
|
||||||
remove_json_codeblock,
|
remove_json_codeblock,
|
||||||
)
|
)
|
||||||
|
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||||
from khoj.utils.helpers import timer
|
from khoj.utils.helpers import timer
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ async def run_code(
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
sandbox_url: str = SANDBOX_URL,
|
sandbox_url: str = SANDBOX_URL,
|
||||||
):
|
):
|
||||||
|
@ -43,7 +43,7 @@ async def run_code(
|
||||||
try:
|
try:
|
||||||
with timer("Chat actor: Generate programs to execute", logger):
|
with timer("Chat actor: Generate programs to execute", logger):
|
||||||
codes = await generate_python_code(
|
codes = await generate_python_code(
|
||||||
query, conversation_history, previous_iterations_history, location_data, user, uploaded_image_url, agent
|
query, conversation_history, previous_iterations_history, location_data, user, query_images, agent
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
||||||
|
@ -70,7 +70,7 @@ async def generate_python_code(
|
||||||
previous_iterations_history: str,
|
previous_iterations_history: str,
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
location = f"{location_data}" if location_data else "Unknown"
|
location = f"{location_data}" if location_data else "Unknown"
|
||||||
|
@ -95,7 +95,7 @@ async def generate_python_code(
|
||||||
|
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
code_generation_prompt,
|
code_generation_prompt,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
|
@ -21,11 +21,13 @@ from starlette.authentication import has_required_scope, requires
|
||||||
from khoj.configure import initialize_content
|
from khoj.configure import initialize_content
|
||||||
from khoj.database import adapters
|
from khoj.database import adapters
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
|
AgentAdapters,
|
||||||
AutomationAdapters,
|
AutomationAdapters,
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
|
get_default_search_model,
|
||||||
|
get_user_default_search_model,
|
||||||
get_user_photo,
|
get_user_photo,
|
||||||
get_user_search_model_or_default,
|
|
||||||
)
|
)
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
Agent,
|
Agent,
|
||||||
|
@ -115,10 +117,16 @@ async def execute_search(
|
||||||
dedupe: Optional[bool] = True,
|
dedupe: Optional[bool] = True,
|
||||||
agent: Optional[Agent] = None,
|
agent: Optional[Agent] = None,
|
||||||
):
|
):
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# Run validation checks
|
# Run validation checks
|
||||||
results: List[SearchResponse] = []
|
results: List[SearchResponse] = []
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Ensure the agent, if present, is accessible by the user
|
||||||
|
if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user):
|
||||||
|
logger.error(f"Agent {agent.slug} is not accessible by user {user}")
|
||||||
|
return results
|
||||||
|
|
||||||
if q is None or q == "":
|
if q is None or q == "":
|
||||||
logger.warning(f"No query param (q) passed in API call to initiate search")
|
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||||
return results
|
return results
|
||||||
|
@ -143,7 +151,7 @@ async def execute_search(
|
||||||
encoded_asymmetric_query = None
|
encoded_asymmetric_query = None
|
||||||
if t != SearchType.Image:
|
if t != SearchType.Image:
|
||||||
with timer("Encoding query took", logger=logger):
|
with timer("Encoding query took", logger=logger):
|
||||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
search_model = await sync_to_async(get_user_default_search_model)(user)
|
||||||
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
encoded_asymmetric_query = state.embeddings_model[search_model.name].embed_query(defiltered_query)
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
@ -341,7 +349,7 @@ async def extract_references_and_questions(
|
||||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
|
@ -430,7 +438,7 @@ async def extract_references_and_questions(
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user=user,
|
user=user,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
@ -451,12 +459,14 @@ async def extract_references_and_questions(
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
inferred_queries = extract_questions_gemini(
|
inferred_queries = extract_questions_gemini(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
|
query_images=query_images,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
max_tokens=conversation_config.max_prompt_size,
|
max_tokens=conversation_config.max_prompt_size,
|
||||||
user=user,
|
user=user,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import random
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
@ -9,8 +11,8 @@ from fastapi.responses import Response
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
|
|
||||||
from khoj.database.adapters import AgentAdapters
|
from khoj.database.adapters import AgentAdapters, ConversationAdapters
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, Conversation, KhojUser
|
||||||
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
|
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
|
@ -45,30 +47,49 @@ async def all_agents(
|
||||||
) -> Response:
|
) -> Response:
|
||||||
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
||||||
agents = await AgentAdapters.aget_all_accessible_agents(user)
|
agents = await AgentAdapters.aget_all_accessible_agents(user)
|
||||||
|
default_agent = await AgentAdapters.aget_default_agent()
|
||||||
|
default_agent_packet = None
|
||||||
agents_packet = list()
|
agents_packet = list()
|
||||||
for agent in agents:
|
for agent in agents:
|
||||||
files = agent.fileobject_set.all()
|
files = agent.fileobject_set.all()
|
||||||
file_names = [file.file_name for file in files]
|
file_names = [file.file_name for file in files]
|
||||||
agents_packet.append(
|
agent_packet = {
|
||||||
{
|
"slug": agent.slug,
|
||||||
"slug": agent.slug,
|
"name": agent.name,
|
||||||
"name": agent.name,
|
"persona": agent.personality,
|
||||||
"persona": agent.personality,
|
"creator": agent.creator.username if agent.creator else None,
|
||||||
"creator": agent.creator.username if agent.creator else None,
|
"managed_by_admin": agent.managed_by_admin,
|
||||||
"managed_by_admin": agent.managed_by_admin,
|
"color": agent.style_color,
|
||||||
"color": agent.style_color,
|
"icon": agent.style_icon,
|
||||||
"icon": agent.style_icon,
|
"privacy_level": agent.privacy_level,
|
||||||
"privacy_level": agent.privacy_level,
|
"chat_model": agent.chat_model.chat_model,
|
||||||
"chat_model": agent.chat_model.chat_model,
|
"files": file_names,
|
||||||
"files": file_names,
|
"input_tools": agent.input_tools,
|
||||||
"input_tools": agent.input_tools,
|
"output_modes": agent.output_modes,
|
||||||
"output_modes": agent.output_modes,
|
}
|
||||||
}
|
if agent.slug == default_agent.slug:
|
||||||
)
|
default_agent_packet = agent_packet
|
||||||
|
else:
|
||||||
|
agents_packet.append(agent_packet)
|
||||||
|
|
||||||
|
# Load recent conversation sessions
|
||||||
|
min_date = datetime.min.replace(tzinfo=timezone.utc)
|
||||||
|
two_weeks_ago = datetime.today() - timedelta(weeks=2)
|
||||||
|
conversations = []
|
||||||
|
if user:
|
||||||
|
conversations = await sync_to_async(list[Conversation])(
|
||||||
|
ConversationAdapters.get_conversation_sessions(user, request.user.client_app)
|
||||||
|
.filter(updated_at__gte=two_weeks_ago)
|
||||||
|
.order_by("-updated_at")[:50]
|
||||||
|
)
|
||||||
|
conversation_times = {conv.agent.slug: conv.updated_at for conv in conversations if conv.agent}
|
||||||
|
|
||||||
|
# Put default agent first, then sort by mru and finally shuffle unused randomly
|
||||||
|
random.shuffle(agents_packet)
|
||||||
|
agents_packet.sort(key=lambda x: conversation_times.get(x["slug"]) or min_date, reverse=True)
|
||||||
|
if default_agent_packet:
|
||||||
|
agents_packet.insert(0, default_agent_packet)
|
||||||
|
|
||||||
# Make sure that the agent named 'khoj' is first in the list. Everything else is sorted by name.
|
|
||||||
agents_packet.sort(key=lambda x: x["name"])
|
|
||||||
agents_packet.sort(key=lambda x: x["slug"] == "khoj", reverse=True)
|
|
||||||
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
|
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,15 +10,15 @@ from urllib.parse import unquote
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from fastapi.requests import Request
|
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
from starlette.authentication import has_required_scope, requires
|
from starlette.authentication import requires
|
||||||
|
|
||||||
from khoj.app.settings import ALLOWED_HOSTS
|
from khoj.app.settings import ALLOWED_HOSTS
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
AgentAdapters,
|
AgentAdapters,
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
|
FileObjectAdapters,
|
||||||
PublicConversationAdapters,
|
PublicConversationAdapters,
|
||||||
aget_user_name,
|
aget_user_name,
|
||||||
)
|
)
|
||||||
|
@ -31,8 +31,10 @@ from khoj.processor.tools.online_search import read_webpages, search_online
|
||||||
from khoj.processor.tools.run_code import run_code
|
from khoj.processor.tools.run_code import run_code
|
||||||
from khoj.routers.api import extract_references_and_questions
|
from khoj.routers.api import extract_references_and_questions
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
|
ApiImageRateLimiter,
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
|
ChatRequestBody,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
|
@ -41,6 +43,8 @@ from khoj.routers.helpers import (
|
||||||
construct_automation_created_message,
|
construct_automation_created_message,
|
||||||
create_automation,
|
create_automation,
|
||||||
extract_relevant_info,
|
extract_relevant_info,
|
||||||
|
extract_relevant_summary,
|
||||||
|
generate_excalidraw_diagram,
|
||||||
generate_summary_from_files,
|
generate_summary_from_files,
|
||||||
get_conversation_command,
|
get_conversation_command,
|
||||||
is_query_empty,
|
is_query_empty,
|
||||||
|
@ -529,22 +533,6 @@ async def set_conversation_title(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatRequestBody(BaseModel):
|
|
||||||
q: str
|
|
||||||
n: Optional[int] = 7
|
|
||||||
d: Optional[float] = None
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
title: Optional[str] = None
|
|
||||||
conversation_id: Optional[str] = None
|
|
||||||
city: Optional[str] = None
|
|
||||||
region: Optional[str] = None
|
|
||||||
country: Optional[str] = None
|
|
||||||
country_code: Optional[str] = None
|
|
||||||
timezone: Optional[str] = None
|
|
||||||
image: Optional[str] = None
|
|
||||||
create_new: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
@api_chat.post("")
|
@api_chat.post("")
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def chat(
|
async def chat(
|
||||||
|
@ -557,6 +545,7 @@ async def chat(
|
||||||
rate_limiter_per_day=Depends(
|
rate_limiter_per_day=Depends(
|
||||||
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
|
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
|
||||||
),
|
),
|
||||||
|
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
|
||||||
):
|
):
|
||||||
# Access the parameters from the body
|
# Access the parameters from the body
|
||||||
q = body.q
|
q = body.q
|
||||||
|
@ -570,29 +559,28 @@ async def chat(
|
||||||
country = body.country or get_country_name_from_timezone(body.timezone)
|
country = body.country or get_country_name_from_timezone(body.timezone)
|
||||||
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||||
timezone = body.timezone
|
timezone = body.timezone
|
||||||
image = body.image
|
raw_images = body.images
|
||||||
|
|
||||||
async def event_generator(q: str, image: str):
|
async def event_generator(q: str, images: list[str]):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
ttft = None
|
ttft = None
|
||||||
chat_metadata: dict = {}
|
chat_metadata: dict = {}
|
||||||
connection_alive = True
|
connection_alive = True
|
||||||
user: KhojUser = request.user.object
|
user: KhojUser = request.user.object
|
||||||
subscribed: bool = has_required_scope(request, ["premium"])
|
|
||||||
event_delimiter = "␃🔚␗"
|
event_delimiter = "␃🔚␗"
|
||||||
q = unquote(q)
|
q = unquote(q)
|
||||||
nonlocal conversation_id
|
nonlocal conversation_id
|
||||||
|
|
||||||
uploaded_image_url = None
|
uploaded_images: list[str] = []
|
||||||
if image:
|
if images:
|
||||||
decoded_string = unquote(image)
|
for image in images:
|
||||||
base64_data = decoded_string.split(",", 1)[1]
|
decoded_string = unquote(image)
|
||||||
image_bytes = base64.b64decode(base64_data)
|
base64_data = decoded_string.split(",", 1)[1]
|
||||||
webp_image_bytes = convert_image_to_webp(image_bytes)
|
image_bytes = base64.b64decode(base64_data)
|
||||||
try:
|
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||||
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||||
except:
|
if uploaded_image:
|
||||||
uploaded_image_url = None
|
uploaded_images.append(uploaded_image)
|
||||||
|
|
||||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||||
nonlocal connection_alive, ttft
|
nonlocal connection_alive, ttft
|
||||||
|
@ -645,7 +633,7 @@ async def chat(
|
||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="chat",
|
api="chat",
|
||||||
client=request.user.client_app,
|
client=common.client,
|
||||||
user_agent=request.headers.get("user-agent"),
|
user_agent=request.headers.get("user-agent"),
|
||||||
host=request.headers.get("host"),
|
host=request.headers.get("host"),
|
||||||
metadata=chat_metadata,
|
metadata=chat_metadata,
|
||||||
|
@ -706,11 +694,10 @@ async def chat(
|
||||||
async for research_result in execute_information_collection(
|
async for research_result in execute_information_collection(
|
||||||
request=request,
|
request=request,
|
||||||
user=user,
|
user=user,
|
||||||
query=q,
|
query=defiltered_query,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
conversation_history=meta_log,
|
conversation_history=meta_log,
|
||||||
subscribed=subscribed,
|
query_images=raw_images,
|
||||||
uploaded_image_url=uploaded_image_url,
|
|
||||||
agent=agent,
|
agent=agent,
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
|
@ -743,11 +730,11 @@ async def chat(
|
||||||
meta_log,
|
meta_log,
|
||||||
is_automated_task,
|
is_automated_task,
|
||||||
user=user,
|
user=user,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
|
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||||
yield result
|
yield result
|
||||||
if mode not in conversation_commands:
|
if mode not in conversation_commands:
|
||||||
|
@ -788,11 +775,15 @@ async def chat(
|
||||||
user=user,
|
user=user,
|
||||||
file_filters=file_filters,
|
file_filters=file_filters,
|
||||||
meta_log=meta_log,
|
meta_log=meta_log,
|
||||||
subscribed=subscribed,
|
query_images=uploaded_images,
|
||||||
|
agent=agent,
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
send_response_func=partial(send_llm_response),
|
send_response_func=partial(send_llm_response),
|
||||||
):
|
):
|
||||||
yield response
|
if isinstance(response, dict) and ChatEvent.STATUS in response:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
response
|
||||||
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
|
@ -803,7 +794,7 @@ async def chat(
|
||||||
intent_type="summarize",
|
intent_type="summarize",
|
||||||
client_application=request.user.client_app,
|
client_application=request.user.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -846,189 +837,189 @@ async def chat(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
inferred_queries=[query_to_run],
|
inferred_queries=[query_to_run],
|
||||||
automation_id=automation.id,
|
automation_id=automation.id,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
)
|
)
|
||||||
async for result in send_llm_response(llm_response):
|
async for result in send_llm_response(llm_response):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
# # Gather Context
|
# Gather Context
|
||||||
# # Extract Document References
|
## Extract Document References
|
||||||
# try:
|
if pending_research:
|
||||||
# async for result in extract_references_and_questions(
|
try:
|
||||||
# request,
|
async for result in extract_references_and_questions(
|
||||||
# meta_log,
|
request,
|
||||||
# q,
|
meta_log,
|
||||||
# (n or 7),
|
q,
|
||||||
# d,
|
(n or 7),
|
||||||
# conversation_id,
|
d,
|
||||||
# conversation_commands,
|
conversation_id,
|
||||||
# location,
|
conversation_commands,
|
||||||
# partial(send_event, ChatEvent.STATUS),
|
location,
|
||||||
# uploaded_image_url=uploaded_image_url,
|
partial(send_event, ChatEvent.STATUS),
|
||||||
# agent=agent,
|
query_images=uploaded_images,
|
||||||
# ):
|
agent=agent,
|
||||||
# if isinstance(result, dict) and ChatEvent.STATUS in result:
|
):
|
||||||
# yield result[ChatEvent.STATUS]
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
# else:
|
yield result[ChatEvent.STATUS]
|
||||||
# compiled_references.extend(result[0])
|
else:
|
||||||
# inferred_queries.extend(result[1])
|
compiled_references.extend(result[0])
|
||||||
# defiltered_query = result[2]
|
inferred_queries.extend(result[1])
|
||||||
# except Exception as e:
|
defiltered_query = result[2]
|
||||||
# error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
except Exception as e:
|
||||||
# logger.warning(error_message)
|
error_message = (
|
||||||
# async for result in send_event(
|
f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
||||||
# ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
)
|
||||||
# ):
|
logger.error(error_message, exc_info=True)
|
||||||
# yield result
|
async for result in send_event(
|
||||||
#
|
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
||||||
# # if not is_none_or_empty(compiled_references):
|
):
|
||||||
# try:
|
yield result
|
||||||
# headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
|
||||||
# # Strip only leading # from headings
|
|
||||||
# headings = headings.replace("#", "")
|
|
||||||
# async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
|
||||||
# yield result
|
|
||||||
# except Exception as e:
|
|
||||||
# # TODO Get correct type for compiled across research notes extraction
|
|
||||||
# logger.error(f"Error extracting references: {e}", exc_info=True)
|
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
if not is_none_or_empty(compiled_references):
|
||||||
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
|
||||||
yield result
|
# Strip only leading # from headings
|
||||||
return
|
headings = headings.replace("#", "")
|
||||||
|
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||||
|
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
if ConversationCommand.Notes in conversation_commands and is_none_or_empty(compiled_references):
|
||||||
conversation_commands.remove(ConversationCommand.Notes)
|
conversation_commands.remove(ConversationCommand.Notes)
|
||||||
|
|
||||||
## Gather Online References
|
if pending_research:
|
||||||
if ConversationCommand.Online in conversation_commands and pending_research:
|
## Gather Online References
|
||||||
try:
|
if ConversationCommand.Online in conversation_commands:
|
||||||
async for result in search_online(
|
try:
|
||||||
defiltered_query,
|
async for result in search_online(
|
||||||
meta_log,
|
defiltered_query,
|
||||||
location,
|
meta_log,
|
||||||
user,
|
location,
|
||||||
subscribed,
|
user,
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
custom_filters,
|
custom_filters,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
else:
|
else:
|
||||||
online_results = result
|
online_results = result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = f"Error searching online: {e}. Attempting to respond without online results"
|
error_message = f"Error searching online: {e}. Attempting to respond without online results"
|
||||||
logger.warning(error_message)
|
logger.warning(error_message)
|
||||||
async for result in send_event(
|
async for result in send_event(
|
||||||
ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
|
ChatEvent.STATUS, "Online search failed. I'll try respond without online references"
|
||||||
):
|
):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
## Gather Webpage References
|
if pending_research:
|
||||||
if ConversationCommand.Webpage in conversation_commands and pending_research:
|
## Gather Webpage References
|
||||||
try:
|
if ConversationCommand.Webpage in conversation_commands:
|
||||||
async for result in read_webpages(
|
try:
|
||||||
defiltered_query,
|
async for result in read_webpages(
|
||||||
meta_log,
|
defiltered_query,
|
||||||
location,
|
meta_log,
|
||||||
user,
|
location,
|
||||||
partial(send_event, ChatEvent.STATUS),
|
user,
|
||||||
uploaded_image_url=uploaded_image_url,
|
partial(send_event, ChatEvent.STATUS),
|
||||||
agent=agent,
|
query_images=uploaded_images,
|
||||||
):
|
agent=agent,
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
):
|
||||||
yield result[ChatEvent.STATUS]
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
else:
|
yield result[ChatEvent.STATUS]
|
||||||
direct_web_pages = result
|
else:
|
||||||
webpages = []
|
direct_web_pages = result
|
||||||
for query in direct_web_pages:
|
webpages = []
|
||||||
if online_results.get(query):
|
for query in direct_web_pages:
|
||||||
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
|
if online_results.get(query):
|
||||||
else:
|
online_results[query]["webpages"] = direct_web_pages[query]["webpages"]
|
||||||
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
|
else:
|
||||||
|
online_results[query] = {"webpages": direct_web_pages[query]["webpages"]}
|
||||||
|
|
||||||
for webpage in direct_web_pages[query]["webpages"]:
|
for webpage in direct_web_pages[query]["webpages"]:
|
||||||
webpages.append(webpage["link"])
|
webpages.append(webpage["link"])
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Read web pages**: {webpages}"):
|
||||||
yield result
|
yield result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Error reading webpages: {e}. Attempting to respond without webpage results",
|
f"Error reading webpages: {e}. Attempting to respond without webpage results",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
async for result in send_event(
|
async for result in send_event(
|
||||||
ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
|
ChatEvent.STATUS, "Webpage read failed. I'll try respond without webpage references"
|
||||||
):
|
):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
## Gather Code Results
|
## Send Gathered References
|
||||||
if ConversationCommand.Code in conversation_commands and pending_research:
|
async for result in send_event(
|
||||||
try:
|
ChatEvent.REFERENCES,
|
||||||
previous_iteration_history = (
|
{
|
||||||
f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
|
"inferredQueries": inferred_queries,
|
||||||
)
|
"context": compiled_references,
|
||||||
async for result in run_code(
|
"onlineContext": online_results,
|
||||||
defiltered_query,
|
},
|
||||||
meta_log,
|
):
|
||||||
previous_iteration_history,
|
yield result
|
||||||
location,
|
|
||||||
user,
|
|
||||||
partial(send_event, ChatEvent.STATUS),
|
|
||||||
uploaded_image_url=uploaded_image_url,
|
|
||||||
agent=agent,
|
|
||||||
):
|
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
|
||||||
yield result[ChatEvent.STATUS]
|
|
||||||
else:
|
|
||||||
code_results = result
|
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
|
|
||||||
yield result
|
|
||||||
except ValueError as e:
|
|
||||||
logger.warning(
|
|
||||||
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
## Send Gathered References
|
if pending_research:
|
||||||
async for result in send_event(
|
## Gather Code Results
|
||||||
ChatEvent.REFERENCES,
|
if ConversationCommand.Code in conversation_commands and pending_research:
|
||||||
{
|
try:
|
||||||
"inferredQueries": inferred_queries,
|
previous_iteration_history = (
|
||||||
"context": compiled_references,
|
f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
|
||||||
"onlineContext": online_results,
|
)
|
||||||
"codeContext": code_results,
|
async for result in run_code(
|
||||||
},
|
defiltered_query,
|
||||||
):
|
meta_log,
|
||||||
yield result
|
previous_iteration_history,
|
||||||
|
location,
|
||||||
|
user,
|
||||||
|
partial(send_event, ChatEvent.STATUS),
|
||||||
|
query_images=uploaded_images,
|
||||||
|
agent=agent,
|
||||||
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
code_results = result
|
||||||
|
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
|
||||||
|
yield result
|
||||||
|
except ValueError as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Generate Output
|
# Generate Output
|
||||||
## Generate Image Output
|
## Generate Image Output
|
||||||
if ConversationCommand.Image in conversation_commands and pending_research:
|
if ConversationCommand.Image in conversation_commands:
|
||||||
async for result in text_to_image(
|
async for result in text_to_image(
|
||||||
q,
|
defiltered_query,
|
||||||
user,
|
user,
|
||||||
meta_log,
|
meta_log,
|
||||||
location_data=location,
|
location_data=location,
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
else:
|
else:
|
||||||
image, status_code, improved_image_prompt, intent_type = result
|
generated_image, status_code, improved_image_prompt, intent_type = result
|
||||||
|
|
||||||
if image is None or status_code != 200:
|
if generated_image is None or status_code != 200:
|
||||||
content_obj = {
|
content_obj = {
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
"intentType": intent_type,
|
"intentType": intent_type,
|
||||||
"detail": improved_image_prompt,
|
"detail": improved_image_prompt,
|
||||||
"image": image,
|
"image": None,
|
||||||
}
|
}
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
yield result
|
yield result
|
||||||
|
@ -1036,7 +1027,7 @@ async def chat(
|
||||||
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
image,
|
generated_image,
|
||||||
user,
|
user,
|
||||||
meta_log,
|
meta_log,
|
||||||
user_message_time,
|
user_message_time,
|
||||||
|
@ -1046,22 +1037,73 @@ async def chat(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
)
|
)
|
||||||
content_obj = {
|
content_obj = {
|
||||||
"intentType": intent_type,
|
"intentType": intent_type,
|
||||||
"inferredQueries": [improved_image_prompt],
|
"inferredQueries": [improved_image_prompt],
|
||||||
"image": image,
|
"image": generated_image,
|
||||||
}
|
}
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
yield result
|
yield result
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if ConversationCommand.Diagram in conversation_commands:
|
||||||
|
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
|
||||||
|
yield result
|
||||||
|
|
||||||
|
intent_type = "excalidraw"
|
||||||
|
inferred_queries = []
|
||||||
|
diagram_description = ""
|
||||||
|
|
||||||
|
async for result in generate_excalidraw_diagram(
|
||||||
|
q=defiltered_query,
|
||||||
|
conversation_history=meta_log,
|
||||||
|
location_data=location,
|
||||||
|
note_references=compiled_references,
|
||||||
|
online_results=online_results,
|
||||||
|
query_images=uploaded_images,
|
||||||
|
user=user,
|
||||||
|
agent=agent,
|
||||||
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
|
):
|
||||||
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
yield result[ChatEvent.STATUS]
|
||||||
|
else:
|
||||||
|
better_diagram_description_prompt, excalidraw_diagram_description = result
|
||||||
|
inferred_queries.append(better_diagram_description_prompt)
|
||||||
|
diagram_description = excalidraw_diagram_description
|
||||||
|
|
||||||
|
content_obj = {
|
||||||
|
"intentType": intent_type,
|
||||||
|
"inferredQueries": inferred_queries,
|
||||||
|
"image": diagram_description,
|
||||||
|
}
|
||||||
|
|
||||||
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
q,
|
||||||
|
excalidraw_diagram_description,
|
||||||
|
user,
|
||||||
|
meta_log,
|
||||||
|
user_message_time,
|
||||||
|
intent_type="excalidraw",
|
||||||
|
inferred_queries=[better_diagram_description_prompt],
|
||||||
|
client_application=request.user.client_app,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
compiled_references=compiled_references,
|
||||||
|
online_results=online_results,
|
||||||
|
query_images=uploaded_images,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
|
yield result
|
||||||
|
return
|
||||||
|
|
||||||
## Generate Text Output
|
## Generate Text Output
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
||||||
yield result
|
yield result
|
||||||
llm_response, chat_metadata = await agenerate_chat_response(
|
llm_response, chat_metadata = await agenerate_chat_response(
|
||||||
q,
|
defiltered_query,
|
||||||
meta_log,
|
meta_log,
|
||||||
conversation,
|
conversation,
|
||||||
compiled_references,
|
compiled_references,
|
||||||
|
@ -1074,8 +1116,8 @@ async def chat(
|
||||||
conversation_id,
|
conversation_id,
|
||||||
location,
|
location,
|
||||||
user_name,
|
user_name,
|
||||||
uploaded_image_url,
|
|
||||||
researched_results,
|
researched_results,
|
||||||
|
uploaded_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send Response
|
# Send Response
|
||||||
|
@ -1101,9 +1143,9 @@ async def chat(
|
||||||
|
|
||||||
## Stream Text Response
|
## Stream Text Response
|
||||||
if stream:
|
if stream:
|
||||||
return StreamingResponse(event_generator(q, image=image), media_type="text/plain")
|
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
|
||||||
## Non-Streaming Text Response
|
## Non-Streaming Text Response
|
||||||
else:
|
else:
|
||||||
response_iterator = event_generator(q, image=image)
|
response_iterator = event_generator(q, images=raw_images)
|
||||||
response_data = await read_chat_stream(response_iterator)
|
response_data = await read_chat_stream(response_iterator)
|
||||||
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
||||||
|
|
|
@ -94,39 +94,6 @@ async def update_voice_model(
|
||||||
return Response(status_code=202, content=json.dumps({"status": "ok"}))
|
return Response(status_code=202, content=json.dumps({"status": "ok"}))
|
||||||
|
|
||||||
|
|
||||||
@api_model.post("/search", status_code=200)
|
|
||||||
@requires(["authenticated"])
|
|
||||||
async def update_search_model(
|
|
||||||
request: Request,
|
|
||||||
id: str,
|
|
||||||
client: Optional[str] = None,
|
|
||||||
):
|
|
||||||
user = request.user.object
|
|
||||||
|
|
||||||
prev_config = await adapters.aget_user_search_model(user)
|
|
||||||
new_config = await adapters.aset_user_search_model(user, int(id))
|
|
||||||
|
|
||||||
if prev_config and int(id) != prev_config.id and new_config:
|
|
||||||
await EntryAdapters.adelete_all_entries(user)
|
|
||||||
|
|
||||||
if not prev_config:
|
|
||||||
# If the use was just using the default config, delete all the entries and set the new config.
|
|
||||||
await EntryAdapters.adelete_all_entries(user)
|
|
||||||
|
|
||||||
if new_config is None:
|
|
||||||
return {"status": "error", "message": "Model not found"}
|
|
||||||
else:
|
|
||||||
update_telemetry_state(
|
|
||||||
request=request,
|
|
||||||
telemetry_type="api",
|
|
||||||
api="set_search_model",
|
|
||||||
client=client,
|
|
||||||
metadata={"search_model": new_config.setting.name},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
@api_model.post("/paint", status_code=200)
|
@api_model.post("/paint", status_code=200)
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def update_paint_model(
|
async def update_paint_model(
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request, Response
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
|
|
||||||
from khoj.database import adapters
|
from khoj.database import adapters
|
||||||
|
from khoj.database.models import KhojUser, Subscription
|
||||||
from khoj.routers.helpers import update_telemetry_state
|
from khoj.routers.helpers import update_telemetry_state
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
|
||||||
|
@ -73,7 +75,7 @@ async def subscribe(request: Request):
|
||||||
elif event_type in {"customer.subscription.deleted"}:
|
elif event_type in {"customer.subscription.deleted"}:
|
||||||
# Reset the user to trial state
|
# Reset the user to trial state
|
||||||
user, is_new = await adapters.set_user_subscription(
|
user, is_new = await adapters.set_user_subscription(
|
||||||
customer_email, is_recurring=False, renewal_date=False, type="trial"
|
customer_email, is_recurring=False, renewal_date=False, type=Subscription.Type.TRIAL
|
||||||
)
|
)
|
||||||
success = user is not None
|
success = user is not None
|
||||||
|
|
||||||
|
@ -82,7 +84,7 @@ async def subscribe(request: Request):
|
||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="create_user",
|
api="create_user",
|
||||||
metadata={"user_id": str(user.user.uuid)},
|
metadata={"server_id": str(user.user.uuid)},
|
||||||
)
|
)
|
||||||
logger.log(logging.INFO, f"🥳 New User Created: {user.user.uuid}")
|
logger.log(logging.INFO, f"🥳 New User Created: {user.user.uuid}")
|
||||||
|
|
||||||
|
@ -116,3 +118,19 @@ async def update_subscription(request: Request, email: str, operation: str):
|
||||||
return {"success": False, "message": "No subscription found that is set to cancel"}
|
return {"success": False, "message": "No subscription found that is set to cancel"}
|
||||||
|
|
||||||
return {"success": False, "message": "Invalid operation"}
|
return {"success": False, "message": "Invalid operation"}
|
||||||
|
|
||||||
|
|
||||||
|
@subscription_router.post("/trial", response_class=Response)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def start_trial(request: Request) -> Response:
|
||||||
|
user: KhojUser = request.user.object
|
||||||
|
|
||||||
|
# Start a trial for the user
|
||||||
|
updated_subscription = await adapters.astart_trial_subscription(user)
|
||||||
|
|
||||||
|
# Return trial status as a JSON response
|
||||||
|
return Response(
|
||||||
|
content=json.dumps({"trial_enabled": updated_subscription is not None}),
|
||||||
|
media_type="application/json",
|
||||||
|
status_code=200,
|
||||||
|
)
|
|
@ -90,7 +90,7 @@ async def login_magic_link(request: Request, form: MagicLinkForm):
|
||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="create_user",
|
api="create_user",
|
||||||
metadata={"user_id": str(user.uuid)},
|
metadata={"server_id": str(user.uuid)},
|
||||||
)
|
)
|
||||||
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
||||||
|
|
||||||
|
@ -175,7 +175,7 @@ async def auth(request: Request):
|
||||||
request=request,
|
request=request,
|
||||||
telemetry_type="api",
|
telemetry_type="api",
|
||||||
api="create_user",
|
api="create_user",
|
||||||
metadata={"user_id": str(khoj_user.uuid)},
|
metadata={"server_id": str(khoj_user.uuid)},
|
||||||
)
|
)
|
||||||
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
|
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
|
||||||
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
|
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
@ -22,7 +23,7 @@ from typing import (
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse
|
||||||
|
|
||||||
import cron_descriptor
|
import cron_descriptor
|
||||||
import pytz
|
import pytz
|
||||||
|
@ -31,16 +32,19 @@ from apscheduler.job import Job
|
||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
|
from pydantic import BaseModel
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
from starlette.requests import URL
|
from starlette.requests import URL
|
||||||
|
|
||||||
from khoj.database import adapters
|
from khoj.database import adapters
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
|
LENGTH_OF_FREE_TRIAL,
|
||||||
AgentAdapters,
|
AgentAdapters,
|
||||||
AutomationAdapters,
|
AutomationAdapters,
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
FileObjectAdapters,
|
FileObjectAdapters,
|
||||||
|
ais_user_subscribed,
|
||||||
create_khoj_token,
|
create_khoj_token,
|
||||||
get_khoj_tokens,
|
get_khoj_tokens,
|
||||||
get_user_name,
|
get_user_name,
|
||||||
|
@ -210,6 +214,21 @@ def get_next_url(request: Request) -> str:
|
||||||
return urljoin(str(request.base_url).rstrip("/"), next_path)
|
return urljoin(str(request.base_url).rstrip("/"), next_path)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||||
|
chat_history = ""
|
||||||
|
for chat in conversation_history.get("chat", [])[-n:]:
|
||||||
|
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
||||||
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
|
chat_history += f"{agent_name}: {chat['message']}\n"
|
||||||
|
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||||
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
|
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||||
|
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
||||||
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
|
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
||||||
|
return chat_history
|
||||||
|
|
||||||
|
|
||||||
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
|
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
|
||||||
if query.startswith("/notes"):
|
if query.startswith("/notes"):
|
||||||
return ConversationCommand.Notes
|
return ConversationCommand.Notes
|
||||||
|
@ -227,6 +246,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
||||||
return ConversationCommand.AutomatedTask
|
return ConversationCommand.AutomatedTask
|
||||||
elif query.startswith("/summarize"):
|
elif query.startswith("/summarize"):
|
||||||
return ConversationCommand.Summarize
|
return ConversationCommand.Summarize
|
||||||
|
elif query.startswith("/diagram"):
|
||||||
|
return ConversationCommand.Diagram
|
||||||
# If no relevant notes found for the given query
|
# If no relevant notes found for the given query
|
||||||
elif not any_references:
|
elif not any_references:
|
||||||
return ConversationCommand.General
|
return ConversationCommand.General
|
||||||
|
@ -282,7 +303,7 @@ async def aget_relevant_information_sources(
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
is_task: bool,
|
is_task: bool,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -301,8 +322,8 @@ async def aget_relevant_information_sources(
|
||||||
|
|
||||||
chat_history = construct_chat_history(conversation_history)
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
if uploaded_image_url:
|
if query_images:
|
||||||
query = f"[placeholder for user attached image]\n{query}"
|
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||||
|
|
||||||
personality_context = (
|
personality_context = (
|
||||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||||
|
@ -359,7 +380,7 @@ async def aget_relevant_output_modes(
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
is_task: bool = False,
|
is_task: bool = False,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -381,8 +402,8 @@ async def aget_relevant_output_modes(
|
||||||
|
|
||||||
chat_history = construct_chat_history(conversation_history)
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
if uploaded_image_url:
|
if query_images:
|
||||||
query = f"[placeholder for user attached image]\n{query}"
|
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||||
|
|
||||||
personality_context = (
|
personality_context = (
|
||||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||||
|
@ -425,7 +446,7 @@ async def infer_webpage_urls(
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -451,7 +472,7 @@ async def infer_webpage_urls(
|
||||||
|
|
||||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||||
|
@ -472,7 +493,7 @@ async def generate_online_subqueries(
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -498,7 +519,7 @@ async def generate_online_subqueries(
|
||||||
|
|
||||||
with timer("Chat actor: Generate online search subqueries", logger):
|
with timer("Chat actor: Generate online search subqueries", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
@ -517,7 +538,7 @@ async def generate_online_subqueries(
|
||||||
|
|
||||||
|
|
||||||
async def schedule_query(
|
async def schedule_query(
|
||||||
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None
|
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
|
||||||
) -> Tuple[str, ...]:
|
) -> Tuple[str, ...]:
|
||||||
"""
|
"""
|
||||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||||
|
@ -530,7 +551,7 @@ async def schedule_query(
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
crontime_prompt, query_images=query_images, response_type="json_object", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
@ -544,12 +565,14 @@ async def schedule_query(
|
||||||
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
||||||
|
|
||||||
|
|
||||||
async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agent: Agent = None) -> Union[str, None]:
|
async def extract_relevant_info(
|
||||||
|
qs: set[str], corpus: str, user: KhojUser = None, agent: Agent = None
|
||||||
|
) -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
Extract relevant information for a given query from the target corpus
|
Extract relevant information for a given query from the target corpus
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if is_none_or_empty(corpus) or is_none_or_empty(q):
|
if is_none_or_empty(corpus) or is_none_or_empty(qs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
personality_context = (
|
personality_context = (
|
||||||
|
@ -557,17 +580,16 @@ async def extract_relevant_info(q: str, corpus: str, user: KhojUser = None, agen
|
||||||
)
|
)
|
||||||
|
|
||||||
extract_relevant_information = prompts.extract_relevant_information.format(
|
extract_relevant_information = prompts.extract_relevant_information.format(
|
||||||
query=q,
|
query=", ".join(qs),
|
||||||
corpus=corpus.strip(),
|
corpus=corpus.strip(),
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Chat actor: Extract relevant information from data", logger):
|
response = await send_message_to_model_wrapper(
|
||||||
response = await send_message_to_model_wrapper(
|
extract_relevant_information,
|
||||||
extract_relevant_information,
|
prompts.system_prompt_extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_information,
|
user=user,
|
||||||
user=user,
|
)
|
||||||
)
|
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
|
@ -575,7 +597,7 @@ async def extract_relevant_summary(
|
||||||
q: str,
|
q: str,
|
||||||
corpus: str,
|
corpus: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> Union[str, None]:
|
) -> Union[str, None]:
|
||||||
|
@ -604,7 +626,7 @@ async def extract_relevant_summary(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_summary,
|
prompts.system_prompt_extract_relevant_summary,
|
||||||
user=user,
|
user=user,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
@ -614,8 +636,7 @@ async def generate_summary_from_files(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
file_filters: List[str],
|
file_filters: List[str],
|
||||||
meta_log: dict,
|
meta_log: dict,
|
||||||
subscribed: bool,
|
query_images: List[str] = None,
|
||||||
uploaded_image_url: str = None,
|
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
send_response_func: Optional[Callable] = None,
|
send_response_func: Optional[Callable] = None,
|
||||||
|
@ -647,7 +668,7 @@ async def generate_summary_from_files(
|
||||||
q,
|
q,
|
||||||
contextual_data,
|
contextual_data,
|
||||||
conversation_history=meta_log,
|
conversation_history=meta_log,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
@ -661,6 +682,129 @@ async def generate_summary_from_files(
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_excalidraw_diagram(
|
||||||
|
q: str,
|
||||||
|
conversation_history: Dict[str, Any],
|
||||||
|
location_data: LocationData,
|
||||||
|
note_references: List[Dict[str, Any]],
|
||||||
|
online_results: Optional[dict] = None,
|
||||||
|
query_images: List[str] = None,
|
||||||
|
user: KhojUser = None,
|
||||||
|
agent: Agent = None,
|
||||||
|
send_status_func: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
if send_status_func:
|
||||||
|
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
|
|
||||||
|
better_diagram_description_prompt = await generate_better_diagram_description(
|
||||||
|
q=q,
|
||||||
|
conversation_history=conversation_history,
|
||||||
|
location_data=location_data,
|
||||||
|
note_references=note_references,
|
||||||
|
online_results=online_results,
|
||||||
|
query_images=query_images,
|
||||||
|
user=user,
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
if send_status_func:
|
||||||
|
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
|
||||||
|
yield {ChatEvent.STATUS: event}
|
||||||
|
|
||||||
|
excalidraw_diagram_description = await generate_excalidraw_diagram_from_description(
|
||||||
|
q=better_diagram_description_prompt,
|
||||||
|
user=user,
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield better_diagram_description_prompt, excalidraw_diagram_description
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_better_diagram_description(
|
||||||
|
q: str,
|
||||||
|
conversation_history: Dict[str, Any],
|
||||||
|
location_data: LocationData,
|
||||||
|
note_references: List[Dict[str, Any]],
|
||||||
|
online_results: Optional[dict] = None,
|
||||||
|
query_images: List[str] = None,
|
||||||
|
user: KhojUser = None,
|
||||||
|
agent: Agent = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a diagram description from the given query and context
|
||||||
|
"""
|
||||||
|
|
||||||
|
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
|
||||||
|
personality_context = (
|
||||||
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
if location_data:
|
||||||
|
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||||
|
else:
|
||||||
|
location_prompt = "Unknown"
|
||||||
|
|
||||||
|
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
|
||||||
|
|
||||||
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
|
simplified_online_results = {}
|
||||||
|
|
||||||
|
if online_results:
|
||||||
|
for result in online_results:
|
||||||
|
if online_results[result].get("answerBox"):
|
||||||
|
simplified_online_results[result] = online_results[result]["answerBox"]
|
||||||
|
elif online_results[result].get("webpages"):
|
||||||
|
simplified_online_results[result] = online_results[result]["webpages"]
|
||||||
|
|
||||||
|
improve_diagram_description_prompt = prompts.improve_diagram_description_prompt.format(
|
||||||
|
query=q,
|
||||||
|
chat_history=chat_history,
|
||||||
|
location=location_prompt,
|
||||||
|
current_date=today_date,
|
||||||
|
references=user_references,
|
||||||
|
online_results=simplified_online_results,
|
||||||
|
personality_context=personality_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
with timer("Chat actor: Generate better diagram description", logger):
|
||||||
|
response = await send_message_to_model_wrapper(
|
||||||
|
improve_diagram_description_prompt, query_images=query_images, user=user
|
||||||
|
)
|
||||||
|
response = response.strip()
|
||||||
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
|
response = response[1:-1]
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_excalidraw_diagram_from_description(
|
||||||
|
q: str,
|
||||||
|
user: KhojUser = None,
|
||||||
|
agent: Agent = None,
|
||||||
|
) -> str:
|
||||||
|
personality_context = (
|
||||||
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
excalidraw_diagram_generation = prompts.excalidraw_diagram_generation_prompt.format(
|
||||||
|
personality_context=personality_context,
|
||||||
|
query=q,
|
||||||
|
)
|
||||||
|
|
||||||
|
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||||
|
raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user)
|
||||||
|
raw_response = raw_response.strip()
|
||||||
|
raw_response = remove_json_codeblock(raw_response)
|
||||||
|
response: Dict[str, str] = json.loads(raw_response)
|
||||||
|
if not response or not isinstance(response, List) or not isinstance(response[0], Dict):
|
||||||
|
# TODO Some additional validation here that it's a valid Excalidraw diagram
|
||||||
|
raise AssertionError(f"Invalid response for improving diagram description: {response}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
async def generate_better_image_prompt(
|
async def generate_better_image_prompt(
|
||||||
q: str,
|
q: str,
|
||||||
conversation_history: str,
|
conversation_history: str,
|
||||||
|
@ -668,7 +812,7 @@ async def generate_better_image_prompt(
|
||||||
note_references: List[Dict[str, Any]],
|
note_references: List[Dict[str, Any]],
|
||||||
online_results: Optional[dict] = None,
|
online_results: Optional[dict] = None,
|
||||||
model_type: Optional[str] = None,
|
model_type: Optional[str] = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -720,7 +864,7 @@ async def generate_better_image_prompt(
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user)
|
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
response = response[1:-1]
|
response = response[1:-1]
|
||||||
|
@ -728,6 +872,117 @@ async def generate_better_image_prompt(
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def send_message_to_model_wrapper(
|
||||||
|
message: str,
|
||||||
|
system_message: str = "",
|
||||||
|
response_type: str = "text",
|
||||||
|
user: KhojUser = None,
|
||||||
|
query_images: List[str] = None,
|
||||||
|
):
|
||||||
|
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||||
|
vision_available = conversation_config.vision_enabled
|
||||||
|
if not vision_available and query_images:
|
||||||
|
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||||
|
if vision_enabled_config:
|
||||||
|
conversation_config = vision_enabled_config
|
||||||
|
vision_available = True
|
||||||
|
|
||||||
|
subscribed = await ais_user_subscribed(user)
|
||||||
|
chat_model = conversation_config.chat_model
|
||||||
|
max_tokens = (
|
||||||
|
conversation_config.subscribed_max_prompt_size
|
||||||
|
if subscribed and conversation_config.subscribed_max_prompt_size
|
||||||
|
else conversation_config.max_prompt_size
|
||||||
|
)
|
||||||
|
tokenizer = conversation_config.tokenizer
|
||||||
|
model_type = conversation_config.model_type
|
||||||
|
vision_available = conversation_config.vision_enabled
|
||||||
|
|
||||||
|
if model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||||
|
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||||
|
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||||
|
|
||||||
|
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model,
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
tokenizer_name=tokenizer,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return send_message_to_model_offline(
|
||||||
|
messages=truncated_messages,
|
||||||
|
loaded_model=loaded_model,
|
||||||
|
model=chat_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
streaming=False,
|
||||||
|
response_type=response_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
|
openai_chat_config = conversation_config.openai_config
|
||||||
|
api_key = openai_chat_config.api_key
|
||||||
|
api_base_url = openai_chat_config.api_base_url
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
tokenizer_name=tokenizer,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
query_images=query_images,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return send_message_to_model(
|
||||||
|
messages=truncated_messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model=chat_model,
|
||||||
|
response_type=response_type,
|
||||||
|
api_base_url=api_base_url,
|
||||||
|
)
|
||||||
|
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
tokenizer_name=tokenizer,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
query_images=query_images,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return anthropic_send_message_to_model(
|
||||||
|
messages=truncated_messages,
|
||||||
|
api_key=api_key,
|
||||||
|
model=chat_model,
|
||||||
|
)
|
||||||
|
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
|
api_key = conversation_config.openai_config.api_key
|
||||||
|
truncated_messages = generate_chatml_messages_with_context(
|
||||||
|
user_message=message,
|
||||||
|
system_message=system_message,
|
||||||
|
model_name=chat_model,
|
||||||
|
max_prompt_size=max_tokens,
|
||||||
|
tokenizer_name=tokenizer,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
query_images=query_images,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return gemini_send_message_to_model(
|
||||||
|
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
|
||||||
|
|
||||||
def send_message_to_model_wrapper_sync(
|
def send_message_to_model_wrapper_sync(
|
||||||
message: str,
|
message: str,
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
|
@ -809,12 +1064,14 @@ def send_message_to_model_wrapper_sync(
|
||||||
model_name=chat_model,
|
model_name=chat_model,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
return gemini_send_message_to_model(
|
||||||
messages=truncated_messages,
|
messages=truncated_messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
|
response_type=response_type,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||||
|
@ -834,8 +1091,8 @@ def generate_chat_response(
|
||||||
conversation_id: str = None,
|
conversation_id: str = None,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: Optional[str] = None,
|
user_name: Optional[str] = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
|
||||||
meta_research: str = "",
|
meta_research: str = "",
|
||||||
|
query_images: Optional[List[str]] = None,
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
chat_response = None
|
chat_response = None
|
||||||
|
@ -858,12 +1115,12 @@ def generate_chat_response(
|
||||||
inferred_queries=inferred_queries,
|
inferred_queries=inferred_queries,
|
||||||
client_application=client_application,
|
client_application=client_application,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = conversation_config.vision_enabled
|
||||||
if not vision_available and uploaded_image_url:
|
if not vision_available and query_images:
|
||||||
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
||||||
if vision_enabled_config:
|
if vision_enabled_config:
|
||||||
conversation_config = vision_enabled_config
|
conversation_config = vision_enabled_config
|
||||||
|
@ -894,7 +1151,8 @@ def generate_chat_response(
|
||||||
chat_response = converse(
|
chat_response = converse(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
query_to_run,
|
query_to_run,
|
||||||
image_url=uploaded_image_url,
|
q,
|
||||||
|
query_images=query_images,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
code_results=code_results,
|
code_results=code_results,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
|
@ -937,6 +1195,10 @@ def generate_chat_response(
|
||||||
online_results,
|
online_results,
|
||||||
code_results,
|
code_results,
|
||||||
meta_log,
|
meta_log,
|
||||||
|
q,
|
||||||
|
query_images=query_images,
|
||||||
|
online_results=online_results,
|
||||||
|
conversation_log=meta_log,
|
||||||
model=conversation_config.chat_model,
|
model=conversation_config.chat_model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
|
@ -946,6 +1208,7 @@ def generate_chat_response(
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
vision_available=vision_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata.update({"chat_model": conversation_config.chat_model})
|
metadata.update({"chat_model": conversation_config.chat_model})
|
||||||
|
@ -957,6 +1220,22 @@ def generate_chat_response(
|
||||||
return chat_response, metadata
|
return chat_response, metadata
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequestBody(BaseModel):
|
||||||
|
q: str
|
||||||
|
n: Optional[int] = 7
|
||||||
|
d: Optional[float] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
title: Optional[str] = None
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
city: Optional[str] = None
|
||||||
|
region: Optional[str] = None
|
||||||
|
country: Optional[str] = None
|
||||||
|
country_code: Optional[str] = None
|
||||||
|
timezone: Optional[str] = None
|
||||||
|
images: Optional[list[str]] = None
|
||||||
|
create_new: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
|
@ -1002,13 +1281,58 @@ class ApiUserRateLimiter:
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=429,
|
status_code=429,
|
||||||
detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
|
detail="I'm glad you're enjoying interacting with me! But you've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the current request to the cache
|
# Add the current request to the cache
|
||||||
UserRequests.objects.create(user=user, slug=self.slug)
|
UserRequests.objects.create(user=user, slug=self.slug)
|
||||||
|
|
||||||
|
|
||||||
|
class ApiImageRateLimiter:
|
||||||
|
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
|
||||||
|
self.max_images = max_images
|
||||||
|
self.max_combined_size_mb = max_combined_size_mb
|
||||||
|
|
||||||
|
def __call__(self, request: Request, body: ChatRequestBody):
|
||||||
|
if state.billing_enabled is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Rate limiting is disabled if user unauthenticated.
|
||||||
|
# Other systems handle authentication
|
||||||
|
if not request.user.is_authenticated:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not body.images:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check number of images
|
||||||
|
if len(body.images) > self.max_images:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check total size of images
|
||||||
|
total_size_mb = 0.0
|
||||||
|
for image in body.images:
|
||||||
|
# Unquote the image in case it's URL encoded
|
||||||
|
image = unquote(image)
|
||||||
|
# Assuming the image is a base64 encoded string
|
||||||
|
# Remove the data:image/jpeg;base64, part if present
|
||||||
|
if "," in image:
|
||||||
|
image = image.split(",", 1)[1]
|
||||||
|
|
||||||
|
# Decode base64 to get the actual size
|
||||||
|
image_bytes = base64.b64decode(image)
|
||||||
|
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
|
||||||
|
|
||||||
|
if total_size_mb > self.max_combined_size_mb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationCommandRateLimiter:
|
class ConversationCommandRateLimiter:
|
||||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
||||||
self.slug = slug
|
self.slug = slug
|
||||||
|
@ -1411,10 +1735,16 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
||||||
|
|
||||||
user_subscription_state = get_user_subscription_state(user.email)
|
user_subscription_state = get_user_subscription_state(user.email)
|
||||||
user_subscription = adapters.get_user_subscription(user.email)
|
user_subscription = adapters.get_user_subscription(user.email)
|
||||||
|
|
||||||
subscription_renewal_date = (
|
subscription_renewal_date = (
|
||||||
user_subscription.renewal_date.strftime("%d %b %Y")
|
user_subscription.renewal_date.strftime("%d %b %Y")
|
||||||
if user_subscription and user_subscription.renewal_date
|
if user_subscription and user_subscription.renewal_date
|
||||||
else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y")
|
else None
|
||||||
|
)
|
||||||
|
subscription_enabled_trial_at = (
|
||||||
|
user_subscription.enabled_trial_at.strftime("%d %b %Y")
|
||||||
|
if user_subscription and user_subscription.enabled_trial_at
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
given_name = get_user_name(user)
|
given_name = get_user_name(user)
|
||||||
|
|
||||||
|
@ -1437,13 +1767,6 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
||||||
for chat_model in chat_models:
|
for chat_model in chat_models:
|
||||||
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id})
|
chat_model_options.append({"name": chat_model.chat_model, "id": chat_model.id})
|
||||||
|
|
||||||
search_model_options = adapters.get_or_create_search_models().all()
|
|
||||||
all_search_model_options = list()
|
|
||||||
for search_model_option in search_model_options:
|
|
||||||
all_search_model_options.append({"name": search_model_option.name, "id": search_model_option.id})
|
|
||||||
|
|
||||||
current_search_model_option = adapters.get_user_search_model_or_default(user)
|
|
||||||
|
|
||||||
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
|
selected_paint_model_config = ConversationAdapters.get_user_text_to_image_model_config(user)
|
||||||
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
paint_model_options = ConversationAdapters.get_text_to_image_model_options().all()
|
||||||
all_paint_model_options = list()
|
all_paint_model_options = list()
|
||||||
|
@ -1476,8 +1799,6 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
||||||
"has_documents": has_documents,
|
"has_documents": has_documents,
|
||||||
"notion_token": notion_token,
|
"notion_token": notion_token,
|
||||||
# user model settings
|
# user model settings
|
||||||
"search_model_options": all_search_model_options,
|
|
||||||
"selected_search_model_config": current_search_model_option.id,
|
|
||||||
"chat_model_options": chat_model_options,
|
"chat_model_options": chat_model_options,
|
||||||
"selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None,
|
"selected_chat_model_config": selected_chat_model_config.id if selected_chat_model_config else None,
|
||||||
"paint_model_options": all_paint_model_options,
|
"paint_model_options": all_paint_model_options,
|
||||||
|
@ -1487,6 +1808,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
||||||
# user billing info
|
# user billing info
|
||||||
"subscription_state": user_subscription_state,
|
"subscription_state": user_subscription_state,
|
||||||
"subscription_renewal_date": subscription_renewal_date,
|
"subscription_renewal_date": subscription_renewal_date,
|
||||||
|
"subscription_enabled_trial_at": subscription_enabled_trial_at,
|
||||||
# server settings
|
# server settings
|
||||||
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
|
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
|
||||||
"billing_enabled": state.billing_enabled,
|
"billing_enabled": state.billing_enabled,
|
||||||
|
@ -1495,6 +1817,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
|
||||||
"khoj_version": state.khoj_version,
|
"khoj_version": state.khoj_version,
|
||||||
"anonymous_mode": state.anonymous_mode,
|
"anonymous_mode": state.anonymous_mode,
|
||||||
"notion_oauth_url": notion_oauth_url,
|
"notion_oauth_url": notion_oauth_url,
|
||||||
|
"length_of_free_trial": LENGTH_OF_FREE_TRIAL,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ async def apick_next_tool(
|
||||||
query: str,
|
query: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = [],
|
||||||
location: LocationData = None,
|
location: LocationData = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
@ -62,8 +62,8 @@ async def apick_next_tool(
|
||||||
|
|
||||||
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
|
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
|
||||||
|
|
||||||
if uploaded_image_url:
|
if query_images:
|
||||||
query = f"[placeholder for user attached image]\n{query}"
|
query = f"[placeholder for user attached images]\n{query}"
|
||||||
|
|
||||||
personality_context = (
|
personality_context = (
|
||||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||||
|
@ -131,8 +131,7 @@ async def execute_information_collection(
|
||||||
query: str,
|
query: str,
|
||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
subscribed: bool,
|
query_images: List[str],
|
||||||
uploaded_image_url: str = None,
|
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
|
@ -154,7 +153,7 @@ async def execute_information_collection(
|
||||||
query,
|
query,
|
||||||
conversation_history,
|
conversation_history,
|
||||||
user,
|
user,
|
||||||
uploaded_image_url,
|
query_images,
|
||||||
location,
|
location,
|
||||||
user_name,
|
user_name,
|
||||||
agent,
|
agent,
|
||||||
|
@ -180,7 +179,7 @@ async def execute_information_collection(
|
||||||
[ConversationCommand.Default],
|
[ConversationCommand.Default],
|
||||||
location,
|
location,
|
||||||
send_status_func,
|
send_status_func,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -208,11 +207,10 @@ async def execute_information_collection(
|
||||||
conversation_history,
|
conversation_history,
|
||||||
location,
|
location,
|
||||||
user,
|
user,
|
||||||
subscribed,
|
|
||||||
send_status_func,
|
send_status_func,
|
||||||
[],
|
[],
|
||||||
max_webpages_to_read=0,
|
max_webpages_to_read=0,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -229,7 +227,7 @@ async def execute_information_collection(
|
||||||
location,
|
location,
|
||||||
user,
|
user,
|
||||||
send_status_func,
|
send_status_func,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -259,7 +257,7 @@ async def execute_information_collection(
|
||||||
location,
|
location,
|
||||||
user,
|
user,
|
||||||
send_status_func,
|
send_status_func,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
|
|
@ -51,17 +51,6 @@ def chat_page(request: Request):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@web_client.get("/experimental", response_class=FileResponse)
|
|
||||||
@requires(["authenticated"], redirect="login_page")
|
|
||||||
def experimental_page(request: Request):
|
|
||||||
return templates.TemplateResponse(
|
|
||||||
"index.html",
|
|
||||||
context={
|
|
||||||
"request": request,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@web_client.get("/factchecker", response_class=FileResponse)
|
@web_client.get("/factchecker", response_class=FileResponse)
|
||||||
def fact_checker_page(request: Request):
|
def fact_checker_page(request: Request):
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(
|
||||||
|
|
|
@ -8,7 +8,11 @@ import torch
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from sentence_transformers import util
|
from sentence_transformers import util
|
||||||
|
|
||||||
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
from khoj.database.adapters import (
|
||||||
|
EntryAdapters,
|
||||||
|
get_default_search_model,
|
||||||
|
get_user_default_search_model,
|
||||||
|
)
|
||||||
from khoj.database.models import Agent
|
from khoj.database.models import Agent
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import KhojUser
|
from khoj.database.models import KhojUser
|
||||||
|
@ -110,7 +114,7 @@ async def query(
|
||||||
file_type = search_type_to_embeddings_type[type.value]
|
file_type = search_type_to_embeddings_type[type.value]
|
||||||
|
|
||||||
query = raw_query
|
query = raw_query
|
||||||
search_model = await sync_to_async(get_user_search_model_or_default)(user)
|
search_model = await sync_to_async(get_user_default_search_model)(user)
|
||||||
if not max_distance:
|
if not max_distance:
|
||||||
if search_model.bi_encoder_confidence_threshold:
|
if search_model.bi_encoder_confidence_threshold:
|
||||||
max_distance = search_model.bi_encoder_confidence_threshold
|
max_distance = search_model.bi_encoder_confidence_threshold
|
||||||
|
|
|
@ -2,10 +2,12 @@ from __future__ import annotations # to avoid quoting type hints
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import io
|
import io
|
||||||
|
import ipaddress
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import random
|
import random
|
||||||
|
import urllib.parse
|
||||||
import uuid
|
import uuid
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -125,6 +127,8 @@ def get_file_type(file_type: str, file_content: bytes) -> tuple[str, str]:
|
||||||
return "image", encoding
|
return "image", encoding
|
||||||
elif file_type in ["image/png"]:
|
elif file_type in ["image/png"]:
|
||||||
return "image", encoding
|
return "image", encoding
|
||||||
|
elif file_type in ["image/webp"]:
|
||||||
|
return "image", encoding
|
||||||
elif content_group in ["code", "text"]:
|
elif content_group in ["code", "text"]:
|
||||||
return "plaintext", encoding
|
return "plaintext", encoding
|
||||||
else:
|
else:
|
||||||
|
@ -164,9 +168,9 @@ def get_class_by_name(name: str) -> object:
|
||||||
class timer:
|
class timer:
|
||||||
"""Context manager to log time taken for a block of code to run"""
|
"""Context manager to log time taken for a block of code to run"""
|
||||||
|
|
||||||
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None):
|
def __init__(self, message: str, logger: logging.Logger, device: torch.device = None, log_level=logging.DEBUG):
|
||||||
self.message = message
|
self.message = message
|
||||||
self.logger = logger
|
self.logger = logger.debug if log_level == logging.DEBUG else logger.info
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
@ -176,9 +180,9 @@ class timer:
|
||||||
def __exit__(self, *_):
|
def __exit__(self, *_):
|
||||||
elapsed = perf_counter() - self.start
|
elapsed = perf_counter() - self.start
|
||||||
if self.device is None:
|
if self.device is None:
|
||||||
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds")
|
self.logger(f"{self.message}: {elapsed:.3f} seconds")
|
||||||
else:
|
else:
|
||||||
self.logger.debug(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}")
|
self.logger(f"{self.message}: {elapsed:.3f} seconds on device: {self.device}")
|
||||||
|
|
||||||
|
|
||||||
class LRU(OrderedDict):
|
class LRU(OrderedDict):
|
||||||
|
@ -315,6 +319,7 @@ class ConversationCommand(str, Enum):
|
||||||
Automation = "automation"
|
Automation = "automation"
|
||||||
AutomatedTask = "automated_task"
|
AutomatedTask = "automated_task"
|
||||||
Summarize = "summarize"
|
Summarize = "summarize"
|
||||||
|
Diagram = "diagram"
|
||||||
|
|
||||||
|
|
||||||
command_descriptions = {
|
command_descriptions = {
|
||||||
|
@ -324,10 +329,11 @@ command_descriptions = {
|
||||||
ConversationCommand.Online: "Search for information on the internet.",
|
ConversationCommand.Online: "Search for information on the internet.",
|
||||||
ConversationCommand.Webpage: "Get information from webpage suggested by you.",
|
ConversationCommand.Webpage: "Get information from webpage suggested by you.",
|
||||||
ConversationCommand.Code: "Run Python code to parse information, run complex calculations, create documents and charts.",
|
ConversationCommand.Code: "Run Python code to parse information, run complex calculations, create documents and charts.",
|
||||||
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
ConversationCommand.Image: "Generate illustrative, creative images by describing your imagination in words.",
|
||||||
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
|
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
|
||||||
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
|
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
|
||||||
ConversationCommand.Summarize: "Get help with a question pertaining to an entire document.",
|
ConversationCommand.Summarize: "Get help with a question pertaining to an entire document.",
|
||||||
|
ConversationCommand.Diagram: "Draw a flowchart, diagram, or any other visual representation best expressed with primitives like lines, rectangles, and text.",
|
||||||
}
|
}
|
||||||
|
|
||||||
command_descriptions_for_agent = {
|
command_descriptions_for_agent = {
|
||||||
|
@ -359,11 +365,16 @@ mode_descriptions_for_llm = {
|
||||||
ConversationCommand.Image: "Use this if the user is requesting you to generate images based on their description. This does not support generating charts or graphs.",
|
ConversationCommand.Image: "Use this if the user is requesting you to generate images based on their description. This does not support generating charts or graphs.",
|
||||||
ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",
|
ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",
|
||||||
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
|
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
|
||||||
|
ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency",
|
||||||
|
ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query.",
|
||||||
|
ConversationCommand.Diagram: "Use this if the user is requesting a visual representation that requires primitives like lines, rectangles, and text.",
|
||||||
}
|
}
|
||||||
|
|
||||||
mode_descriptions_for_agent = {
|
mode_descriptions_for_agent = {
|
||||||
ConversationCommand.Image: "Agent can generate images in response. It cannot not use this to generate charts and graphs.",
|
ConversationCommand.Image: "Agent can generate images in response. It cannot not use this to generate charts and graphs.",
|
||||||
|
ConversationCommand.Automation: "Agent can schedule a task to run at a scheduled date, time and frequency in response.",
|
||||||
ConversationCommand.Text: "Agent can generate text in response.",
|
ConversationCommand.Text: "Agent can generate text in response.",
|
||||||
|
ConversationCommand.Diagram: "Agent can generate a visual representation that requires primitives like lines, rectangles, and text.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -445,6 +456,46 @@ def is_internet_connected():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_internal_url(url: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a URL is likely to be internal/non-public.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url (str): The URL to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the URL is likely internal, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
|
hostname = parsed_url.hostname
|
||||||
|
|
||||||
|
# Check for localhost
|
||||||
|
if hostname in ["localhost", "127.0.0.1", "::1"]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for IP addresses in private ranges
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(hostname)
|
||||||
|
return ip.is_private
|
||||||
|
except ValueError:
|
||||||
|
pass # Not an IP address, continue with other checks
|
||||||
|
|
||||||
|
# Check for common internal TLDs
|
||||||
|
internal_tlds = [".local", ".internal", ".private", ".corp", ".home", ".lan"]
|
||||||
|
if any(hostname.endswith(tld) for tld in internal_tlds):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for URLs without a TLD
|
||||||
|
if "." not in hostname:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
# If we can't parse the URL or something else goes wrong, assume it's not internal
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def convert_image_to_webp(image_bytes):
|
def convert_image_to_webp(image_bytes):
|
||||||
"""Convert image bytes to webp format for faster loading"""
|
"""Convert image bytes to webp format for faster loading"""
|
||||||
image_io = io.BytesIO(image_bytes)
|
image_io = io.BytesIO(image_bytes)
|
||||||
|
|
|
@ -178,6 +178,13 @@ def api_user4(default_user4):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db
|
||||||
|
@pytest.fixture
|
||||||
|
def default_openai_chat_model_option():
|
||||||
|
chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
|
||||||
|
return chat_model
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db
|
@pytest.mark.django_db
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def offline_agent():
|
def offline_agent():
|
||||||
|
|
|
@ -86,7 +86,7 @@ class SubscriptionFactory(factory.django.DjangoModelFactory):
|
||||||
model = Subscription
|
model = Subscription
|
||||||
|
|
||||||
user = factory.SubFactory(UserFactory)
|
user = factory.SubFactory(UserFactory)
|
||||||
type = "standard"
|
type = Subscription.Type.STANDARD
|
||||||
is_recurring = False
|
is_recurring = False
|
||||||
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
|
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
|
||||||
|
|
||||||
|
|
211
tests/test_agents.py
Normal file
211
tests/test_agents.py
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
# tests/test_agents.py
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
|
from khoj.database.adapters import AgentAdapters
|
||||||
|
from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser
|
||||||
|
from khoj.routers.api import execute_search
|
||||||
|
from khoj.utils.helpers import get_absolute_path
|
||||||
|
from tests.helpers import ChatModelOptionsFactory
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_default_agent(default_user: KhojUser):
|
||||||
|
ChatModelOptionsFactory()
|
||||||
|
|
||||||
|
agent = AgentAdapters.create_default_agent(default_user)
|
||||||
|
assert agent is not None
|
||||||
|
assert agent.input_tools == []
|
||||||
|
assert agent.output_modes == []
|
||||||
|
assert agent.privacy_level == Agent.PrivacyLevel.PUBLIC
|
||||||
|
assert agent.managed_by_admin == True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModelOptions):
|
||||||
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
default_user,
|
||||||
|
"Test Agent",
|
||||||
|
"Test Personality",
|
||||||
|
Agent.PrivacyLevel.PRIVATE,
|
||||||
|
"icon",
|
||||||
|
"color",
|
||||||
|
default_openai_chat_model_option.chat_model,
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
assert new_agent is not None
|
||||||
|
assert new_agent.name == "Test Agent"
|
||||||
|
assert new_agent.privacy_level == Agent.PrivacyLevel.PRIVATE
|
||||||
|
assert new_agent.creator == default_user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
async def test_create_or_update_agent_with_knowledge_base(
|
||||||
|
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
|
||||||
|
):
|
||||||
|
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||||
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
default_user2,
|
||||||
|
"Test Agent",
|
||||||
|
"Test Personality",
|
||||||
|
Agent.PrivacyLevel.PRIVATE,
|
||||||
|
"icon",
|
||||||
|
"color",
|
||||||
|
default_openai_chat_model_option.chat_model,
|
||||||
|
[full_filename],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
entries = await sync_to_async(list)(Entry.objects.filter(agent=new_agent))
|
||||||
|
file_names = set()
|
||||||
|
for entry in entries:
|
||||||
|
file_names.add(entry.file_path)
|
||||||
|
|
||||||
|
assert new_agent is not None
|
||||||
|
assert new_agent.name == "Test Agent"
|
||||||
|
assert new_agent.privacy_level == Agent.PrivacyLevel.PRIVATE
|
||||||
|
assert new_agent.creator == default_user2
|
||||||
|
assert len(entries) > 0
|
||||||
|
assert full_filename in file_names
|
||||||
|
assert len(file_names) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
async def test_create_or_update_agent_with_knowledge_base_and_search(
|
||||||
|
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
|
||||||
|
):
|
||||||
|
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||||
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
default_user2,
|
||||||
|
"Test Agent",
|
||||||
|
"Test Personality",
|
||||||
|
Agent.PrivacyLevel.PRIVATE,
|
||||||
|
"icon",
|
||||||
|
"color",
|
||||||
|
default_openai_chat_model_option.chat_model,
|
||||||
|
[full_filename],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
search_result = await execute_search(user=default_user2, q="having kids", agent=new_agent)
|
||||||
|
|
||||||
|
assert len(search_result) == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
async def test_agent_with_knowledge_base_and_search_not_creator(
|
||||||
|
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
|
||||||
|
):
|
||||||
|
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||||
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
default_user2,
|
||||||
|
"Test Agent",
|
||||||
|
"Test Personality",
|
||||||
|
Agent.PrivacyLevel.PUBLIC,
|
||||||
|
"icon",
|
||||||
|
"color",
|
||||||
|
default_openai_chat_model_option.chat_model,
|
||||||
|
[full_filename],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
search_result = await execute_search(user=default_user3, q="having kids", agent=new_agent)
|
||||||
|
|
||||||
|
assert len(search_result) == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
|
||||||
|
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
|
||||||
|
):
|
||||||
|
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||||
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
default_user2,
|
||||||
|
"Test Agent",
|
||||||
|
"Test Personality",
|
||||||
|
Agent.PrivacyLevel.PRIVATE,
|
||||||
|
"icon",
|
||||||
|
"color",
|
||||||
|
default_openai_chat_model_option.chat_model,
|
||||||
|
[full_filename],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
search_result = await execute_search(user=default_user3, q="having kids", agent=new_agent)
|
||||||
|
|
||||||
|
assert len(search_result) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
async def test_agent_with_knowledge_base_and_search_not_creator_and_private_accessible_to_none(
|
||||||
|
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
|
||||||
|
):
|
||||||
|
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||||
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
default_user2,
|
||||||
|
"Test Agent",
|
||||||
|
"Test Personality",
|
||||||
|
Agent.PrivacyLevel.PRIVATE,
|
||||||
|
"icon",
|
||||||
|
"color",
|
||||||
|
default_openai_chat_model_option.chat_model,
|
||||||
|
[full_filename],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
search_result = await execute_search(user=None, q="having kids", agent=new_agent)
|
||||||
|
|
||||||
|
assert len(search_result) == 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
async def test_multiple_agents_with_knowledge_base_and_users(
|
||||||
|
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
|
||||||
|
):
|
||||||
|
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||||
|
new_agent = await AgentAdapters.aupdate_agent(
|
||||||
|
default_user2,
|
||||||
|
"Test Agent",
|
||||||
|
"Test Personality",
|
||||||
|
Agent.PrivacyLevel.PUBLIC,
|
||||||
|
"icon",
|
||||||
|
"color",
|
||||||
|
default_openai_chat_model_option.chat_model,
|
||||||
|
[full_filename],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
full_filename2 = get_absolute_path("tests/data/markdown/Namita.markdown")
|
||||||
|
new_agent2 = await AgentAdapters.aupdate_agent(
|
||||||
|
default_user2,
|
||||||
|
"Test Agent 2",
|
||||||
|
"Test Personality",
|
||||||
|
Agent.PrivacyLevel.PUBLIC,
|
||||||
|
"icon",
|
||||||
|
"color",
|
||||||
|
default_openai_chat_model_option.chat_model,
|
||||||
|
[full_filename2],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
|
||||||
|
search_result = await execute_search(user=default_user3, q="having kids", agent=new_agent2)
|
||||||
|
search_result2 = await execute_search(user=default_user3, q="Namita", agent=new_agent2)
|
||||||
|
|
||||||
|
assert len(search_result) == 0
|
||||||
|
assert len(search_result2) == 1
|
|
@ -1,6 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||||
from khoj.utils.fs_syncer import get_pdf_files
|
from khoj.utils.fs_syncer import get_pdf_files
|
||||||
from khoj.utils.rawconfig import TextContentConfig
|
from khoj.utils.rawconfig import TextContentConfig
|
||||||
|
@ -37,6 +39,7 @@ def test_multi_page_pdf_to_jsonl():
|
||||||
assert len(entries[1]) == 6
|
assert len(entries[1]) == 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Temporarily disabled OCR due to performance issues")
|
||||||
def test_ocr_page_pdf_to_jsonl():
|
def test_ocr_page_pdf_to_jsonl():
|
||||||
"Convert multiple pages from single PDF file to jsonl."
|
"Convert multiple pages from single PDF file to jsonl."
|
||||||
# Arrange
|
# Arrange
|
||||||
|
|
|
@ -77,5 +77,10 @@
|
||||||
"1.23.3": "0.15.0",
|
"1.23.3": "0.15.0",
|
||||||
"1.24.0": "0.15.0",
|
"1.24.0": "0.15.0",
|
||||||
"1.24.1": "0.15.0",
|
"1.24.1": "0.15.0",
|
||||||
"1.25.0": "0.15.0"
|
"1.25.0": "0.15.0",
|
||||||
|
"1.26.0": "0.15.0",
|
||||||
|
"1.26.1": "0.15.0",
|
||||||
|
"1.26.2": "0.15.0",
|
||||||
|
"1.26.3": "0.15.0",
|
||||||
|
"1.26.4": "0.15.0"
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue