mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Merge pull request #580 from khoj-ai/fix-upgrade-chat-to-create-images
Support Image Generation with Khoj
This commit is contained in:
commit
9b961ed496
22 changed files with 529 additions and 303 deletions
|
@ -42,7 +42,7 @@ dependencies = [
|
|||
"fastapi >= 0.104.1",
|
||||
"python-multipart >= 0.0.5",
|
||||
"jinja2 == 3.1.2",
|
||||
"openai >= 0.27.0, < 1.0.0",
|
||||
"openai >= 1.0.0",
|
||||
"tiktoken >= 0.3.2",
|
||||
"tenacity >= 8.2.2",
|
||||
"pillow ~= 9.5.0",
|
||||
|
|
|
@ -179,7 +179,13 @@
|
|||
return numOnlineReferences;
|
||||
}
|
||||
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
|
||||
if (intentType === "text-to-image") {
|
||||
let imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||
renderMessage(imageMarkdown, by, dt);
|
||||
return;
|
||||
}
|
||||
|
||||
if (context == null && onlineContext == null) {
|
||||
renderMessage(message, by, dt);
|
||||
return;
|
||||
|
@ -244,6 +250,17 @@
|
|||
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
|
||||
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
|
||||
|
||||
// Customize the rendering of images
|
||||
md.renderer.rules.image = function(tokens, idx, options, env, self) {
|
||||
let token = tokens[idx];
|
||||
|
||||
// Add class="text-to-image" to images
|
||||
token.attrPush(['class', 'text-to-image']);
|
||||
|
||||
// Use the default renderer to render image markdown format
|
||||
return self.renderToken(tokens, idx, options);
|
||||
};
|
||||
|
||||
// Render markdown
|
||||
newHTML = md.render(newHTML);
|
||||
// Get any elements with a class that starts with "language"
|
||||
|
@ -328,14 +345,41 @@
|
|||
let chatInput = document.getElementById("chat-input");
|
||||
chatInput.classList.remove("option-enabled");
|
||||
|
||||
// Call specified Khoj API which returns a streamed response of type text/plain
|
||||
fetch(url, { headers })
|
||||
.then(response => {
|
||||
// Call specified Khoj API
|
||||
let response = await fetch(url, { headers });
|
||||
let rawResponse = "";
|
||||
const contentType = response.headers.get("content-type");
|
||||
|
||||
if (contentType === "application/json") {
|
||||
// Handle JSON response
|
||||
try {
|
||||
const responseAsJson = await response.json();
|
||||
if (responseAsJson.image) {
|
||||
// If response has image field, response is a generated image.
|
||||
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
|
||||
}
|
||||
if (responseAsJson.detail) {
|
||||
// If response has detail field, response is an error message.
|
||||
rawResponse += responseAsJson.detail;
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
rawResponse += chunk;
|
||||
} finally {
|
||||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
document.getElementById("chat-input").removeAttribute("disabled");
|
||||
}
|
||||
} else {
|
||||
// Handle streamed response of type text/event-stream or text/plain
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let rawResponse = "";
|
||||
let references = null;
|
||||
|
||||
readStream();
|
||||
|
||||
function readStream() {
|
||||
reader.read().then(({ done, value }) => {
|
||||
if (done) {
|
||||
|
@ -404,16 +448,23 @@
|
|||
if (newResponseText.getElementsByClassName("spinner").length > 0) {
|
||||
newResponseText.removeChild(loadingSpinner);
|
||||
}
|
||||
|
||||
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
|
||||
if (chunk.startsWith("{") && chunk.endsWith("}")) {
|
||||
try {
|
||||
const responseAsJson = JSON.parse(chunk);
|
||||
if (responseAsJson.image) {
|
||||
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
|
||||
}
|
||||
if (responseAsJson.detail) {
|
||||
newResponseText.innerHTML += responseAsJson.detail;
|
||||
rawResponse += responseAsJson.detail;
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
newResponseText.innerHTML += chunk;
|
||||
rawResponse += chunk;
|
||||
} finally {
|
||||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
}
|
||||
} else {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
|
@ -429,8 +480,7 @@
|
|||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
});
|
||||
}
|
||||
readStream();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function incrementalChat(event) {
|
||||
|
@ -522,7 +572,7 @@
|
|||
.then(response => {
|
||||
// Render conversation history, if any
|
||||
response.forEach(chat_log => {
|
||||
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
|
||||
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
|
||||
});
|
||||
})
|
||||
.catch(err => {
|
||||
|
@ -625,9 +675,13 @@
|
|||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||
.then(data => { chatInput.value += data.text; })
|
||||
.catch(err => {
|
||||
err.status == 422
|
||||
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
||||
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
|
||||
if (err.status === 501) {
|
||||
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
||||
} else if (err.status === 422) {
|
||||
flashStatusInChatInput("⛔️ Audio file to large to process.")
|
||||
} else {
|
||||
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
|
@ -810,6 +864,9 @@
|
|||
margin-top: -10px;
|
||||
transform: rotate(-60deg)
|
||||
}
|
||||
img.text-to-image {
|
||||
max-width: 60%;
|
||||
}
|
||||
|
||||
#chat-footer {
|
||||
padding: 0;
|
||||
|
@ -1050,6 +1107,9 @@
|
|||
margin: 4px;
|
||||
grid-template-columns: auto;
|
||||
}
|
||||
img.text-to-image {
|
||||
max-width: 100%;
|
||||
}
|
||||
}
|
||||
@media only screen and (min-width: 600px) {
|
||||
body {
|
||||
|
|
|
@ -2,6 +2,12 @@ import { App, MarkdownRenderer, Modal, request, requestUrl, setIcon } from 'obsi
|
|||
import { KhojSetting } from 'src/settings';
|
||||
import fetch from "node-fetch";
|
||||
|
||||
export interface ChatJsonResult {
|
||||
image?: string;
|
||||
detail?: string;
|
||||
}
|
||||
|
||||
|
||||
export class KhojChatModal extends Modal {
|
||||
result: string;
|
||||
setting: KhojSetting;
|
||||
|
@ -105,15 +111,19 @@ export class KhojChatModal extends Modal {
|
|||
return referenceButton;
|
||||
}
|
||||
|
||||
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date) {
|
||||
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string) {
|
||||
if (!message) {
|
||||
return;
|
||||
} else if (intentType === "text-to-image") {
|
||||
let imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||
this.renderMessage(chatEl, imageMarkdown, sender, dt);
|
||||
return;
|
||||
} else if (!context) {
|
||||
this.renderMessage(chatEl, message, sender, dt);
|
||||
return
|
||||
return;
|
||||
} else if (!!context && context?.length === 0) {
|
||||
this.renderMessage(chatEl, message, sender, dt);
|
||||
return
|
||||
return;
|
||||
}
|
||||
let chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
|
||||
let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0]
|
||||
|
@ -225,7 +235,7 @@ export class KhojChatModal extends Modal {
|
|||
let response = await request({ url: chatUrl, headers: headers });
|
||||
let chatLogs = JSON.parse(response).response;
|
||||
chatLogs.forEach((chatLog: any) => {
|
||||
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created));
|
||||
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -266,8 +276,25 @@ export class KhojChatModal extends Modal {
|
|||
|
||||
this.result = "";
|
||||
responseElement.innerHTML = "";
|
||||
if (response.headers.get("content-type") == "application/json") {
|
||||
let responseText = ""
|
||||
try {
|
||||
const responseAsJson = await response.json() as ChatJsonResult;
|
||||
if (responseAsJson.image) {
|
||||
responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`;
|
||||
} else if (responseAsJson.detail) {
|
||||
responseText = responseAsJson.detail;
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
responseText = response.body.read().toString()
|
||||
} finally {
|
||||
this.renderIncrementalMessage(responseElement, responseText);
|
||||
}
|
||||
}
|
||||
|
||||
for await (const chunk of response.body) {
|
||||
const responseText = chunk.toString();
|
||||
let responseText = chunk.toString();
|
||||
if (responseText.includes("### compiled references:")) {
|
||||
const additionalResponse = responseText.split("### compiled references:")[0];
|
||||
this.renderIncrementalMessage(responseElement, additionalResponse);
|
||||
|
@ -310,6 +337,12 @@ export class KhojChatModal extends Modal {
|
|||
referenceExpandButton.innerHTML = expandButtonText;
|
||||
references.appendChild(referenceSection);
|
||||
} else {
|
||||
if (responseText.startsWith("{") && responseText.endsWith("}")) {
|
||||
} else {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
continue;
|
||||
}
|
||||
|
||||
this.renderIncrementalMessage(responseElement, responseText);
|
||||
}
|
||||
}
|
||||
|
@ -389,10 +422,12 @@ export class KhojChatModal extends Modal {
|
|||
if (response.status === 200) {
|
||||
console.log(response);
|
||||
chatInput.value += response.json.text;
|
||||
} else if (response.status === 422) {
|
||||
throw new Error("⛔️ Failed to transcribe audio");
|
||||
} else {
|
||||
} else if (response.status === 501) {
|
||||
throw new Error("⛔️ Configure speech-to-text model on server.");
|
||||
} else if (response.status === 422) {
|
||||
throw new Error("⛔️ Audio file to large to process.");
|
||||
} else {
|
||||
throw new Error("⛔️ Failed to transcribe audio.");
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -217,7 +217,9 @@ button.copy-button:hover {
|
|||
background: #f5f5f5;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
img {
|
||||
max-width: 60%;
|
||||
}
|
||||
|
||||
#khoj-chat-footer {
|
||||
padding: 0;
|
||||
|
|
|
@ -7,6 +7,7 @@ import requests
|
|||
import os
|
||||
|
||||
# External Packages
|
||||
import openai
|
||||
import schedule
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||
|
@ -22,6 +23,7 @@ from starlette.authentication import (
|
|||
# Internal Packages
|
||||
from khoj.database.models import KhojUser, Subscription
|
||||
from khoj.database.adapters import (
|
||||
ConversationAdapters,
|
||||
get_all_users,
|
||||
get_or_create_search_model,
|
||||
aget_user_subscription_state,
|
||||
|
@ -138,6 +140,10 @@ def configure_server(
|
|||
config = FullConfig()
|
||||
state.config = config
|
||||
|
||||
if ConversationAdapters.has_valid_openai_conversation_config():
|
||||
openai_config = ConversationAdapters.get_openai_conversation_config()
|
||||
state.openai_client = openai.OpenAI(api_key=openai_config.api_key)
|
||||
|
||||
# Initialize Search Models from Config and initialize content
|
||||
try:
|
||||
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)
|
||||
|
|
|
@ -22,6 +22,7 @@ from khoj.database.models import (
|
|||
GithubConfig,
|
||||
GithubRepoConfig,
|
||||
GoogleUser,
|
||||
TextToImageModelConfig,
|
||||
KhojApiUser,
|
||||
KhojUser,
|
||||
NotionConfig,
|
||||
|
@ -426,6 +427,10 @@ class ConversationAdapters:
|
|||
else:
|
||||
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
|
||||
|
||||
@staticmethod
|
||||
async def aget_text_to_image_model_config():
|
||||
return await TextToImageModelConfig.objects.filter().afirst()
|
||||
|
||||
|
||||
class EntryAdapters:
|
||||
word_filer = WordFilter()
|
||||
|
|
25
src/khoj/database/migrations/0022_texttoimagemodelconfig.py
Normal file
25
src/khoj/database/migrations/0022_texttoimagemodelconfig.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Generated by Django 4.2.7 on 2023-12-04 22:17
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0021_speechtotextmodeloptions_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="TextToImageModelConfig",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("model_name", models.CharField(default="dall-e-3", max_length=200)),
|
||||
("model_type", models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -112,6 +112,14 @@ class SearchModelConfig(BaseModel):
|
|||
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||
|
||||
|
||||
class TextToImageModelConfig(BaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
|
||||
model_name = models.CharField(max_length=200, default="dall-e-3")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
|
||||
|
||||
|
||||
class OpenAIProcessorConversationConfig(BaseModel):
|
||||
api_key = models.CharField(max_length=200)
|
||||
|
||||
|
|
|
@ -188,7 +188,13 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
return numOnlineReferences;
|
||||
}
|
||||
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
|
||||
if (intentType === "text-to-image") {
|
||||
let imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||
renderMessage(imageMarkdown, by, dt);
|
||||
return;
|
||||
}
|
||||
|
||||
if (context == null && onlineContext == null) {
|
||||
renderMessage(message, by, dt);
|
||||
return;
|
||||
|
@ -253,6 +259,17 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
|
||||
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
|
||||
|
||||
// Customize the rendering of images
|
||||
md.renderer.rules.image = function(tokens, idx, options, env, self) {
|
||||
let token = tokens[idx];
|
||||
|
||||
// Add class="text-to-image" to images
|
||||
token.attrPush(['class', 'text-to-image']);
|
||||
|
||||
// Use the default renderer to render image markdown format
|
||||
return self.renderToken(tokens, idx, options);
|
||||
};
|
||||
|
||||
// Render markdown
|
||||
newHTML = md.render(newHTML);
|
||||
// Get any elements with a class that starts with "language"
|
||||
|
@ -292,7 +309,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
return element
|
||||
}
|
||||
|
||||
function chat() {
|
||||
async function chat() {
|
||||
// Extract required fields for search from form
|
||||
let query = document.getElementById("chat-input").value.trim();
|
||||
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
|
||||
|
@ -333,14 +350,41 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
let chatInput = document.getElementById("chat-input");
|
||||
chatInput.classList.remove("option-enabled");
|
||||
|
||||
// Call specified Khoj API which returns a streamed response of type text/plain
|
||||
fetch(url)
|
||||
.then(response => {
|
||||
// Call specified Khoj API
|
||||
let response = await fetch(url);
|
||||
let rawResponse = "";
|
||||
const contentType = response.headers.get("content-type");
|
||||
|
||||
if (contentType === "application/json") {
|
||||
// Handle JSON response
|
||||
try {
|
||||
const responseAsJson = await response.json();
|
||||
if (responseAsJson.image) {
|
||||
// If response has image field, response is a generated image.
|
||||
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
|
||||
}
|
||||
if (responseAsJson.detail) {
|
||||
// If response has detail field, response is an error message.
|
||||
rawResponse += responseAsJson.detail;
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
rawResponse += chunk;
|
||||
} finally {
|
||||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
document.getElementById("chat-input").removeAttribute("disabled");
|
||||
}
|
||||
} else {
|
||||
// Handle streamed response of type text/event-stream or text/plain
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let rawResponse = "";
|
||||
let references = null;
|
||||
|
||||
readStream();
|
||||
|
||||
function readStream() {
|
||||
reader.read().then(({ done, value }) => {
|
||||
if (done) {
|
||||
|
@ -410,36 +454,19 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
newResponseText.removeChild(loadingSpinner);
|
||||
}
|
||||
|
||||
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
|
||||
if (chunk.startsWith("{") && chunk.endsWith("}")) {
|
||||
try {
|
||||
const responseAsJson = JSON.parse(chunk);
|
||||
if (responseAsJson.detail) {
|
||||
rawResponse += responseAsJson.detail;
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
rawResponse += chunk;
|
||||
} finally {
|
||||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
}
|
||||
} else {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
rawResponse += chunk;
|
||||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
readStream();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Scroll to bottom of chat window as chat response is streamed
|
||||
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
|
||||
});
|
||||
}
|
||||
readStream();
|
||||
});
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
function incrementalChat(event) {
|
||||
if (!event.shiftKey && event.key === 'Enter') {
|
||||
|
@ -516,7 +543,7 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
.then(response => {
|
||||
// Render conversation history, if any
|
||||
response.forEach(chat_log => {
|
||||
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
|
||||
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
|
||||
});
|
||||
})
|
||||
.catch(err => {
|
||||
|
@ -611,9 +638,13 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
.then(response => response.ok ? response.json() : Promise.reject(response))
|
||||
.then(data => { chatInput.value += data.text; })
|
||||
.catch(err => {
|
||||
err.status == 422
|
||||
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
||||
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
|
||||
if (err.status === 501) {
|
||||
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
|
||||
} else if (err.status === 422) {
|
||||
flashStatusInChatInput("⛔️ Audio file to large to process.")
|
||||
} else {
|
||||
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
|
@ -902,6 +933,9 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
margin-top: -10px;
|
||||
transform: rotate(-60deg)
|
||||
}
|
||||
img.text-to-image {
|
||||
max-width: 60%;
|
||||
}
|
||||
|
||||
#chat-footer {
|
||||
padding: 0;
|
||||
|
@ -1029,6 +1063,9 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
margin: 4px;
|
||||
grid-template-columns: auto;
|
||||
}
|
||||
img.text-to-image {
|
||||
max-width: 100%;
|
||||
}
|
||||
}
|
||||
@media only screen and (min-width: 700px) {
|
||||
body {
|
||||
|
|
|
@ -47,7 +47,7 @@ def extract_questions_offline(
|
|||
|
||||
if use_history:
|
||||
for chat in conversation_log.get("chat", [])[-4:]:
|
||||
if chat["by"] == "khoj":
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
|
||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||
chat_history += f"A: {chat['message']}\n"
|
||||
|
||||
|
|
|
@ -12,11 +12,12 @@ def download_model(model_name: str):
|
|||
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
|
||||
raise e
|
||||
|
||||
# Download the chat model
|
||||
# Decide whether to load model to GPU or CPU
|
||||
chat_model_config = None
|
||||
try:
|
||||
# Download the chat model and its config
|
||||
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
|
||||
|
||||
# Decide whether to load model to GPU or CPU
|
||||
try:
|
||||
# Try load chat model to GPU if:
|
||||
# 1. Loading chat model to GPU isn't disabled via CLI and
|
||||
# 2. Machine has GPU
|
||||
|
@ -26,6 +27,12 @@ def download_model(model_name: str):
|
|||
)
|
||||
except ValueError:
|
||||
device = "cpu"
|
||||
except Exception as e:
|
||||
if chat_model_config is None:
|
||||
device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory
|
||||
logger.debug(f"Unable to download model config from gpt4all website: {e}")
|
||||
else:
|
||||
raise e
|
||||
|
||||
# Now load the downloaded chat model onto appropriate device
|
||||
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)
|
||||
|
|
|
@ -41,7 +41,7 @@ def extract_questions(
|
|||
[
|
||||
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
|
||||
for chat in conversation_log.get("chat", [])[-4:]
|
||||
if chat["by"] == "khoj"
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image"
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -123,8 +123,8 @@ def send_message_to_model(
|
|||
|
||||
def converse(
|
||||
references,
|
||||
online_results,
|
||||
user_query,
|
||||
online_results=[],
|
||||
conversation_log={},
|
||||
model: str = "gpt-3.5-turbo",
|
||||
api_key: Optional[str] = None,
|
||||
|
|
|
@ -36,11 +36,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
|
|||
|
||||
@retry(
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
retry_if_exception_type(openai._exceptions.APITimeoutError)
|
||||
| retry_if_exception_type(openai._exceptions.APIError)
|
||||
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
||||
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
||||
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
||||
),
|
||||
wait=wait_random_exponential(min=1, max=10),
|
||||
stop=stop_after_attempt(3),
|
||||
|
@ -57,11 +57,11 @@ def completion_with_backoff(**kwargs):
|
|||
|
||||
@retry(
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
retry_if_exception_type(openai._exceptions.APITimeoutError)
|
||||
| retry_if_exception_type(openai._exceptions.APIError)
|
||||
| retry_if_exception_type(openai._exceptions.APIConnectionError)
|
||||
| retry_if_exception_type(openai._exceptions.RateLimitError)
|
||||
| retry_if_exception_type(openai._exceptions.APIStatusError)
|
||||
),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
stop=stop_after_attempt(3),
|
||||
|
|
|
@ -3,13 +3,13 @@ from io import BufferedReader
|
|||
|
||||
# External Packages
|
||||
from asgiref.sync import sync_to_async
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str:
|
||||
async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str:
|
||||
"""
|
||||
Transcribe audio file using Whisper model via OpenAI's API
|
||||
"""
|
||||
# Send the audio data to the Whisper API
|
||||
response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key)
|
||||
return response["text"]
|
||||
response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
|
||||
return response.text
|
||||
|
|
|
@ -4,6 +4,7 @@ from time import perf_counter
|
|||
import json
|
||||
from datetime import datetime
|
||||
import queue
|
||||
from typing import Any, Dict, List
|
||||
import tiktoken
|
||||
|
||||
# External packages
|
||||
|
@ -11,6 +12,8 @@ from langchain.schema import ChatMessage
|
|||
from transformers import AutoTokenizer
|
||||
|
||||
# Internal Packages
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.utils.helpers import merge_dicts
|
||||
|
||||
|
||||
|
@ -89,6 +92,32 @@ def message_to_log(
|
|||
return conversation_log
|
||||
|
||||
|
||||
def save_to_conversation_log(
|
||||
q: str,
|
||||
chat_response: str,
|
||||
user: KhojUser,
|
||||
meta_log: Dict,
|
||||
user_message_time: str = None,
|
||||
compiled_references: List[str] = [],
|
||||
online_results: Dict[str, Any] = {},
|
||||
inferred_queries: List[str] = [],
|
||||
intent_type: str = "remember",
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
updated_conversation = message_to_log(
|
||||
user_message=q,
|
||||
chat_response=chat_response,
|
||||
user_message_metadata={"created": user_message_time},
|
||||
khoj_message_metadata={
|
||||
"context": compiled_references,
|
||||
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
||||
"onlineContext": online_results,
|
||||
},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
)
|
||||
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
|
||||
|
||||
|
||||
def generate_chatml_messages_with_context(
|
||||
user_message,
|
||||
system_message,
|
||||
|
|
|
@ -19,7 +19,7 @@ from starlette.authentication import requires
|
|||
from khoj.configure import configure_server
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.models import ChatModelOptions
|
||||
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import (
|
||||
GithubConfig,
|
||||
|
@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
|||
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||
from khoj.processor.tools.online_search import search_with_google
|
||||
from khoj.routers.helpers import (
|
||||
ApiUserRateLimiter,
|
||||
CommonQueryParams,
|
||||
agenerate_chat_response,
|
||||
get_conversation_command,
|
||||
text_to_image,
|
||||
is_ready_to_chat,
|
||||
update_telemetry_state,
|
||||
validate_conversation_config,
|
||||
|
@ -622,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
|
|||
|
||||
# Send the audio data to the Whisper API
|
||||
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
if not speech_to_text_config:
|
||||
# If the user has not configured a speech to text model, return an unprocessable entity error
|
||||
status_code = 422
|
||||
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
api_key = openai_chat_config.api_key
|
||||
# If the user has not configured a speech to text model, return an unsupported on server error
|
||||
status_code = 501
|
||||
elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
|
||||
speech2text_model = speech_to_text_config.model_name
|
||||
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
|
||||
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
|
||||
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
|
||||
speech2text_model = speech_to_text_config.model_name
|
||||
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
|
||||
user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
|
||||
finally:
|
||||
# Close and Delete the temporary audio file
|
||||
audio_file.close()
|
||||
|
@ -665,7 +665,7 @@ async def chat(
|
|||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||
) -> Response:
|
||||
user = request.user.object
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||
|
@ -703,6 +703,11 @@ async def chat(
|
|||
media_type="text/event-stream",
|
||||
status_code=200,
|
||||
)
|
||||
elif conversation_command == ConversationCommand.Image:
|
||||
image, status_code = await text_to_image(q)
|
||||
await sync_to_async(save_to_conversation_log)(q, image, user, meta_log, intent_type="text-to-image")
|
||||
content_obj = {"image": image, "intentType": "text-to-image"}
|
||||
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
|
||||
|
||||
# Get the (streamed) chat response from the LLM of choice.
|
||||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
|
@ -786,7 +791,6 @@ async def extract_references_and_questions(
|
|||
conversation_config = await ConversationAdapters.aget_conversation_config(user)
|
||||
if conversation_config is None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
|
||||
if (
|
||||
offline_chat_config
|
||||
and offline_chat_config.enabled
|
||||
|
@ -803,7 +807,7 @@ async def extract_references_and_questions(
|
|||
inferred_queries = extract_questions_offline(
|
||||
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
|
||||
)
|
||||
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
|
||||
openai_chat = await ConversationAdapters.get_openai_chat()
|
||||
api_key = openai_chat_config.api_key
|
||||
|
|
|
@ -9,23 +9,23 @@ from functools import partial
|
|||
from time import time
|
||||
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
# External Packages
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
import openai
|
||||
from starlette.authentication import has_required_scope
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
|
||||
# Internal Packages
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.models import KhojUser, Subscription
|
||||
from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
|
||||
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
|
||||
from khoj.utils import state
|
||||
from khoj.utils.config import GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import ConversationCommand, log_telemetry
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
|||
return ConversationCommand.General
|
||||
elif query.startswith("/online"):
|
||||
return ConversationCommand.Online
|
||||
elif query.startswith("/image"):
|
||||
return ConversationCommand.Image
|
||||
# If no relevant notes found for the given query
|
||||
elif not any_references:
|
||||
return ConversationCommand.General
|
||||
|
@ -186,30 +188,7 @@ def generate_chat_response(
|
|||
conversation_command: ConversationCommand = ConversationCommand.Default,
|
||||
user: KhojUser = None,
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
def _save_to_conversation_log(
|
||||
q: str,
|
||||
chat_response: str,
|
||||
user_message_time: str,
|
||||
compiled_references: List[str],
|
||||
online_results: Dict[str, Any],
|
||||
inferred_queries: List[str],
|
||||
meta_log,
|
||||
):
|
||||
updated_conversation = message_to_log(
|
||||
user_message=q,
|
||||
chat_response=chat_response,
|
||||
user_message_metadata={"created": user_message_time},
|
||||
khoj_message_metadata={
|
||||
"context": compiled_references,
|
||||
"intent": {"inferred-queries": inferred_queries},
|
||||
"onlineContext": online_results,
|
||||
},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
)
|
||||
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
|
||||
|
||||
# Initialize Variables
|
||||
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
chat_response = None
|
||||
logger.debug(f"Conversation Type: {conversation_command.name}")
|
||||
|
||||
|
@ -217,13 +196,13 @@ def generate_chat_response(
|
|||
|
||||
try:
|
||||
partial_completion = partial(
|
||||
_save_to_conversation_log,
|
||||
save_to_conversation_log,
|
||||
q,
|
||||
user_message_time=user_message_time,
|
||||
user=user,
|
||||
meta_log=meta_log,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
inferred_queries=inferred_queries,
|
||||
meta_log=meta_log,
|
||||
)
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
|
||||
|
@ -251,9 +230,9 @@ def generate_chat_response(
|
|||
chat_model = conversation_config.chat_model
|
||||
chat_response = converse(
|
||||
compiled_references,
|
||||
online_results,
|
||||
q,
|
||||
meta_log,
|
||||
online_results=online_results,
|
||||
conversation_log=meta_log,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
|
@ -271,6 +250,29 @@ def generate_chat_response(
|
|||
return chat_response, metadata
|
||||
|
||||
|
||||
async def text_to_image(message: str) -> Tuple[Optional[str], int]:
|
||||
status_code = 200
|
||||
image = None
|
||||
|
||||
# Send the audio data to the Whisper API
|
||||
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
|
||||
if not text_to_image_config:
|
||||
# If the user has not configured a text to image model, return an unsupported on server error
|
||||
status_code = 501
|
||||
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
|
||||
text2image_model = text_to_image_config.model_name
|
||||
try:
|
||||
response = state.openai_client.images.generate(
|
||||
prompt=message, model=text2image_model, response_format="b64_json"
|
||||
)
|
||||
image = response.data[0].b64_json
|
||||
except openai.OpenAIError as e:
|
||||
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
|
||||
status_code = 500
|
||||
|
||||
return image, status_code
|
||||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
def __init__(self, requests: int, subscribed_requests: int, window: int):
|
||||
self.requests = requests
|
||||
|
|
|
@ -273,6 +273,7 @@ class ConversationCommand(str, Enum):
|
|||
Notes = "notes"
|
||||
Help = "help"
|
||||
Online = "online"
|
||||
Image = "image"
|
||||
|
||||
|
||||
command_descriptions = {
|
||||
|
@ -280,6 +281,7 @@ command_descriptions = {
|
|||
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
|
||||
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
|
||||
ConversationCommand.Online: "Look up information on the internet.",
|
||||
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
||||
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from khoj.database.models import (
|
|||
OpenAIProcessorConversationConfig,
|
||||
ChatModelOptions,
|
||||
SpeechToTextModelOptions,
|
||||
TextToImageModelConfig,
|
||||
)
|
||||
|
||||
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
|
||||
|
@ -103,6 +104,15 @@ def initialization():
|
|||
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
|
||||
)
|
||||
|
||||
default_text_to_image_model = "dall-e-3"
|
||||
openai_text_to_image_model = input(
|
||||
f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): "
|
||||
)
|
||||
openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model
|
||||
TextToImageModelConfig.objects.create(
|
||||
model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI
|
||||
)
|
||||
|
||||
if use_offline_model == "y" or use_openai_model == "y":
|
||||
logger.info("🗣️ Chat model configuration complete")
|
||||
|
||||
|
|
|
@ -22,17 +22,9 @@ class BaseEncoder(ABC):
|
|||
|
||||
|
||||
class OpenAI(BaseEncoder):
|
||||
def __init__(self, model_name, device=None):
|
||||
def __init__(self, model_name, client: openai.OpenAI, device=None):
|
||||
self.model_name = model_name
|
||||
if (
|
||||
not state.processor_config
|
||||
or not state.processor_config.conversation
|
||||
or not state.processor_config.conversation.openai_model
|
||||
):
|
||||
raise Exception(
|
||||
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
|
||||
)
|
||||
openai.api_key = state.processor_config.conversation.openai_model.api_key
|
||||
self.openai_client = client
|
||||
self.embedding_dimensions = None
|
||||
|
||||
def encode(self, entries, device=None, **kwargs):
|
||||
|
@ -43,7 +35,7 @@ class OpenAI(BaseEncoder):
|
|||
processed_entry = entries[index].replace("\n", " ")
|
||||
|
||||
try:
|
||||
response = openai.Embedding.create(input=processed_entry, model=self.model_name)
|
||||
response = self.openai_client.embeddings.create(input=processed_entry, model=self.model_name)
|
||||
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
|
||||
# Use current models embedding dimension, once available
|
||||
# Else default to embedding dimensions of the text-embedding-ada-002 model
|
||||
|
|
|
@ -1,15 +1,16 @@
|
|||
# Standard Packages
|
||||
from collections import defaultdict
|
||||
import os
|
||||
from pathlib import Path
|
||||
import threading
|
||||
from typing import List, Dict
|
||||
from collections import defaultdict
|
||||
|
||||
# External Packages
|
||||
from pathlib import Path
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from openai import OpenAI
|
||||
from whisper import Whisper
|
||||
|
||||
# Internal Packages
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.utils import config as utils_config
|
||||
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
|
||||
from khoj.utils.helpers import LRU, get_device
|
||||
|
@ -21,6 +22,7 @@ search_models = SearchModels()
|
|||
embeddings_model: EmbeddingsModel = None
|
||||
cross_encoder_model: CrossEncoderModel = None
|
||||
content_index = ContentIndex()
|
||||
openai_client: OpenAI = None
|
||||
gpt4all_processor_config: GPT4AllProcessorModel = None
|
||||
whisper_model: Whisper = None
|
||||
config_file: Path = None
|
||||
|
|
|
@ -68,10 +68,10 @@ def test_chat_with_online_content(chat_client):
|
|||
response_message = response_message.split("### compiled references")[0]
|
||||
|
||||
# Assert
|
||||
expected_responses = ["http://www.paulgraham.com/greatwork.html"]
|
||||
expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message for expected_response in expected_responses]), (
|
||||
"Expected assistants name, [K|k]hoj, in response but got: " + response_message
|
||||
"Expected links or serper not setup in response but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue