diff --git a/pyproject.toml b/pyproject.toml
index 519d2d60..5a206cce 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,7 +42,7 @@ dependencies = [
"fastapi >= 0.104.1",
"python-multipart >= 0.0.5",
"jinja2 == 3.1.2",
- "openai >= 0.27.0, < 1.0.0",
+ "openai >= 1.0.0",
"tiktoken >= 0.3.2",
"tenacity >= 8.2.2",
"pillow ~= 9.5.0",
diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html
index 8c3fc49d..6b7fde07 100644
--- a/src/interface/desktop/chat.html
+++ b/src/interface/desktop/chat.html
@@ -179,7 +179,13 @@
return numOnlineReferences;
}
- function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
+ function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
+ if (intentType === "text-to-image") {
+ let imageMarkdown = `![](data:image/png;base64,${message})`;
+ renderMessage(imageMarkdown, by, dt);
+ return;
+ }
+
if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
@@ -244,6 +250,17 @@
// Remove any text between [INST] and tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/\[INST\].+(<\/s>)?/g, '');
+ // Customize the rendering of images
+ md.renderer.rules.image = function(tokens, idx, options, env, self) {
+ let token = tokens[idx];
+
+ // Add class="text-to-image" to images
+ token.attrPush(['class', 'text-to-image']);
+
+ // Use the default renderer to render image markdown format
+ return self.renderToken(tokens, idx, options);
+ };
+
// Render markdown
newHTML = md.render(newHTML);
// Get any elements with a class that starts with "language"
@@ -328,109 +345,142 @@
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
- // Call specified Khoj API which returns a streamed response of type text/plain
- fetch(url, { headers })
- .then(response => {
- const reader = response.body.getReader();
- const decoder = new TextDecoder();
- let rawResponse = "";
- let references = null;
+ // Call specified Khoj API
+ let response = await fetch(url, { headers });
+ let rawResponse = "";
+ const contentType = response.headers.get("content-type");
- function readStream() {
- reader.read().then(({ done, value }) => {
- if (done) {
- // Append any references after all the data has been streamed
- if (references != null) {
- newResponseText.appendChild(references);
- }
- document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
- document.getElementById("chat-input").removeAttribute("disabled");
- return;
+ if (contentType === "application/json") {
+ // Handle JSON response
+ try {
+ const responseAsJson = await response.json();
+ if (responseAsJson.image) {
+ // If response has image field, response is a generated image.
+ rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
+ }
+ if (responseAsJson.detail) {
+ // If response has detail field, response is an error message.
+ rawResponse += responseAsJson.detail;
+ }
+ } catch (error) {
+ // If the chunk is not a JSON object, just display it as is
+ rawResponse += chunk;
+ } finally {
+ newResponseText.innerHTML = "";
+ newResponseText.appendChild(formatHTMLMessage(rawResponse));
+
+ document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
+ document.getElementById("chat-input").removeAttribute("disabled");
+ }
+ } else {
+ // Handle streamed response of type text/event-stream or text/plain
+ const reader = response.body.getReader();
+ const decoder = new TextDecoder();
+ let references = null;
+
+ readStream();
+
+ function readStream() {
+ reader.read().then(({ done, value }) => {
+ if (done) {
+ // Append any references after all the data has been streamed
+ if (references != null) {
+ newResponseText.appendChild(references);
+ }
+ document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
+ document.getElementById("chat-input").removeAttribute("disabled");
+ return;
+ }
+
+ // Decode message chunk from stream
+ const chunk = decoder.decode(value, { stream: true });
+
+ if (chunk.includes("### compiled references:")) {
+ const additionalResponse = chunk.split("### compiled references:")[0];
+ rawResponse += additionalResponse;
+ newResponseText.innerHTML = "";
+ newResponseText.appendChild(formatHTMLMessage(rawResponse));
+
+ const rawReference = chunk.split("### compiled references:")[1];
+ const rawReferenceAsJson = JSON.parse(rawReference);
+ references = document.createElement('div');
+ references.classList.add("references");
+
+ let referenceExpandButton = document.createElement('button');
+ referenceExpandButton.classList.add("reference-expand-button");
+
+ let referenceSection = document.createElement('div');
+ referenceSection.classList.add("reference-section");
+ referenceSection.classList.add("collapsed");
+
+ let numReferences = 0;
+
+ // If rawReferenceAsJson is a list, then count the length
+ if (Array.isArray(rawReferenceAsJson)) {
+ numReferences = rawReferenceAsJson.length;
+
+ rawReferenceAsJson.forEach((reference, index) => {
+ let polishedReference = generateReference(reference, index);
+ referenceSection.appendChild(polishedReference);
+ });
+ } else {
+ numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
}
- // Decode message chunk from stream
- const chunk = decoder.decode(value, { stream: true });
+ references.appendChild(referenceExpandButton);
- if (chunk.includes("### compiled references:")) {
- const additionalResponse = chunk.split("### compiled references:")[0];
- rawResponse += additionalResponse;
+ referenceExpandButton.addEventListener('click', function() {
+ if (referenceSection.classList.contains("collapsed")) {
+ referenceSection.classList.remove("collapsed");
+ referenceSection.classList.add("expanded");
+ } else {
+ referenceSection.classList.add("collapsed");
+ referenceSection.classList.remove("expanded");
+ }
+ });
+
+ let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
+ referenceExpandButton.innerHTML = expandButtonText;
+ references.appendChild(referenceSection);
+ readStream();
+ } else {
+ // Display response from Khoj
+ if (newResponseText.getElementsByClassName("spinner").length > 0) {
+ newResponseText.removeChild(loadingSpinner);
+ }
+
+ // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
+ if (chunk.startsWith("{") && chunk.endsWith("}")) {
+ try {
+ const responseAsJson = JSON.parse(chunk);
+ if (responseAsJson.image) {
+ rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
+ }
+ if (responseAsJson.detail) {
+ rawResponse += responseAsJson.detail;
+ }
+ } catch (error) {
+ // If the chunk is not a JSON object, just display it as is
+ rawResponse += chunk;
+ } finally {
+ newResponseText.innerHTML = "";
+ newResponseText.appendChild(formatHTMLMessage(rawResponse));
+ }
+ } else {
+ // If the chunk is not a JSON object, just display it as is
+ rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
- const rawReference = chunk.split("### compiled references:")[1];
- const rawReferenceAsJson = JSON.parse(rawReference);
- references = document.createElement('div');
- references.classList.add("references");
-
- let referenceExpandButton = document.createElement('button');
- referenceExpandButton.classList.add("reference-expand-button");
-
- let referenceSection = document.createElement('div');
- referenceSection.classList.add("reference-section");
- referenceSection.classList.add("collapsed");
-
- let numReferences = 0;
-
- // If rawReferenceAsJson is a list, then count the length
- if (Array.isArray(rawReferenceAsJson)) {
- numReferences = rawReferenceAsJson.length;
-
- rawReferenceAsJson.forEach((reference, index) => {
- let polishedReference = generateReference(reference, index);
- referenceSection.appendChild(polishedReference);
- });
- } else {
- numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
- }
-
- references.appendChild(referenceExpandButton);
-
- referenceExpandButton.addEventListener('click', function() {
- if (referenceSection.classList.contains("collapsed")) {
- referenceSection.classList.remove("collapsed");
- referenceSection.classList.add("expanded");
- } else {
- referenceSection.classList.add("collapsed");
- referenceSection.classList.remove("expanded");
- }
- });
-
- let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
- referenceExpandButton.innerHTML = expandButtonText;
- references.appendChild(referenceSection);
readStream();
- } else {
- // Display response from Khoj
- if (newResponseText.getElementsByClassName("spinner").length > 0) {
- newResponseText.removeChild(loadingSpinner);
- }
- // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
- if (chunk.startsWith("{") && chunk.endsWith("}")) {
- try {
- const responseAsJson = JSON.parse(chunk);
- if (responseAsJson.detail) {
- newResponseText.innerHTML += responseAsJson.detail;
- }
- } catch (error) {
- // If the chunk is not a JSON object, just display it as is
- newResponseText.innerHTML += chunk;
- }
- } else {
- // If the chunk is not a JSON object, just display it as is
- rawResponse += chunk;
- newResponseText.innerHTML = "";
- newResponseText.appendChild(formatHTMLMessage(rawResponse));
-
- readStream();
- }
}
+ }
- // Scroll to bottom of chat window as chat response is streamed
- document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
- });
- }
- readStream();
- });
+ // Scroll to bottom of chat window as chat response is streamed
+ document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
+ });
+ }
+ }
}
function incrementalChat(event) {
@@ -522,7 +572,7 @@
.then(response => {
// Render conversation history, if any
response.forEach(chat_log => {
- renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
+ renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
});
})
.catch(err => {
@@ -625,9 +675,13 @@
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; })
.catch(err => {
- err.status == 422
- ? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
- : flashStatusInChatInput("⛔️ Failed to transcribe audio")
+ if (err.status === 501) {
+ flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
+ } else if (err.status === 422) {
+ flashStatusInChatInput("⛔️ Audio file to large to process.")
+ } else {
+ flashStatusInChatInput("⛔️ Failed to transcribe audio.")
+ }
});
};
@@ -810,6 +864,9 @@
margin-top: -10px;
transform: rotate(-60deg)
}
+ img.text-to-image {
+ max-width: 60%;
+ }
#chat-footer {
padding: 0;
@@ -1050,6 +1107,9 @@
margin: 4px;
grid-template-columns: auto;
}
+ img.text-to-image {
+ max-width: 100%;
+ }
}
@media only screen and (min-width: 600px) {
body {
diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts
index e5cff4d2..115f4c1f 100644
--- a/src/interface/obsidian/src/chat_modal.ts
+++ b/src/interface/obsidian/src/chat_modal.ts
@@ -2,6 +2,12 @@ import { App, MarkdownRenderer, Modal, request, requestUrl, setIcon } from 'obsi
import { KhojSetting } from 'src/settings';
import fetch from "node-fetch";
+export interface ChatJsonResult {
+ image?: string;
+ detail?: string;
+}
+
+
export class KhojChatModal extends Modal {
result: string;
setting: KhojSetting;
@@ -105,15 +111,19 @@ export class KhojChatModal extends Modal {
return referenceButton;
}
- renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date) {
+ renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string) {
if (!message) {
return;
+ } else if (intentType === "text-to-image") {
+ let imageMarkdown = `![](data:image/png;base64,${message})`;
+ this.renderMessage(chatEl, imageMarkdown, sender, dt);
+ return;
} else if (!context) {
this.renderMessage(chatEl, message, sender, dt);
- return
+ return;
} else if (!!context && context?.length === 0) {
this.renderMessage(chatEl, message, sender, dt);
- return
+ return;
}
let chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0]
@@ -225,7 +235,7 @@ export class KhojChatModal extends Modal {
let response = await request({ url: chatUrl, headers: headers });
let chatLogs = JSON.parse(response).response;
chatLogs.forEach((chatLog: any) => {
- this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created));
+ this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type);
});
}
@@ -266,8 +276,25 @@ export class KhojChatModal extends Modal {
this.result = "";
responseElement.innerHTML = "";
+ if (response.headers.get("content-type") == "application/json") {
+ let responseText = ""
+ try {
+ const responseAsJson = await response.json() as ChatJsonResult;
+ if (responseAsJson.image) {
+ responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`;
+ } else if (responseAsJson.detail) {
+ responseText = responseAsJson.detail;
+ }
+ } catch (error) {
+ // If the chunk is not a JSON object, just display it as is
+ responseText = response.body.read().toString()
+ } finally {
+ this.renderIncrementalMessage(responseElement, responseText);
+ }
+ }
+
for await (const chunk of response.body) {
- const responseText = chunk.toString();
+ let responseText = chunk.toString();
if (responseText.includes("### compiled references:")) {
const additionalResponse = responseText.split("### compiled references:")[0];
this.renderIncrementalMessage(responseElement, additionalResponse);
@@ -310,6 +337,12 @@ export class KhojChatModal extends Modal {
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
} else {
+ if (responseText.startsWith("{") && responseText.endsWith("}")) {
+ } else {
+ // If the chunk is not a JSON object, just display it as is
+ continue;
+ }
+
this.renderIncrementalMessage(responseElement, responseText);
}
}
@@ -389,10 +422,12 @@ export class KhojChatModal extends Modal {
if (response.status === 200) {
console.log(response);
chatInput.value += response.json.text;
- } else if (response.status === 422) {
- throw new Error("⛔️ Failed to transcribe audio");
- } else {
+ } else if (response.status === 501) {
throw new Error("⛔️ Configure speech-to-text model on server.");
+ } else if (response.status === 422) {
+ throw new Error("⛔️ Audio file to large to process.");
+ } else {
+ throw new Error("⛔️ Failed to transcribe audio.");
}
};
diff --git a/src/interface/obsidian/styles.css b/src/interface/obsidian/styles.css
index 6a29e722..11fc0086 100644
--- a/src/interface/obsidian/styles.css
+++ b/src/interface/obsidian/styles.css
@@ -217,7 +217,9 @@ button.copy-button:hover {
background: #f5f5f5;
cursor: pointer;
}
-
+img {
+ max-width: 60%;
+}
#khoj-chat-footer {
padding: 0;
diff --git a/src/khoj/configure.py b/src/khoj/configure.py
index 19e7d403..717ad859 100644
--- a/src/khoj/configure.py
+++ b/src/khoj/configure.py
@@ -7,6 +7,7 @@ import requests
import os
# External Packages
+import openai
import schedule
from starlette.middleware.sessions import SessionMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
@@ -22,6 +23,7 @@ from starlette.authentication import (
# Internal Packages
from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import (
+ ConversationAdapters,
get_all_users,
get_or_create_search_model,
aget_user_subscription_state,
@@ -138,6 +140,10 @@ def configure_server(
config = FullConfig()
state.config = config
+ if ConversationAdapters.has_valid_openai_conversation_config():
+ openai_config = ConversationAdapters.get_openai_conversation_config()
+ state.openai_client = openai.OpenAI(api_key=openai_config.api_key)
+
# Initialize Search Models from Config and initialize content
try:
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index 9d18c815..6e17c15b 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -22,6 +22,7 @@ from khoj.database.models import (
GithubConfig,
GithubRepoConfig,
GoogleUser,
+ TextToImageModelConfig,
KhojApiUser,
KhojUser,
NotionConfig,
@@ -426,6 +427,10 @@ class ConversationAdapters:
else:
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
+ @staticmethod
+ async def aget_text_to_image_model_config():
+ return await TextToImageModelConfig.objects.filter().afirst()
+
class EntryAdapters:
word_filer = WordFilter()
diff --git a/src/khoj/database/migrations/0022_texttoimagemodelconfig.py b/src/khoj/database/migrations/0022_texttoimagemodelconfig.py
new file mode 100644
index 00000000..7450dc40
--- /dev/null
+++ b/src/khoj/database/migrations/0022_texttoimagemodelconfig.py
@@ -0,0 +1,25 @@
+# Generated by Django 4.2.7 on 2023-12-04 22:17
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0021_speechtotextmodeloptions_and_more"),
+ ]
+
+ operations = [
+ migrations.CreateModel(
+ name="TextToImageModelConfig",
+ fields=[
+ ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
+ ("created_at", models.DateTimeField(auto_now_add=True)),
+ ("updated_at", models.DateTimeField(auto_now=True)),
+ ("model_name", models.CharField(default="dall-e-3", max_length=200)),
+ ("model_type", models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200)),
+ ],
+ options={
+ "abstract": False,
+ },
+ ),
+ ]
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index 82348fbe..00700f2f 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -112,6 +112,14 @@ class SearchModelConfig(BaseModel):
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
+class TextToImageModelConfig(BaseModel):
+ class ModelType(models.TextChoices):
+ OPENAI = "openai"
+
+ model_name = models.CharField(max_length=200, default="dall-e-3")
+ model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
+
+
class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200)
diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html
index a55d8e29..e85759fb 100644
--- a/src/khoj/interface/web/chat.html
+++ b/src/khoj/interface/web/chat.html
@@ -183,12 +183,18 @@ To get started, just start typing below. You can also type / to see a list of co
referenceSection.appendChild(polishedReference);
}
}
- }
+ }
return numOnlineReferences;
- }
+ }
+
+ function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
+ if (intentType === "text-to-image") {
+ let imageMarkdown = `![](data:image/png;base64,${message})`;
+ renderMessage(imageMarkdown, by, dt);
+ return;
+ }
- function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
@@ -253,6 +259,17 @@ To get started, just start typing below. You can also type / to see a list of co
// Remove any text between [INST] and tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/\[INST\].+(<\/s>)?/g, '');
+ // Customize the rendering of images
+ md.renderer.rules.image = function(tokens, idx, options, env, self) {
+ let token = tokens[idx];
+
+ // Add class="text-to-image" to images
+ token.attrPush(['class', 'text-to-image']);
+
+ // Use the default renderer to render image markdown format
+ return self.renderToken(tokens, idx, options);
+ };
+
// Render markdown
newHTML = md.render(newHTML);
// Get any elements with a class that starts with "language"
@@ -292,7 +309,7 @@ To get started, just start typing below. You can also type / to see a list of co
return element
}
- function chat() {
+ async function chat() {
// Extract required fields for search from form
let query = document.getElementById("chat-input").value.trim();
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
@@ -333,113 +350,123 @@ To get started, just start typing below. You can also type / to see a list of co
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
- // Call specified Khoj API which returns a streamed response of type text/plain
- fetch(url)
- .then(response => {
- const reader = response.body.getReader();
- const decoder = new TextDecoder();
- let rawResponse = "";
- let references = null;
+ // Call specified Khoj API
+ let response = await fetch(url);
+ let rawResponse = "";
+ const contentType = response.headers.get("content-type");
- function readStream() {
- reader.read().then(({ done, value }) => {
- if (done) {
- // Append any references after all the data has been streamed
- if (references != null) {
- newResponseText.appendChild(references);
- }
- document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
- document.getElementById("chat-input").removeAttribute("disabled");
- return;
- }
-
- // Decode message chunk from stream
- const chunk = decoder.decode(value, { stream: true });
-
- if (chunk.includes("### compiled references:")) {
- const additionalResponse = chunk.split("### compiled references:")[0];
- rawResponse += additionalResponse;
- newResponseText.innerHTML = "";
- newResponseText.appendChild(formatHTMLMessage(rawResponse));
-
- const rawReference = chunk.split("### compiled references:")[1];
- const rawReferenceAsJson = JSON.parse(rawReference);
- references = document.createElement('div');
- references.classList.add("references");
-
- let referenceExpandButton = document.createElement('button');
- referenceExpandButton.classList.add("reference-expand-button");
-
- let referenceSection = document.createElement('div');
- referenceSection.classList.add("reference-section");
- referenceSection.classList.add("collapsed");
-
- let numReferences = 0;
-
- // If rawReferenceAsJson is a list, then count the length
- if (Array.isArray(rawReferenceAsJson)) {
- numReferences = rawReferenceAsJson.length;
-
- rawReferenceAsJson.forEach((reference, index) => {
- let polishedReference = generateReference(reference, index);
- referenceSection.appendChild(polishedReference);
- });
- } else {
- numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
- }
-
- references.appendChild(referenceExpandButton);
-
- referenceExpandButton.addEventListener('click', function() {
- if (referenceSection.classList.contains("collapsed")) {
- referenceSection.classList.remove("collapsed");
- referenceSection.classList.add("expanded");
- } else {
- referenceSection.classList.add("collapsed");
- referenceSection.classList.remove("expanded");
- }
- });
-
- let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
- referenceExpandButton.innerHTML = expandButtonText;
- references.appendChild(referenceSection);
- readStream();
- } else {
- // Display response from Khoj
- if (newResponseText.getElementsByClassName("spinner").length > 0) {
- newResponseText.removeChild(loadingSpinner);
- }
-
- // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
- if (chunk.startsWith("{") && chunk.endsWith("}")) {
- try {
- const responseAsJson = JSON.parse(chunk);
- if (responseAsJson.detail) {
- rawResponse += responseAsJson.detail;
- }
- } catch (error) {
- // If the chunk is not a JSON object, just display it as is
- rawResponse += chunk;
- } finally {
- newResponseText.innerHTML = "";
- newResponseText.appendChild(formatHTMLMessage(rawResponse));
- }
- } else {
- // If the chunk is not a JSON object, just display it as is
- rawResponse += chunk;
- newResponseText.innerHTML = "";
- newResponseText.appendChild(formatHTMLMessage(rawResponse));
- readStream();
- }
- }
-
- // Scroll to bottom of chat window as chat response is streamed
- document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
- });
+ if (contentType === "application/json") {
+ // Handle JSON response
+ try {
+ const responseAsJson = await response.json();
+ if (responseAsJson.image) {
+ // If response has image field, response is a generated image.
+ rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
}
- readStream();
- });
- }
+ if (responseAsJson.detail) {
+ // If response has detail field, response is an error message.
+ rawResponse += responseAsJson.detail;
+ }
+ } catch (error) {
+ // If the chunk is not a JSON object, just display it as is
+ rawResponse += chunk;
+ } finally {
+ newResponseText.innerHTML = "";
+ newResponseText.appendChild(formatHTMLMessage(rawResponse));
+
+ document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
+ document.getElementById("chat-input").removeAttribute("disabled");
+ }
+ } else {
+ // Handle streamed response of type text/event-stream or text/plain
+ const reader = response.body.getReader();
+ const decoder = new TextDecoder();
+ let references = null;
+
+ readStream();
+
+ function readStream() {
+ reader.read().then(({ done, value }) => {
+ if (done) {
+ // Append any references after all the data has been streamed
+ if (references != null) {
+ newResponseText.appendChild(references);
+ }
+ document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
+ document.getElementById("chat-input").removeAttribute("disabled");
+ return;
+ }
+
+ // Decode message chunk from stream
+ const chunk = decoder.decode(value, { stream: true });
+
+ if (chunk.includes("### compiled references:")) {
+ const additionalResponse = chunk.split("### compiled references:")[0];
+ rawResponse += additionalResponse;
+ newResponseText.innerHTML = "";
+ newResponseText.appendChild(formatHTMLMessage(rawResponse));
+
+ const rawReference = chunk.split("### compiled references:")[1];
+ const rawReferenceAsJson = JSON.parse(rawReference);
+ references = document.createElement('div');
+ references.classList.add("references");
+
+ let referenceExpandButton = document.createElement('button');
+ referenceExpandButton.classList.add("reference-expand-button");
+
+ let referenceSection = document.createElement('div');
+ referenceSection.classList.add("reference-section");
+ referenceSection.classList.add("collapsed");
+
+ let numReferences = 0;
+
+ // If rawReferenceAsJson is a list, then count the length
+ if (Array.isArray(rawReferenceAsJson)) {
+ numReferences = rawReferenceAsJson.length;
+
+ rawReferenceAsJson.forEach((reference, index) => {
+ let polishedReference = generateReference(reference, index);
+ referenceSection.appendChild(polishedReference);
+ });
+ } else {
+ numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson);
+ }
+
+ references.appendChild(referenceExpandButton);
+
+ referenceExpandButton.addEventListener('click', function() {
+ if (referenceSection.classList.contains("collapsed")) {
+ referenceSection.classList.remove("collapsed");
+ referenceSection.classList.add("expanded");
+ } else {
+ referenceSection.classList.add("collapsed");
+ referenceSection.classList.remove("expanded");
+ }
+ });
+
+ let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`;
+ referenceExpandButton.innerHTML = expandButtonText;
+ references.appendChild(referenceSection);
+ readStream();
+ } else {
+ // Display response from Khoj
+ if (newResponseText.getElementsByClassName("spinner").length > 0) {
+ newResponseText.removeChild(loadingSpinner);
+ }
+
+ // If the chunk is not a JSON object, just display it as is
+ rawResponse += chunk;
+ newResponseText.innerHTML = "";
+ newResponseText.appendChild(formatHTMLMessage(rawResponse));
+ readStream();
+ }
+ });
+
+ // Scroll to bottom of chat window as chat response is streamed
+ document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
+ };
+ }
+ };
function incrementalChat(event) {
if (!event.shiftKey && event.key === 'Enter') {
@@ -516,7 +543,7 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => {
// Render conversation history, if any
response.forEach(chat_log => {
- renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
+ renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
});
})
.catch(err => {
@@ -611,9 +638,13 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; })
.catch(err => {
- err.status == 422
- ? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
- : flashStatusInChatInput("⛔️ Failed to transcribe audio")
+ if (err.status === 501) {
+ flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
+ } else if (err.status === 422) {
+ flashStatusInChatInput("⛔️ Audio file to large to process.")
+ } else {
+ flashStatusInChatInput("⛔️ Failed to transcribe audio.")
+ }
});
};
@@ -902,6 +933,9 @@ To get started, just start typing below. You can also type / to see a list of co
margin-top: -10px;
transform: rotate(-60deg)
}
+ img.text-to-image {
+ max-width: 60%;
+ }
#chat-footer {
padding: 0;
@@ -1029,6 +1063,9 @@ To get started, just start typing below. You can also type / to see a list of co
margin: 4px;
grid-template-columns: auto;
}
+ img.text-to-image {
+ max-width: 100%;
+ }
}
@media only screen and (min-width: 700px) {
body {
diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py
index 41b1844b..211e3cac 100644
--- a/src/khoj/processor/conversation/offline/chat_model.py
+++ b/src/khoj/processor/conversation/offline/chat_model.py
@@ -47,7 +47,7 @@ def extract_questions_offline(
if use_history:
for chat in conversation_log.get("chat", [])[-4:]:
- if chat["by"] == "khoj":
+ if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py
index 0b876b26..3a1862f7 100644
--- a/src/khoj/processor/conversation/offline/utils.py
+++ b/src/khoj/processor/conversation/offline/utils.py
@@ -12,11 +12,12 @@ def download_model(model_name: str):
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
- # Download the chat model
- chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
-
# Decide whether to load model to GPU or CPU
+ chat_model_config = None
try:
+ # Download the chat model and its config
+ chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
+
# Try load chat model to GPU if:
# 1. Loading chat model to GPU isn't disabled via CLI and
# 2. Machine has GPU
@@ -26,6 +27,12 @@ def download_model(model_name: str):
)
except ValueError:
device = "cpu"
+ except Exception as e:
+ if chat_model_config is None:
+ device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory
+ logger.debug(f"Unable to download model config from gpt4all website: {e}")
+ else:
+ raise e
# Now load the downloaded chat model onto appropriate device
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)
diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py
index fed110f7..7bebc26c 100644
--- a/src/khoj/processor/conversation/openai/gpt.py
+++ b/src/khoj/processor/conversation/openai/gpt.py
@@ -41,7 +41,7 @@ def extract_questions(
[
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
- if chat["by"] == "khoj"
+ if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image"
]
)
@@ -123,8 +123,8 @@ def send_message_to_model(
def converse(
references,
- online_results,
user_query,
+ online_results=[],
conversation_log={},
model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,
diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py
index 69fab7e5..a5868cb8 100644
--- a/src/khoj/processor/conversation/openai/utils.py
+++ b/src/khoj/processor/conversation/openai/utils.py
@@ -36,11 +36,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
@retry(
retry=(
- retry_if_exception_type(openai.error.Timeout)
- | retry_if_exception_type(openai.error.APIError)
- | retry_if_exception_type(openai.error.APIConnectionError)
- | retry_if_exception_type(openai.error.RateLimitError)
- | retry_if_exception_type(openai.error.ServiceUnavailableError)
+ retry_if_exception_type(openai._exceptions.APITimeoutError)
+ | retry_if_exception_type(openai._exceptions.APIError)
+ | retry_if_exception_type(openai._exceptions.APIConnectionError)
+ | retry_if_exception_type(openai._exceptions.RateLimitError)
+ | retry_if_exception_type(openai._exceptions.APIStatusError)
),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
@@ -57,11 +57,11 @@ def completion_with_backoff(**kwargs):
@retry(
retry=(
- retry_if_exception_type(openai.error.Timeout)
- | retry_if_exception_type(openai.error.APIError)
- | retry_if_exception_type(openai.error.APIConnectionError)
- | retry_if_exception_type(openai.error.RateLimitError)
- | retry_if_exception_type(openai.error.ServiceUnavailableError)
+ retry_if_exception_type(openai._exceptions.APITimeoutError)
+ | retry_if_exception_type(openai._exceptions.APIError)
+ | retry_if_exception_type(openai._exceptions.APIConnectionError)
+ | retry_if_exception_type(openai._exceptions.RateLimitError)
+ | retry_if_exception_type(openai._exceptions.APIStatusError)
),
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3),
diff --git a/src/khoj/processor/conversation/openai/whisper.py b/src/khoj/processor/conversation/openai/whisper.py
index 72834d92..bd0e66df 100644
--- a/src/khoj/processor/conversation/openai/whisper.py
+++ b/src/khoj/processor/conversation/openai/whisper.py
@@ -3,13 +3,13 @@ from io import BufferedReader
# External Packages
from asgiref.sync import sync_to_async
-import openai
+from openai import OpenAI
-async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str:
+async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str:
"""
Transcribe audio file using Whisper model via OpenAI's API
"""
# Send the audio data to the Whisper API
- response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key)
- return response["text"]
+ response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
+ return response.text
diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py
index e2164719..4606efd9 100644
--- a/src/khoj/processor/conversation/utils.py
+++ b/src/khoj/processor/conversation/utils.py
@@ -4,6 +4,7 @@ from time import perf_counter
import json
from datetime import datetime
import queue
+from typing import Any, Dict, List
import tiktoken
# External packages
@@ -11,6 +12,8 @@ from langchain.schema import ChatMessage
from transformers import AutoTokenizer
# Internal Packages
+from khoj.database.adapters import ConversationAdapters
+from khoj.database.models import KhojUser
from khoj.utils.helpers import merge_dicts
@@ -89,6 +92,32 @@ def message_to_log(
return conversation_log
+def save_to_conversation_log(
+ q: str,
+ chat_response: str,
+ user: KhojUser,
+ meta_log: Dict,
+ user_message_time: str = None,
+ compiled_references: List[str] = [],
+ online_results: Dict[str, Any] = {},
+ inferred_queries: List[str] = [],
+ intent_type: str = "remember",
+):
+ user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ updated_conversation = message_to_log(
+ user_message=q,
+ chat_response=chat_response,
+ user_message_metadata={"created": user_message_time},
+ khoj_message_metadata={
+ "context": compiled_references,
+ "intent": {"inferred-queries": inferred_queries, "type": intent_type},
+ "onlineContext": online_results,
+ },
+ conversation_log=meta_log.get("chat", []),
+ )
+ ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
+
+
def generate_chatml_messages_with_context(
user_message,
system_message,
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index ae125980..7efd8bfd 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -19,7 +19,7 @@ from starlette.authentication import requires
from khoj.configure import configure_server
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
-from khoj.database.models import ChatModelOptions
+from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
from khoj.database.models import Entry as DbEntry
from khoj.database.models import (
GithubConfig,
@@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.prompts import help_message, no_entries_found
+from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.tools.online_search import search_with_google
from khoj.routers.helpers import (
ApiUserRateLimiter,
CommonQueryParams,
agenerate_chat_response,
get_conversation_command,
+ text_to_image,
is_ready_to_chat,
update_telemetry_state,
validate_conversation_config,
@@ -622,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
# Send the audio data to the Whisper API
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
- openai_chat_config = await ConversationAdapters.get_openai_chat_config()
if not speech_to_text_config:
- # If the user has not configured a speech to text model, return an unprocessable entity error
- status_code = 422
- elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
- api_key = openai_chat_config.api_key
+ # If the user has not configured a speech to text model, return an unsupported on server error
+ status_code = 501
+ elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
speech2text_model = speech_to_text_config.model_name
- user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
- elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
+ user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
+ elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
speech2text_model = speech_to_text_config.model_name
- user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
+ user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
finally:
# Close and Delete the temporary audio file
audio_file.close()
@@ -665,7 +665,7 @@ async def chat(
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response:
- user = request.user.object
+ user: KhojUser = request.user.object
await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
@@ -703,6 +703,11 @@ async def chat(
media_type="text/event-stream",
status_code=200,
)
+ elif conversation_command == ConversationCommand.Image:
+ image, status_code = await text_to_image(q)
+ await sync_to_async(save_to_conversation_log)(q, image, user, meta_log, intent_type="text-to-image")
+ content_obj = {"image": image, "intentType": "text-to-image"}
+ return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice.
llm_response, chat_metadata = await agenerate_chat_response(
@@ -786,7 +791,6 @@ async def extract_references_and_questions(
conversation_config = await ConversationAdapters.aget_conversation_config(user)
if conversation_config is None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
- openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
if (
offline_chat_config
and offline_chat_config.enabled
@@ -803,7 +807,7 @@ async def extract_references_and_questions(
inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
)
- elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
+ elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
openai_chat = await ConversationAdapters.get_openai_chat()
api_key = openai_chat_config.api_key
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index 22b0d037..4e883f35 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -9,23 +9,23 @@ from functools import partial
from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
+# External Packages
from fastapi import Depends, Header, HTTPException, Request, UploadFile
+import openai
from starlette.authentication import has_required_scope
-from asgiref.sync import sync_to_async
-
+# Internal Packages
from khoj.database.adapters import ConversationAdapters, EntryAdapters
-from khoj.database.models import KhojUser, Subscription
+from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
-from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log
-
-# Internal Packages
+from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry
+
logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)
@@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.General
elif query.startswith("/online"):
return ConversationCommand.Online
+ elif query.startswith("/image"):
+ return ConversationCommand.Image
# If no relevant notes found for the given query
elif not any_references:
return ConversationCommand.General
@@ -186,30 +188,7 @@ def generate_chat_response(
conversation_command: ConversationCommand = ConversationCommand.Default,
user: KhojUser = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
- def _save_to_conversation_log(
- q: str,
- chat_response: str,
- user_message_time: str,
- compiled_references: List[str],
- online_results: Dict[str, Any],
- inferred_queries: List[str],
- meta_log,
- ):
- updated_conversation = message_to_log(
- user_message=q,
- chat_response=chat_response,
- user_message_metadata={"created": user_message_time},
- khoj_message_metadata={
- "context": compiled_references,
- "intent": {"inferred-queries": inferred_queries},
- "onlineContext": online_results,
- },
- conversation_log=meta_log.get("chat", []),
- )
- ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
-
# Initialize Variables
- user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat_response = None
logger.debug(f"Conversation Type: {conversation_command.name}")
@@ -217,13 +196,13 @@ def generate_chat_response(
try:
partial_completion = partial(
- _save_to_conversation_log,
+ save_to_conversation_log,
q,
- user_message_time=user_message_time,
+ user=user,
+ meta_log=meta_log,
compiled_references=compiled_references,
online_results=online_results,
inferred_queries=inferred_queries,
- meta_log=meta_log,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
@@ -251,9 +230,9 @@ def generate_chat_response(
chat_model = conversation_config.chat_model
chat_response = converse(
compiled_references,
- online_results,
q,
- meta_log,
+ online_results=online_results,
+ conversation_log=meta_log,
model=chat_model,
api_key=api_key,
completion_func=partial_completion,
@@ -271,6 +250,29 @@ def generate_chat_response(
return chat_response, metadata
+async def text_to_image(message: str) -> Tuple[Optional[str], int]:
+ status_code = 200
+ image = None
+
+ # Send the audio data to the Whisper API
+ text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
+ if not text_to_image_config:
+ # If the user has not configured a text to image model, return an unsupported on server error
+ status_code = 501
+ elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
+ text2image_model = text_to_image_config.model_name
+ try:
+ response = state.openai_client.images.generate(
+ prompt=message, model=text2image_model, response_format="b64_json"
+ )
+ image = response.data[0].b64_json
+ except openai.OpenAIError as e:
+ logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
+ status_code = 500
+
+ return image, status_code
+
+
class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int):
self.requests = requests
diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py
index 42e3835d..21fe7e98 100644
--- a/src/khoj/utils/helpers.py
+++ b/src/khoj/utils/helpers.py
@@ -273,6 +273,7 @@ class ConversationCommand(str, Enum):
Notes = "notes"
Help = "help"
Online = "online"
+ Image = "image"
command_descriptions = {
@@ -280,6 +281,7 @@ command_descriptions = {
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Online: "Look up information on the internet.",
+ ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
}
diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py
index 313b18fc..0bb78dbe 100644
--- a/src/khoj/utils/initialization.py
+++ b/src/khoj/utils/initialization.py
@@ -7,6 +7,7 @@ from khoj.database.models import (
OpenAIProcessorConversationConfig,
ChatModelOptions,
SpeechToTextModelOptions,
+ TextToImageModelConfig,
)
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
@@ -103,6 +104,15 @@ def initialization():
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
)
+ default_text_to_image_model = "dall-e-3"
+ openai_text_to_image_model = input(
+ f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): "
+ )
+ openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model
+ TextToImageModelConfig.objects.create(
+ model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI
+ )
+
if use_offline_model == "y" or use_openai_model == "y":
logger.info("🗣️ Chat model configuration complete")
diff --git a/src/khoj/utils/models.py b/src/khoj/utils/models.py
index b5bbe292..e1298b08 100644
--- a/src/khoj/utils/models.py
+++ b/src/khoj/utils/models.py
@@ -22,17 +22,9 @@ class BaseEncoder(ABC):
class OpenAI(BaseEncoder):
- def __init__(self, model_name, device=None):
+ def __init__(self, model_name, client: openai.OpenAI, device=None):
self.model_name = model_name
- if (
- not state.processor_config
- or not state.processor_config.conversation
- or not state.processor_config.conversation.openai_model
- ):
- raise Exception(
- f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
- )
- openai.api_key = state.processor_config.conversation.openai_model.api_key
+ self.openai_client = client
self.embedding_dimensions = None
def encode(self, entries, device=None, **kwargs):
@@ -43,7 +35,7 @@ class OpenAI(BaseEncoder):
processed_entry = entries[index].replace("\n", " ")
try:
- response = openai.Embedding.create(input=processed_entry, model=self.model_name)
+ response = self.openai_client.embeddings.create(input=processed_entry, model=self.model_name)
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
# Use current models embedding dimension, once available
# Else default to embedding dimensions of the text-embedding-ada-002 model
diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py
index b54cf4b3..d5358868 100644
--- a/src/khoj/utils/state.py
+++ b/src/khoj/utils/state.py
@@ -1,15 +1,16 @@
# Standard Packages
+from collections import defaultdict
import os
+from pathlib import Path
import threading
from typing import List, Dict
-from collections import defaultdict
# External Packages
-from pathlib import Path
-from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
+from openai import OpenAI
from whisper import Whisper
# Internal Packages
+from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU, get_device
@@ -21,6 +22,7 @@ search_models = SearchModels()
embeddings_model: EmbeddingsModel = None
cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex()
+openai_client: OpenAI = None
gpt4all_processor_config: GPT4AllProcessorModel = None
whisper_model: Whisper = None
config_file: Path = None
diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py
index 7b98d4da..50ad60af 100644
--- a/tests/test_openai_chat_director.py
+++ b/tests/test_openai_chat_director.py
@@ -68,10 +68,10 @@ def test_chat_with_online_content(chat_client):
response_message = response_message.split("### compiled references")[0]
# Assert
- expected_responses = ["http://www.paulgraham.com/greatwork.html"]
+ expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
- "Expected assistants name, [K|k]hoj, in response but got: " + response_message
+ "Expected links or serper not setup in response but got: " + response_message
)