Merge pull request #580 from khoj-ai/fix-upgrade-chat-to-create-images

Support Image Generation with Khoj
This commit is contained in:
sabaimran 2023-12-07 21:17:58 +05:30 committed by GitHub
commit 9b961ed496
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 529 additions and 303 deletions

View file

@ -42,7 +42,7 @@ dependencies = [
"fastapi >= 0.104.1",
"python-multipart >= 0.0.5",
"jinja2 == 3.1.2",
"openai >= 0.27.0, < 1.0.0",
"openai >= 1.0.0",
"tiktoken >= 0.3.2",
"tenacity >= 8.2.2",
"pillow ~= 9.5.0",

View file

@ -179,7 +179,13 @@
return numOnlineReferences;
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
renderMessage(imageMarkdown, by, dt);
return;
}
if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
@ -244,6 +250,17 @@
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
// Customize the rendering of images
md.renderer.rules.image = function(tokens, idx, options, env, self) {
let token = tokens[idx];
// Add class="text-to-image" to images
token.attrPush(['class', 'text-to-image']);
// Use the default renderer to render image markdown format
return self.renderToken(tokens, idx, options);
};
// Render markdown
newHTML = md.render(newHTML);
// Get any elements with a class that starts with "language"
@ -328,14 +345,41 @@
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
// Call specified Khoj API which returns a streamed response of type text/plain
fetch(url, { headers })
.then(response => {
// Call specified Khoj API
let response = await fetch(url, { headers });
let rawResponse = "";
const contentType = response.headers.get("content-type");
if (contentType === "application/json") {
// Handle JSON response
try {
const responseAsJson = await response.json();
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader();
const decoder = new TextDecoder();
let rawResponse = "";
let references = null;
readStream();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
@ -404,16 +448,23 @@
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.image) {
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
}
if (responseAsJson.detail) {
newResponseText.innerHTML += responseAsJson.detail;
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
newResponseText.innerHTML += chunk;
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
}
} else {
// If the chunk is not a JSON object, just display it as is
@ -429,8 +480,7 @@
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
});
}
readStream();
});
}
}
function incrementalChat(event) {
@ -522,7 +572,7 @@
.then(response => {
// Render conversation history, if any
response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
});
})
.catch(err => {
@ -625,9 +675,13 @@
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; })
.catch(err => {
err.status == 422
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
if (err.status === 501) {
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
} else if (err.status === 422) {
flashStatusInChatInput("⛔️ Audio file to large to process.")
} else {
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
}
});
};
@ -810,6 +864,9 @@
margin-top: -10px;
transform: rotate(-60deg)
}
img.text-to-image {
max-width: 60%;
}
#chat-footer {
padding: 0;
@ -1050,6 +1107,9 @@
margin: 4px;
grid-template-columns: auto;
}
img.text-to-image {
max-width: 100%;
}
}
@media only screen and (min-width: 600px) {
body {

View file

@ -2,6 +2,12 @@ import { App, MarkdownRenderer, Modal, request, requestUrl, setIcon } from 'obsi
import { KhojSetting } from 'src/settings';
import fetch from "node-fetch";
export interface ChatJsonResult {
image?: string;
detail?: string;
}
export class KhojChatModal extends Modal {
result: string;
setting: KhojSetting;
@ -105,15 +111,19 @@ export class KhojChatModal extends Modal {
return referenceButton;
}
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date) {
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string) {
if (!message) {
return;
} else if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
this.renderMessage(chatEl, imageMarkdown, sender, dt);
return;
} else if (!context) {
this.renderMessage(chatEl, message, sender, dt);
return
return;
} else if (!!context && context?.length === 0) {
this.renderMessage(chatEl, message, sender, dt);
return
return;
}
let chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0]
@ -225,7 +235,7 @@ export class KhojChatModal extends Modal {
let response = await request({ url: chatUrl, headers: headers });
let chatLogs = JSON.parse(response).response;
chatLogs.forEach((chatLog: any) => {
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created));
this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type);
});
}
@ -266,8 +276,25 @@ export class KhojChatModal extends Modal {
this.result = "";
responseElement.innerHTML = "";
if (response.headers.get("content-type") == "application/json") {
let responseText = ""
try {
const responseAsJson = await response.json() as ChatJsonResult;
if (responseAsJson.image) {
responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.detail) {
responseText = responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
responseText = response.body.read().toString()
} finally {
this.renderIncrementalMessage(responseElement, responseText);
}
}
for await (const chunk of response.body) {
const responseText = chunk.toString();
let responseText = chunk.toString();
if (responseText.includes("### compiled references:")) {
const additionalResponse = responseText.split("### compiled references:")[0];
this.renderIncrementalMessage(responseElement, additionalResponse);
@ -310,6 +337,12 @@ export class KhojChatModal extends Modal {
referenceExpandButton.innerHTML = expandButtonText;
references.appendChild(referenceSection);
} else {
if (responseText.startsWith("{") && responseText.endsWith("}")) {
} else {
// If the chunk is not a JSON object, just display it as is
continue;
}
this.renderIncrementalMessage(responseElement, responseText);
}
}
@ -389,10 +422,12 @@ export class KhojChatModal extends Modal {
if (response.status === 200) {
console.log(response);
chatInput.value += response.json.text;
} else if (response.status === 422) {
throw new Error("⛔️ Failed to transcribe audio");
} else {
} else if (response.status === 501) {
throw new Error("⛔️ Configure speech-to-text model on server.");
} else if (response.status === 422) {
throw new Error("⛔️ Audio file to large to process.");
} else {
throw new Error("⛔️ Failed to transcribe audio.");
}
};

View file

@ -217,7 +217,9 @@ button.copy-button:hover {
background: #f5f5f5;
cursor: pointer;
}
img {
max-width: 60%;
}
#khoj-chat-footer {
padding: 0;

View file

@ -7,6 +7,7 @@ import requests
import os
# External Packages
import openai
import schedule
from starlette.middleware.sessions import SessionMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
@ -22,6 +23,7 @@ from starlette.authentication import (
# Internal Packages
from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import (
ConversationAdapters,
get_all_users,
get_or_create_search_model,
aget_user_subscription_state,
@ -138,6 +140,10 @@ def configure_server(
config = FullConfig()
state.config = config
if ConversationAdapters.has_valid_openai_conversation_config():
openai_config = ConversationAdapters.get_openai_conversation_config()
state.openai_client = openai.OpenAI(api_key=openai_config.api_key)
# Initialize Search Models from Config and initialize content
try:
state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder)

View file

@ -22,6 +22,7 @@ from khoj.database.models import (
GithubConfig,
GithubRepoConfig,
GoogleUser,
TextToImageModelConfig,
KhojApiUser,
KhojUser,
NotionConfig,
@ -426,6 +427,10 @@ class ConversationAdapters:
else:
raise ValueError("Invalid conversation config - either configure offline chat or openai chat")
@staticmethod
async def aget_text_to_image_model_config():
return await TextToImageModelConfig.objects.filter().afirst()
class EntryAdapters:
word_filer = WordFilter()

View file

@ -0,0 +1,25 @@
# Generated by Django 4.2.7 on 2023-12-04 22:17
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0021_speechtotextmodeloptions_and_more"),
]
operations = [
migrations.CreateModel(
name="TextToImageModelConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("model_name", models.CharField(default="dall-e-3", max_length=200)),
("model_type", models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200)),
],
options={
"abstract": False,
},
),
]

View file

@ -112,6 +112,14 @@ class SearchModelConfig(BaseModel):
cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2")
class TextToImageModelConfig(BaseModel):
class ModelType(models.TextChoices):
OPENAI = "openai"
model_name = models.CharField(max_length=200, default="dall-e-3")
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI)
class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200)

View file

@ -188,7 +188,13 @@ To get started, just start typing below. You can also type / to see a list of co
return numOnlineReferences;
}
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) {
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
renderMessage(imageMarkdown, by, dt);
return;
}
if (context == null && onlineContext == null) {
renderMessage(message, by, dt);
return;
@ -253,6 +259,17 @@ To get started, just start typing below. You can also type / to see a list of co
// Remove any text between <s>[INST] and </s> tags. These are spurious instructions for the AI chat model.
newHTML = newHTML.replace(/<s>\[INST\].+(<\/s>)?/g, '');
// Customize the rendering of images
md.renderer.rules.image = function(tokens, idx, options, env, self) {
let token = tokens[idx];
// Add class="text-to-image" to images
token.attrPush(['class', 'text-to-image']);
// Use the default renderer to render image markdown format
return self.renderToken(tokens, idx, options);
};
// Render markdown
newHTML = md.render(newHTML);
// Get any elements with a class that starts with "language"
@ -292,7 +309,7 @@ To get started, just start typing below. You can also type / to see a list of co
return element
}
function chat() {
async function chat() {
// Extract required fields for search from form
let query = document.getElementById("chat-input").value.trim();
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
@ -333,14 +350,41 @@ To get started, just start typing below. You can also type / to see a list of co
let chatInput = document.getElementById("chat-input");
chatInput.classList.remove("option-enabled");
// Call specified Khoj API which returns a streamed response of type text/plain
fetch(url)
.then(response => {
// Call specified Khoj API
let response = await fetch(url);
let rawResponse = "";
const contentType = response.headers.get("content-type");
if (contentType === "application/json") {
// Handle JSON response
try {
const responseAsJson = await response.json();
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
}
if (responseAsJson.detail) {
// If response has detail field, response is an error message.
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
}
} else {
// Handle streamed response of type text/event-stream or text/plain
const reader = response.body.getReader();
const decoder = new TextDecoder();
let rawResponse = "";
let references = null;
readStream();
function readStream() {
reader.read().then(({ done, value }) => {
if (done) {
@ -410,36 +454,19 @@ To get started, just start typing below. You can also type / to see a list of co
newResponseText.removeChild(loadingSpinner);
}
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.detail) {
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
}
});
// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
});
}
readStream();
});
};
}
};
function incrementalChat(event) {
if (!event.shiftKey && event.key === 'Enter') {
@ -516,7 +543,7 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => {
// Render conversation history, if any
response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext);
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type);
});
})
.catch(err => {
@ -611,9 +638,13 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => response.ok ? response.json() : Promise.reject(response))
.then(data => { chatInput.value += data.text; })
.catch(err => {
err.status == 422
? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
: flashStatusInChatInput("⛔️ Failed to transcribe audio")
if (err.status === 501) {
flashStatusInChatInput("⛔️ Configure speech-to-text model on server.")
} else if (err.status === 422) {
flashStatusInChatInput("⛔️ Audio file to large to process.")
} else {
flashStatusInChatInput("⛔️ Failed to transcribe audio.")
}
});
};
@ -902,6 +933,9 @@ To get started, just start typing below. You can also type / to see a list of co
margin-top: -10px;
transform: rotate(-60deg)
}
img.text-to-image {
max-width: 60%;
}
#chat-footer {
padding: 0;
@ -1029,6 +1063,9 @@ To get started, just start typing below. You can also type / to see a list of co
margin: 4px;
grid-template-columns: auto;
}
img.text-to-image {
max-width: 100%;
}
}
@media only screen and (min-width: 700px) {
body {

View file

@ -47,7 +47,7 @@ def extract_questions_offline(
if use_history:
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj":
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"

View file

@ -12,11 +12,12 @@ def download_model(model_name: str):
logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.")
raise e
# Download the chat model
# Decide whether to load model to GPU or CPU
chat_model_config = None
try:
# Download the chat model and its config
chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True)
# Decide whether to load model to GPU or CPU
try:
# Try load chat model to GPU if:
# 1. Loading chat model to GPU isn't disabled via CLI and
# 2. Machine has GPU
@ -26,6 +27,12 @@ def download_model(model_name: str):
)
except ValueError:
device = "cpu"
except Exception as e:
if chat_model_config is None:
device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory
logger.debug(f"Unable to download model config from gpt4all website: {e}")
else:
raise e
# Now load the downloaded chat model onto appropriate device
chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False)

View file

@ -41,7 +41,7 @@ def extract_questions(
[
f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj"
if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image"
]
)
@ -123,8 +123,8 @@ def send_message_to_model(
def converse(
references,
online_results,
user_query,
online_results=[],
conversation_log={},
model: str = "gpt-3.5-turbo",
api_key: Optional[str] = None,

View file

@ -36,11 +36,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler):
@retry(
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
retry_if_exception_type(openai._exceptions.APITimeoutError)
| retry_if_exception_type(openai._exceptions.APIError)
| retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai._exceptions.APIStatusError)
),
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(3),
@ -57,11 +57,11 @@ def completion_with_backoff(**kwargs):
@retry(
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
retry_if_exception_type(openai._exceptions.APITimeoutError)
| retry_if_exception_type(openai._exceptions.APIError)
| retry_if_exception_type(openai._exceptions.APIConnectionError)
| retry_if_exception_type(openai._exceptions.RateLimitError)
| retry_if_exception_type(openai._exceptions.APIStatusError)
),
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(3),

View file

@ -3,13 +3,13 @@ from io import BufferedReader
# External Packages
from asgiref.sync import sync_to_async
import openai
from openai import OpenAI
async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str:
async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str:
"""
Transcribe audio file using Whisper model via OpenAI's API
"""
# Send the audio data to the Whisper API
response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key)
return response["text"]
response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file)
return response.text

View file

@ -4,6 +4,7 @@ from time import perf_counter
import json
from datetime import datetime
import queue
from typing import Any, Dict, List
import tiktoken
# External packages
@ -11,6 +12,8 @@ from langchain.schema import ChatMessage
from transformers import AutoTokenizer
# Internal Packages
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser
from khoj.utils.helpers import merge_dicts
@ -89,6 +92,32 @@ def message_to_log(
return conversation_log
def save_to_conversation_log(
q: str,
chat_response: str,
user: KhojUser,
meta_log: Dict,
user_message_time: str = None,
compiled_references: List[str] = [],
online_results: Dict[str, Any] = {},
inferred_queries: List[str] = [],
intent_type: str = "remember",
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
},
conversation_log=meta_log.get("chat", []),
)
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
def generate_chatml_messages_with_context(
user_message,
system_message,

View file

@ -19,7 +19,7 @@ from starlette.authentication import requires
from khoj.configure import configure_server
from khoj.database import adapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import ChatModelOptions
from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions
from khoj.database.models import Entry as DbEntry
from khoj.database.models import (
GithubConfig,
@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.tools.online_search import search_with_google
from khoj.routers.helpers import (
ApiUserRateLimiter,
CommonQueryParams,
agenerate_chat_response,
get_conversation_command,
text_to_image,
is_ready_to_chat,
update_telemetry_state,
validate_conversation_config,
@ -622,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi
# Send the audio data to the Whisper API
speech_to_text_config = await ConversationAdapters.get_speech_to_text_config()
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
if not speech_to_text_config:
# If the user has not configured a speech to text model, return an unprocessable entity error
status_code = 422
elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI:
api_key = openai_chat_config.api_key
# If the user has not configured a speech to text model, return an unsupported on server error
status_code = 501
elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI:
speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key)
elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE:
user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client)
elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE:
speech2text_model = speech_to_text_config.model_name
user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model)
user_message = await transcribe_audio_offline(audio_filename, speech2text_model)
finally:
# Close and Delete the temporary audio file
audio_file.close()
@ -665,7 +665,7 @@ async def chat(
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response:
user = request.user.object
user: KhojUser = request.user.object
await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
@ -703,6 +703,11 @@ async def chat(
media_type="text/event-stream",
status_code=200,
)
elif conversation_command == ConversationCommand.Image:
image, status_code = await text_to_image(q)
await sync_to_async(save_to_conversation_log)(q, image, user, meta_log, intent_type="text-to-image")
content_obj = {"image": image, "intentType": "text-to-image"}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice.
llm_response, chat_metadata = await agenerate_chat_response(
@ -786,7 +791,6 @@ async def extract_references_and_questions(
conversation_config = await ConversationAdapters.aget_conversation_config(user)
if conversation_config is None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
openai_chat_config = await ConversationAdapters.aget_openai_conversation_config()
if (
offline_chat_config
and offline_chat_config.enabled
@ -803,7 +807,7 @@ async def extract_references_and_questions(
inferred_queries = extract_questions_offline(
defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False
)
elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = await ConversationAdapters.get_openai_chat_config()
openai_chat = await ConversationAdapters.get_openai_chat()
api_key = openai_chat_config.api_key

View file

@ -9,23 +9,23 @@ from functools import partial
from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
# External Packages
from fastapi import Depends, Header, HTTPException, Request, UploadFile
import openai
from starlette.authentication import has_required_scope
from asgiref.sync import sync_to_async
# Internal Packages
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser, Subscription
from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log
# Internal Packages
from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log
from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ConversationCommand, log_telemetry
logger = logging.getLogger(__name__)
executor = ThreadPoolExecutor(max_workers=1)
@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.General
elif query.startswith("/online"):
return ConversationCommand.Online
elif query.startswith("/image"):
return ConversationCommand.Image
# If no relevant notes found for the given query
elif not any_references:
return ConversationCommand.General
@ -186,30 +188,7 @@ def generate_chat_response(
conversation_command: ConversationCommand = ConversationCommand.Default,
user: KhojUser = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
def _save_to_conversation_log(
q: str,
chat_response: str,
user_message_time: str,
compiled_references: List[str],
online_results: Dict[str, Any],
inferred_queries: List[str],
meta_log,
):
updated_conversation = message_to_log(
user_message=q,
chat_response=chat_response,
user_message_metadata={"created": user_message_time},
khoj_message_metadata={
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries},
"onlineContext": online_results,
},
conversation_log=meta_log.get("chat", []),
)
ConversationAdapters.save_conversation(user, {"chat": updated_conversation})
# Initialize Variables
user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
chat_response = None
logger.debug(f"Conversation Type: {conversation_command.name}")
@ -217,13 +196,13 @@ def generate_chat_response(
try:
partial_completion = partial(
_save_to_conversation_log,
save_to_conversation_log,
q,
user_message_time=user_message_time,
user=user,
meta_log=meta_log,
compiled_references=compiled_references,
online_results=online_results,
inferred_queries=inferred_queries,
meta_log=meta_log,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
@ -251,9 +230,9 @@ def generate_chat_response(
chat_model = conversation_config.chat_model
chat_response = converse(
compiled_references,
online_results,
q,
meta_log,
online_results=online_results,
conversation_log=meta_log,
model=chat_model,
api_key=api_key,
completion_func=partial_completion,
@ -271,6 +250,29 @@ def generate_chat_response(
return chat_response, metadata
async def text_to_image(message: str) -> Tuple[Optional[str], int]:
status_code = 200
image = None
# Send the audio data to the Whisper API
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
text2image_model = text_to_image_config.model_name
try:
response = state.openai_client.images.generate(
prompt=message, model=text2image_model, response_format="b64_json"
)
image = response.data[0].b64_json
except openai.OpenAIError as e:
logger.error(f"Image Generation failed with {e.http_status}: {e.error}")
status_code = 500
return image, status_code
class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int):
self.requests = requests

View file

@ -273,6 +273,7 @@ class ConversationCommand(str, Enum):
Notes = "notes"
Help = "help"
Online = "online"
Image = "image"
command_descriptions = {
@ -280,6 +281,7 @@ command_descriptions = {
ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.",
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Online: "Look up information on the internet.",
ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
}

View file

@ -7,6 +7,7 @@ from khoj.database.models import (
OpenAIProcessorConversationConfig,
ChatModelOptions,
SpeechToTextModelOptions,
TextToImageModelConfig,
)
from khoj.utils.constants import default_offline_chat_model, default_online_chat_model
@ -103,6 +104,15 @@ def initialization():
model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI
)
default_text_to_image_model = "dall-e-3"
openai_text_to_image_model = input(
f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): "
)
openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model
TextToImageModelConfig.objects.create(
model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI
)
if use_offline_model == "y" or use_openai_model == "y":
logger.info("🗣️ Chat model configuration complete")

View file

@ -22,17 +22,9 @@ class BaseEncoder(ABC):
class OpenAI(BaseEncoder):
def __init__(self, model_name, device=None):
def __init__(self, model_name, client: openai.OpenAI, device=None):
self.model_name = model_name
if (
not state.processor_config
or not state.processor_config.conversation
or not state.processor_config.conversation.openai_model
):
raise Exception(
f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}"
)
openai.api_key = state.processor_config.conversation.openai_model.api_key
self.openai_client = client
self.embedding_dimensions = None
def encode(self, entries, device=None, **kwargs):
@ -43,7 +35,7 @@ class OpenAI(BaseEncoder):
processed_entry = entries[index].replace("\n", " ")
try:
response = openai.Embedding.create(input=processed_entry, model=self.model_name)
response = self.openai_client.embeddings.create(input=processed_entry, model=self.model_name)
embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)]
# Use current models embedding dimension, once available
# Else default to embedding dimensions of the text-embedding-ada-002 model

View file

@ -1,15 +1,16 @@
# Standard Packages
from collections import defaultdict
import os
from pathlib import Path
import threading
from typing import List, Dict
from collections import defaultdict
# External Packages
from pathlib import Path
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from openai import OpenAI
from whisper import Whisper
# Internal Packages
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.utils import config as utils_config
from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel
from khoj.utils.helpers import LRU, get_device
@ -21,6 +22,7 @@ search_models = SearchModels()
embeddings_model: EmbeddingsModel = None
cross_encoder_model: CrossEncoderModel = None
content_index = ContentIndex()
openai_client: OpenAI = None
gpt4all_processor_config: GPT4AllProcessorModel = None
whisper_model: Whisper = None
config_file: Path = None

View file

@ -68,10 +68,10 @@ def test_chat_with_online_content(chat_client):
response_message = response_message.split("### compiled references")[0]
# Assert
expected_responses = ["http://www.paulgraham.com/greatwork.html"]
expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"]
assert response.status_code == 200
assert any([expected_response in response_message for expected_response in expected_responses]), (
"Expected assistants name, [K|k]hoj, in response but got: " + response_message
"Expected links or serper not setup in response but got: " + response_message
)