diff --git a/pyproject.toml b/pyproject.toml index 519d2d60..5a206cce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ dependencies = [ "fastapi >= 0.104.1", "python-multipart >= 0.0.5", "jinja2 == 3.1.2", - "openai >= 0.27.0, < 1.0.0", + "openai >= 1.0.0", "tiktoken >= 0.3.2", "tenacity >= 8.2.2", "pillow ~= 9.5.0", diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 8c3fc49d..6b7fde07 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -179,7 +179,13 @@ return numOnlineReferences; } - function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) { + function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { + if (intentType === "text-to-image") { + let imageMarkdown = `![](data:image/png;base64,${message})`; + renderMessage(imageMarkdown, by, dt); + return; + } + if (context == null && onlineContext == null) { renderMessage(message, by, dt); return; @@ -244,6 +250,17 @@ // Remove any text between [INST] and tags. These are spurious instructions for the AI chat model. newHTML = newHTML.replace(/\[INST\].+(<\/s>)?/g, ''); + // Customize the rendering of images + md.renderer.rules.image = function(tokens, idx, options, env, self) { + let token = tokens[idx]; + + // Add class="text-to-image" to images + token.attrPush(['class', 'text-to-image']); + + // Use the default renderer to render image markdown format + return self.renderToken(tokens, idx, options); + }; + // Render markdown newHTML = md.render(newHTML); // Get any elements with a class that starts with "language" @@ -328,109 +345,142 @@ let chatInput = document.getElementById("chat-input"); chatInput.classList.remove("option-enabled"); - // Call specified Khoj API which returns a streamed response of type text/plain - fetch(url, { headers }) - .then(response => { - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let rawResponse = ""; - let references = null; + // Call specified Khoj API + let response = await fetch(url, { headers }); + let rawResponse = ""; + const contentType = response.headers.get("content-type"); - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != null) { - newResponseText.appendChild(references); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - return; + if (contentType === "application/json") { + // Handle JSON response + try { + const responseAsJson = await response.json(); + if (responseAsJson.image) { + // If response has image field, response is a generated image. + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + } + if (responseAsJson.detail) { + // If response has detail field, response is an error message. + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + } + } else { + // Handle streamed response of type text/event-stream or text/plain + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let references = null; + + readStream(); + + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + // Append any references after all the data has been streamed + if (references != null) { + newResponseText.appendChild(references); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + return; + } + + // Decode message chunk from stream + const chunk = decoder.decode(value, { stream: true }); + + if (chunk.includes("### compiled references:")) { + const additionalResponse = chunk.split("### compiled references:")[0]; + rawResponse += additionalResponse; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + + const rawReference = chunk.split("### compiled references:")[1]; + const rawReferenceAsJson = JSON.parse(rawReference); + references = document.createElement('div'); + references.classList.add("references"); + + let referenceExpandButton = document.createElement('button'); + referenceExpandButton.classList.add("reference-expand-button"); + + let referenceSection = document.createElement('div'); + referenceSection.classList.add("reference-section"); + referenceSection.classList.add("collapsed"); + + let numReferences = 0; + + // If rawReferenceAsJson is a list, then count the length + if (Array.isArray(rawReferenceAsJson)) { + numReferences = rawReferenceAsJson.length; + + rawReferenceAsJson.forEach((reference, index) => { + let polishedReference = generateReference(reference, index); + referenceSection.appendChild(polishedReference); + }); + } else { + numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); } - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); + references.appendChild(referenceExpandButton); - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; + referenceExpandButton.addEventListener('click', function() { + if (referenceSection.classList.contains("collapsed")) { + referenceSection.classList.remove("collapsed"); + referenceSection.classList.add("expanded"); + } else { + referenceSection.classList.add("collapsed"); + referenceSection.classList.remove("expanded"); + } + }); + + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; + references.appendChild(referenceSection); + readStream(); + } else { + // Display response from Khoj + if (newResponseText.getElementsByClassName("spinner").length > 0) { + newResponseText.removeChild(loadingSpinner); + } + + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. + if (chunk.startsWith("{") && chunk.endsWith("}")) { + try { + const responseAsJson = JSON.parse(chunk); + if (responseAsJson.image) { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + } + if (responseAsJson.detail) { + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + } + } else { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; newResponseText.innerHTML = ""; newResponseText.appendChild(formatHTMLMessage(rawResponse)); - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - references = document.createElement('div'); - references.classList.add("references"); - - let referenceExpandButton = document.createElement('button'); - referenceExpandButton.classList.add("reference-expand-button"); - - let referenceSection = document.createElement('div'); - referenceSection.classList.add("reference-section"); - referenceSection.classList.add("collapsed"); - - let numReferences = 0; - - // If rawReferenceAsJson is a list, then count the length - if (Array.isArray(rawReferenceAsJson)) { - numReferences = rawReferenceAsJson.length; - - rawReferenceAsJson.forEach((reference, index) => { - let polishedReference = generateReference(reference, index); - referenceSection.appendChild(polishedReference); - }); - } else { - numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); - } - - references.appendChild(referenceExpandButton); - - referenceExpandButton.addEventListener('click', function() { - if (referenceSection.classList.contains("collapsed")) { - referenceSection.classList.remove("collapsed"); - referenceSection.classList.add("expanded"); - } else { - referenceSection.classList.add("collapsed"); - referenceSection.classList.remove("expanded"); - } - }); - - let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; - referenceExpandButton.innerHTML = expandButtonText; - references.appendChild(referenceSection); readStream(); - } else { - // Display response from Khoj - if (newResponseText.getElementsByClassName("spinner").length > 0) { - newResponseText.removeChild(loadingSpinner); - } - // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. - if (chunk.startsWith("{") && chunk.endsWith("}")) { - try { - const responseAsJson = JSON.parse(chunk); - if (responseAsJson.detail) { - newResponseText.innerHTML += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - newResponseText.innerHTML += chunk; - } - } else { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - - readStream(); - } } + } - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }); - } - readStream(); - }); + // Scroll to bottom of chat window as chat response is streamed + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + }); + } + } } function incrementalChat(event) { @@ -522,7 +572,7 @@ .then(response => { // Render conversation history, if any response.forEach(chat_log => { - renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext); + renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type); }); }) .catch(err => { @@ -625,9 +675,13 @@ .then(response => response.ok ? response.json() : Promise.reject(response)) .then(data => { chatInput.value += data.text; }) .catch(err => { - err.status == 422 - ? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") - : flashStatusInChatInput("⛔️ Failed to transcribe audio") + if (err.status === 501) { + flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") + } else if (err.status === 422) { + flashStatusInChatInput("⛔️ Audio file to large to process.") + } else { + flashStatusInChatInput("⛔️ Failed to transcribe audio.") + } }); }; @@ -810,6 +864,9 @@ margin-top: -10px; transform: rotate(-60deg) } + img.text-to-image { + max-width: 60%; + } #chat-footer { padding: 0; @@ -1050,6 +1107,9 @@ margin: 4px; grid-template-columns: auto; } + img.text-to-image { + max-width: 100%; + } } @media only screen and (min-width: 600px) { body { diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index e5cff4d2..115f4c1f 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -2,6 +2,12 @@ import { App, MarkdownRenderer, Modal, request, requestUrl, setIcon } from 'obsi import { KhojSetting } from 'src/settings'; import fetch from "node-fetch"; +export interface ChatJsonResult { + image?: string; + detail?: string; +} + + export class KhojChatModal extends Modal { result: string; setting: KhojSetting; @@ -105,15 +111,19 @@ export class KhojChatModal extends Modal { return referenceButton; } - renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date) { + renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string) { if (!message) { return; + } else if (intentType === "text-to-image") { + let imageMarkdown = `![](data:image/png;base64,${message})`; + this.renderMessage(chatEl, imageMarkdown, sender, dt); + return; } else if (!context) { this.renderMessage(chatEl, message, sender, dt); - return + return; } else if (!!context && context?.length === 0) { this.renderMessage(chatEl, message, sender, dt); - return + return; } let chatMessageEl = this.renderMessage(chatEl, message, sender, dt); let chatMessageBodyEl = chatMessageEl.getElementsByClassName("khoj-chat-message-text")[0] @@ -225,7 +235,7 @@ export class KhojChatModal extends Modal { let response = await request({ url: chatUrl, headers: headers }); let chatLogs = JSON.parse(response).response; chatLogs.forEach((chatLog: any) => { - this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created)); + this.renderMessageWithReferences(chatBodyEl, chatLog.message, chatLog.by, chatLog.context, new Date(chatLog.created), chatLog.intent?.type); }); } @@ -266,8 +276,25 @@ export class KhojChatModal extends Modal { this.result = ""; responseElement.innerHTML = ""; + if (response.headers.get("content-type") == "application/json") { + let responseText = "" + try { + const responseAsJson = await response.json() as ChatJsonResult; + if (responseAsJson.image) { + responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`; + } else if (responseAsJson.detail) { + responseText = responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + responseText = response.body.read().toString() + } finally { + this.renderIncrementalMessage(responseElement, responseText); + } + } + for await (const chunk of response.body) { - const responseText = chunk.toString(); + let responseText = chunk.toString(); if (responseText.includes("### compiled references:")) { const additionalResponse = responseText.split("### compiled references:")[0]; this.renderIncrementalMessage(responseElement, additionalResponse); @@ -310,6 +337,12 @@ export class KhojChatModal extends Modal { referenceExpandButton.innerHTML = expandButtonText; references.appendChild(referenceSection); } else { + if (responseText.startsWith("{") && responseText.endsWith("}")) { + } else { + // If the chunk is not a JSON object, just display it as is + continue; + } + this.renderIncrementalMessage(responseElement, responseText); } } @@ -389,10 +422,12 @@ export class KhojChatModal extends Modal { if (response.status === 200) { console.log(response); chatInput.value += response.json.text; - } else if (response.status === 422) { - throw new Error("⛔️ Failed to transcribe audio"); - } else { + } else if (response.status === 501) { throw new Error("⛔️ Configure speech-to-text model on server."); + } else if (response.status === 422) { + throw new Error("⛔️ Audio file to large to process."); + } else { + throw new Error("⛔️ Failed to transcribe audio."); } }; diff --git a/src/interface/obsidian/styles.css b/src/interface/obsidian/styles.css index 6a29e722..11fc0086 100644 --- a/src/interface/obsidian/styles.css +++ b/src/interface/obsidian/styles.css @@ -217,7 +217,9 @@ button.copy-button:hover { background: #f5f5f5; cursor: pointer; } - +img { + max-width: 60%; +} #khoj-chat-footer { padding: 0; diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 19e7d403..717ad859 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -7,6 +7,7 @@ import requests import os # External Packages +import openai import schedule from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.authentication import AuthenticationMiddleware @@ -22,6 +23,7 @@ from starlette.authentication import ( # Internal Packages from khoj.database.models import KhojUser, Subscription from khoj.database.adapters import ( + ConversationAdapters, get_all_users, get_or_create_search_model, aget_user_subscription_state, @@ -138,6 +140,10 @@ def configure_server( config = FullConfig() state.config = config + if ConversationAdapters.has_valid_openai_conversation_config(): + openai_config = ConversationAdapters.get_openai_conversation_config() + state.openai_client = openai.OpenAI(api_key=openai_config.api_key) + # Initialize Search Models from Config and initialize content try: state.embeddings_model = EmbeddingsModel(get_or_create_search_model().bi_encoder) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 9d18c815..6e17c15b 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -22,6 +22,7 @@ from khoj.database.models import ( GithubConfig, GithubRepoConfig, GoogleUser, + TextToImageModelConfig, KhojApiUser, KhojUser, NotionConfig, @@ -426,6 +427,10 @@ class ConversationAdapters: else: raise ValueError("Invalid conversation config - either configure offline chat or openai chat") + @staticmethod + async def aget_text_to_image_model_config(): + return await TextToImageModelConfig.objects.filter().afirst() + class EntryAdapters: word_filer = WordFilter() diff --git a/src/khoj/database/migrations/0022_texttoimagemodelconfig.py b/src/khoj/database/migrations/0022_texttoimagemodelconfig.py new file mode 100644 index 00000000..7450dc40 --- /dev/null +++ b/src/khoj/database/migrations/0022_texttoimagemodelconfig.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.7 on 2023-12-04 22:17 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0021_speechtotextmodeloptions_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="TextToImageModelConfig", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("model_name", models.CharField(default="dall-e-3", max_length=200)), + ("model_type", models.CharField(choices=[("openai", "Openai")], default="openai", max_length=200)), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 82348fbe..00700f2f 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -112,6 +112,14 @@ class SearchModelConfig(BaseModel): cross_encoder = models.CharField(max_length=200, default="cross-encoder/ms-marco-MiniLM-L-6-v2") +class TextToImageModelConfig(BaseModel): + class ModelType(models.TextChoices): + OPENAI = "openai" + + model_name = models.CharField(max_length=200, default="dall-e-3") + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) + + class OpenAIProcessorConversationConfig(BaseModel): api_key = models.CharField(max_length=200) diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index a55d8e29..e85759fb 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -183,12 +183,18 @@ To get started, just start typing below. You can also type / to see a list of co referenceSection.appendChild(polishedReference); } } - } + } return numOnlineReferences; - } + } + + function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { + if (intentType === "text-to-image") { + let imageMarkdown = `![](data:image/png;base64,${message})`; + renderMessage(imageMarkdown, by, dt); + return; + } - function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null) { if (context == null && onlineContext == null) { renderMessage(message, by, dt); return; @@ -253,6 +259,17 @@ To get started, just start typing below. You can also type / to see a list of co // Remove any text between [INST] and tags. These are spurious instructions for the AI chat model. newHTML = newHTML.replace(/\[INST\].+(<\/s>)?/g, ''); + // Customize the rendering of images + md.renderer.rules.image = function(tokens, idx, options, env, self) { + let token = tokens[idx]; + + // Add class="text-to-image" to images + token.attrPush(['class', 'text-to-image']); + + // Use the default renderer to render image markdown format + return self.renderToken(tokens, idx, options); + }; + // Render markdown newHTML = md.render(newHTML); // Get any elements with a class that starts with "language" @@ -292,7 +309,7 @@ To get started, just start typing below. You can also type / to see a list of co return element } - function chat() { + async function chat() { // Extract required fields for search from form let query = document.getElementById("chat-input").value.trim(); let resultsCount = localStorage.getItem("khojResultsCount") || 5; @@ -333,113 +350,123 @@ To get started, just start typing below. You can also type / to see a list of co let chatInput = document.getElementById("chat-input"); chatInput.classList.remove("option-enabled"); - // Call specified Khoj API which returns a streamed response of type text/plain - fetch(url) - .then(response => { - const reader = response.body.getReader(); - const decoder = new TextDecoder(); - let rawResponse = ""; - let references = null; + // Call specified Khoj API + let response = await fetch(url); + let rawResponse = ""; + const contentType = response.headers.get("content-type"); - function readStream() { - reader.read().then(({ done, value }) => { - if (done) { - // Append any references after all the data has been streamed - if (references != null) { - newResponseText.appendChild(references); - } - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - document.getElementById("chat-input").removeAttribute("disabled"); - return; - } - - // Decode message chunk from stream - const chunk = decoder.decode(value, { stream: true }); - - if (chunk.includes("### compiled references:")) { - const additionalResponse = chunk.split("### compiled references:")[0]; - rawResponse += additionalResponse; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - - const rawReference = chunk.split("### compiled references:")[1]; - const rawReferenceAsJson = JSON.parse(rawReference); - references = document.createElement('div'); - references.classList.add("references"); - - let referenceExpandButton = document.createElement('button'); - referenceExpandButton.classList.add("reference-expand-button"); - - let referenceSection = document.createElement('div'); - referenceSection.classList.add("reference-section"); - referenceSection.classList.add("collapsed"); - - let numReferences = 0; - - // If rawReferenceAsJson is a list, then count the length - if (Array.isArray(rawReferenceAsJson)) { - numReferences = rawReferenceAsJson.length; - - rawReferenceAsJson.forEach((reference, index) => { - let polishedReference = generateReference(reference, index); - referenceSection.appendChild(polishedReference); - }); - } else { - numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); - } - - references.appendChild(referenceExpandButton); - - referenceExpandButton.addEventListener('click', function() { - if (referenceSection.classList.contains("collapsed")) { - referenceSection.classList.remove("collapsed"); - referenceSection.classList.add("expanded"); - } else { - referenceSection.classList.add("collapsed"); - referenceSection.classList.remove("expanded"); - } - }); - - let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; - referenceExpandButton.innerHTML = expandButtonText; - references.appendChild(referenceSection); - readStream(); - } else { - // Display response from Khoj - if (newResponseText.getElementsByClassName("spinner").length > 0) { - newResponseText.removeChild(loadingSpinner); - } - - // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. - if (chunk.startsWith("{") && chunk.endsWith("}")) { - try { - const responseAsJson = JSON.parse(chunk); - if (responseAsJson.detail) { - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - } - } else { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - readStream(); - } - } - - // Scroll to bottom of chat window as chat response is streamed - document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; - }); + if (contentType === "application/json") { + // Handle JSON response + try { + const responseAsJson = await response.json(); + if (responseAsJson.image) { + // If response has image field, response is a generated image. + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; } - readStream(); - }); - } + if (responseAsJson.detail) { + // If response has detail field, response is an error message. + rawResponse += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + } finally { + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + } + } else { + // Handle streamed response of type text/event-stream or text/plain + const reader = response.body.getReader(); + const decoder = new TextDecoder(); + let references = null; + + readStream(); + + function readStream() { + reader.read().then(({ done, value }) => { + if (done) { + // Append any references after all the data has been streamed + if (references != null) { + newResponseText.appendChild(references); + } + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + return; + } + + // Decode message chunk from stream + const chunk = decoder.decode(value, { stream: true }); + + if (chunk.includes("### compiled references:")) { + const additionalResponse = chunk.split("### compiled references:")[0]; + rawResponse += additionalResponse; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + + const rawReference = chunk.split("### compiled references:")[1]; + const rawReferenceAsJson = JSON.parse(rawReference); + references = document.createElement('div'); + references.classList.add("references"); + + let referenceExpandButton = document.createElement('button'); + referenceExpandButton.classList.add("reference-expand-button"); + + let referenceSection = document.createElement('div'); + referenceSection.classList.add("reference-section"); + referenceSection.classList.add("collapsed"); + + let numReferences = 0; + + // If rawReferenceAsJson is a list, then count the length + if (Array.isArray(rawReferenceAsJson)) { + numReferences = rawReferenceAsJson.length; + + rawReferenceAsJson.forEach((reference, index) => { + let polishedReference = generateReference(reference, index); + referenceSection.appendChild(polishedReference); + }); + } else { + numReferences += processOnlineReferences(referenceSection, rawReferenceAsJson); + } + + references.appendChild(referenceExpandButton); + + referenceExpandButton.addEventListener('click', function() { + if (referenceSection.classList.contains("collapsed")) { + referenceSection.classList.remove("collapsed"); + referenceSection.classList.add("expanded"); + } else { + referenceSection.classList.add("collapsed"); + referenceSection.classList.remove("expanded"); + } + }); + + let expandButtonText = numReferences == 1 ? "1 reference" : `${numReferences} references`; + referenceExpandButton.innerHTML = expandButtonText; + references.appendChild(referenceSection); + readStream(); + } else { + // Display response from Khoj + if (newResponseText.getElementsByClassName("spinner").length > 0) { + newResponseText.removeChild(loadingSpinner); + } + + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + readStream(); + } + }); + + // Scroll to bottom of chat window as chat response is streamed + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + }; + } + }; function incrementalChat(event) { if (!event.shiftKey && event.key === 'Enter') { @@ -516,7 +543,7 @@ To get started, just start typing below. You can also type / to see a list of co .then(response => { // Render conversation history, if any response.forEach(chat_log => { - renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext); + renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type); }); }) .catch(err => { @@ -611,9 +638,13 @@ To get started, just start typing below. You can also type / to see a list of co .then(response => response.ok ? response.json() : Promise.reject(response)) .then(data => { chatInput.value += data.text; }) .catch(err => { - err.status == 422 - ? flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") - : flashStatusInChatInput("⛔️ Failed to transcribe audio") + if (err.status === 501) { + flashStatusInChatInput("⛔️ Configure speech-to-text model on server.") + } else if (err.status === 422) { + flashStatusInChatInput("⛔️ Audio file to large to process.") + } else { + flashStatusInChatInput("⛔️ Failed to transcribe audio.") + } }); }; @@ -902,6 +933,9 @@ To get started, just start typing below. You can also type / to see a list of co margin-top: -10px; transform: rotate(-60deg) } + img.text-to-image { + max-width: 60%; + } #chat-footer { padding: 0; @@ -1029,6 +1063,9 @@ To get started, just start typing below. You can also type / to see a list of co margin: 4px; grid-template-columns: auto; } + img.text-to-image { + max-width: 100%; + } } @media only screen and (min-width: 700px) { body { diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 41b1844b..211e3cac 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -47,7 +47,7 @@ def extract_questions_offline( if use_history: for chat in conversation_log.get("chat", [])[-4:]: - if chat["by"] == "khoj": + if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image": chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"A: {chat['message']}\n" diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index 0b876b26..3a1862f7 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -12,11 +12,12 @@ def download_model(model_name: str): logger.info("There was an error importing GPT4All. Please run pip install gpt4all in order to install it.") raise e - # Download the chat model - chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True) - # Decide whether to load model to GPU or CPU + chat_model_config = None try: + # Download the chat model and its config + chat_model_config = gpt4all.GPT4All.retrieve_model(model_name=model_name, allow_download=True) + # Try load chat model to GPU if: # 1. Loading chat model to GPU isn't disabled via CLI and # 2. Machine has GPU @@ -26,6 +27,12 @@ def download_model(model_name: str): ) except ValueError: device = "cpu" + except Exception as e: + if chat_model_config is None: + device = "cpu" # Fallback to CPU as can't determine if GPU has enough memory + logger.debug(f"Unable to download model config from gpt4all website: {e}") + else: + raise e # Now load the downloaded chat model onto appropriate device chat_model = gpt4all.GPT4All(model_name=model_name, device=device, allow_download=False) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index fed110f7..7bebc26c 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -41,7 +41,7 @@ def extract_questions( [ f'Q: {chat["intent"]["query"]}\n\n{chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}\n\n{chat["message"]}\n\n' for chat in conversation_log.get("chat", [])[-4:] - if chat["by"] == "khoj" + if chat["by"] == "khoj" and chat["intent"].get("type") != "text-to-image" ] ) @@ -123,8 +123,8 @@ def send_message_to_model( def converse( references, - online_results, user_query, + online_results=[], conversation_log={}, model: str = "gpt-3.5-turbo", api_key: Optional[str] = None, diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 69fab7e5..a5868cb8 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -36,11 +36,11 @@ class StreamingChatCallbackHandler(StreamingStdOutCallbackHandler): @retry( retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) + retry_if_exception_type(openai._exceptions.APITimeoutError) + | retry_if_exception_type(openai._exceptions.APIError) + | retry_if_exception_type(openai._exceptions.APIConnectionError) + | retry_if_exception_type(openai._exceptions.RateLimitError) + | retry_if_exception_type(openai._exceptions.APIStatusError) ), wait=wait_random_exponential(min=1, max=10), stop=stop_after_attempt(3), @@ -57,11 +57,11 @@ def completion_with_backoff(**kwargs): @retry( retry=( - retry_if_exception_type(openai.error.Timeout) - | retry_if_exception_type(openai.error.APIError) - | retry_if_exception_type(openai.error.APIConnectionError) - | retry_if_exception_type(openai.error.RateLimitError) - | retry_if_exception_type(openai.error.ServiceUnavailableError) + retry_if_exception_type(openai._exceptions.APITimeoutError) + | retry_if_exception_type(openai._exceptions.APIError) + | retry_if_exception_type(openai._exceptions.APIConnectionError) + | retry_if_exception_type(openai._exceptions.RateLimitError) + | retry_if_exception_type(openai._exceptions.APIStatusError) ), wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3), diff --git a/src/khoj/processor/conversation/openai/whisper.py b/src/khoj/processor/conversation/openai/whisper.py index 72834d92..bd0e66df 100644 --- a/src/khoj/processor/conversation/openai/whisper.py +++ b/src/khoj/processor/conversation/openai/whisper.py @@ -3,13 +3,13 @@ from io import BufferedReader # External Packages from asgiref.sync import sync_to_async -import openai +from openai import OpenAI -async def transcribe_audio(audio_file: BufferedReader, model, api_key) -> str: +async def transcribe_audio(audio_file: BufferedReader, model, client: OpenAI) -> str: """ Transcribe audio file using Whisper model via OpenAI's API """ # Send the audio data to the Whisper API - response = await sync_to_async(openai.Audio.translate)(model=model, file=audio_file, api_key=api_key) - return response["text"] + response = await sync_to_async(client.audio.translations.create)(model=model, file=audio_file) + return response.text diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index e2164719..4606efd9 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -4,6 +4,7 @@ from time import perf_counter import json from datetime import datetime import queue +from typing import Any, Dict, List import tiktoken # External packages @@ -11,6 +12,8 @@ from langchain.schema import ChatMessage from transformers import AutoTokenizer # Internal Packages +from khoj.database.adapters import ConversationAdapters +from khoj.database.models import KhojUser from khoj.utils.helpers import merge_dicts @@ -89,6 +92,32 @@ def message_to_log( return conversation_log +def save_to_conversation_log( + q: str, + chat_response: str, + user: KhojUser, + meta_log: Dict, + user_message_time: str = None, + compiled_references: List[str] = [], + online_results: Dict[str, Any] = {}, + inferred_queries: List[str] = [], + intent_type: str = "remember", +): + user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") + updated_conversation = message_to_log( + user_message=q, + chat_response=chat_response, + user_message_metadata={"created": user_message_time}, + khoj_message_metadata={ + "context": compiled_references, + "intent": {"inferred-queries": inferred_queries, "type": intent_type}, + "onlineContext": online_results, + }, + conversation_log=meta_log.get("chat", []), + ) + ConversationAdapters.save_conversation(user, {"chat": updated_conversation}) + + def generate_chatml_messages_with_context( user_message, system_message, diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index ae125980..7efd8bfd 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -19,7 +19,7 @@ from starlette.authentication import requires from khoj.configure import configure_server from khoj.database import adapters from khoj.database.adapters import ConversationAdapters, EntryAdapters -from khoj.database.models import ChatModelOptions +from khoj.database.models import ChatModelOptions, SpeechToTextModelOptions from khoj.database.models import Entry as DbEntry from khoj.database.models import ( GithubConfig, @@ -35,12 +35,14 @@ from khoj.processor.conversation.offline.whisper import transcribe_audio_offline from khoj.processor.conversation.openai.gpt import extract_questions from khoj.processor.conversation.openai.whisper import transcribe_audio from khoj.processor.conversation.prompts import help_message, no_entries_found +from khoj.processor.conversation.utils import save_to_conversation_log from khoj.processor.tools.online_search import search_with_google from khoj.routers.helpers import ( ApiUserRateLimiter, CommonQueryParams, agenerate_chat_response, get_conversation_command, + text_to_image, is_ready_to_chat, update_telemetry_state, validate_conversation_config, @@ -622,17 +624,15 @@ async def transcribe(request: Request, common: CommonQueryParams, file: UploadFi # Send the audio data to the Whisper API speech_to_text_config = await ConversationAdapters.get_speech_to_text_config() - openai_chat_config = await ConversationAdapters.get_openai_chat_config() if not speech_to_text_config: - # If the user has not configured a speech to text model, return an unprocessable entity error - status_code = 422 - elif openai_chat_config and speech_to_text_config.model_type == ChatModelOptions.ModelType.OPENAI: - api_key = openai_chat_config.api_key + # If the user has not configured a speech to text model, return an unsupported on server error + status_code = 501 + elif state.openai_client and speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OPENAI: speech2text_model = speech_to_text_config.model_name - user_message = await transcribe_audio(audio_file, model=speech2text_model, api_key=api_key) - elif speech_to_text_config.model_type == ChatModelOptions.ModelType.OFFLINE: + user_message = await transcribe_audio(audio_file, speech2text_model, client=state.openai_client) + elif speech_to_text_config.model_type == SpeechToTextModelOptions.ModelType.OFFLINE: speech2text_model = speech_to_text_config.model_name - user_message = await transcribe_audio_offline(audio_filename, model=speech2text_model) + user_message = await transcribe_audio_offline(audio_filename, speech2text_model) finally: # Close and Delete the temporary audio file audio_file.close() @@ -665,7 +665,7 @@ async def chat( rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)), rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), ) -> Response: - user = request.user.object + user: KhojUser = request.user.object await is_ready_to_chat(user) conversation_command = get_conversation_command(query=q, any_references=True) @@ -703,6 +703,11 @@ async def chat( media_type="text/event-stream", status_code=200, ) + elif conversation_command == ConversationCommand.Image: + image, status_code = await text_to_image(q) + await sync_to_async(save_to_conversation_log)(q, image, user, meta_log, intent_type="text-to-image") + content_obj = {"image": image, "intentType": "text-to-image"} + return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) # Get the (streamed) chat response from the LLM of choice. llm_response, chat_metadata = await agenerate_chat_response( @@ -786,7 +791,6 @@ async def extract_references_and_questions( conversation_config = await ConversationAdapters.aget_conversation_config(user) if conversation_config is None: conversation_config = await ConversationAdapters.aget_default_conversation_config() - openai_chat_config = await ConversationAdapters.aget_openai_conversation_config() if ( offline_chat_config and offline_chat_config.enabled @@ -803,7 +807,7 @@ async def extract_references_and_questions( inferred_queries = extract_questions_offline( defiltered_query, loaded_model=loaded_model, conversation_log=meta_log, should_extract_questions=False ) - elif openai_chat_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: + elif conversation_config and conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: openai_chat_config = await ConversationAdapters.get_openai_chat_config() openai_chat = await ConversationAdapters.get_openai_chat() api_key = openai_chat_config.api_key diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 22b0d037..4e883f35 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -9,23 +9,23 @@ from functools import partial from time import time from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union +# External Packages from fastapi import Depends, Header, HTTPException, Request, UploadFile +import openai from starlette.authentication import has_required_scope -from asgiref.sync import sync_to_async - +# Internal Packages from khoj.database.adapters import ConversationAdapters, EntryAdapters -from khoj.database.models import KhojUser, Subscription +from khoj.database.models import KhojUser, Subscription, TextToImageModelConfig from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.openai.gpt import converse, send_message_to_model -from khoj.processor.conversation.utils import ThreadedGenerator, message_to_log - -# Internal Packages +from khoj.processor.conversation.utils import ThreadedGenerator, save_to_conversation_log from khoj.utils import state from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.helpers import ConversationCommand, log_telemetry + logger = logging.getLogger(__name__) executor = ThreadPoolExecutor(max_workers=1) @@ -102,6 +102,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver return ConversationCommand.General elif query.startswith("/online"): return ConversationCommand.Online + elif query.startswith("/image"): + return ConversationCommand.Image # If no relevant notes found for the given query elif not any_references: return ConversationCommand.General @@ -186,30 +188,7 @@ def generate_chat_response( conversation_command: ConversationCommand = ConversationCommand.Default, user: KhojUser = None, ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: - def _save_to_conversation_log( - q: str, - chat_response: str, - user_message_time: str, - compiled_references: List[str], - online_results: Dict[str, Any], - inferred_queries: List[str], - meta_log, - ): - updated_conversation = message_to_log( - user_message=q, - chat_response=chat_response, - user_message_metadata={"created": user_message_time}, - khoj_message_metadata={ - "context": compiled_references, - "intent": {"inferred-queries": inferred_queries}, - "onlineContext": online_results, - }, - conversation_log=meta_log.get("chat", []), - ) - ConversationAdapters.save_conversation(user, {"chat": updated_conversation}) - # Initialize Variables - user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") chat_response = None logger.debug(f"Conversation Type: {conversation_command.name}") @@ -217,13 +196,13 @@ def generate_chat_response( try: partial_completion = partial( - _save_to_conversation_log, + save_to_conversation_log, q, - user_message_time=user_message_time, + user=user, + meta_log=meta_log, compiled_references=compiled_references, online_results=online_results, inferred_queries=inferred_queries, - meta_log=meta_log, ) conversation_config = ConversationAdapters.get_valid_conversation_config(user) @@ -251,9 +230,9 @@ def generate_chat_response( chat_model = conversation_config.chat_model chat_response = converse( compiled_references, - online_results, q, - meta_log, + online_results=online_results, + conversation_log=meta_log, model=chat_model, api_key=api_key, completion_func=partial_completion, @@ -271,6 +250,29 @@ def generate_chat_response( return chat_response, metadata +async def text_to_image(message: str) -> Tuple[Optional[str], int]: + status_code = 200 + image = None + + # Send the audio data to the Whisper API + text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() + if not text_to_image_config: + # If the user has not configured a text to image model, return an unsupported on server error + status_code = 501 + elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: + text2image_model = text_to_image_config.model_name + try: + response = state.openai_client.images.generate( + prompt=message, model=text2image_model, response_format="b64_json" + ) + image = response.data[0].b64_json + except openai.OpenAIError as e: + logger.error(f"Image Generation failed with {e.http_status}: {e.error}") + status_code = 500 + + return image, status_code + + class ApiUserRateLimiter: def __init__(self, requests: int, subscribed_requests: int, window: int): self.requests = requests diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py index 42e3835d..21fe7e98 100644 --- a/src/khoj/utils/helpers.py +++ b/src/khoj/utils/helpers.py @@ -273,6 +273,7 @@ class ConversationCommand(str, Enum): Notes = "notes" Help = "help" Online = "online" + Image = "image" command_descriptions = { @@ -280,6 +281,7 @@ command_descriptions = { ConversationCommand.Notes: "Only talk about information that is available in your knowledge base.", ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.", ConversationCommand.Online: "Look up information on the internet.", + ConversationCommand.Image: "Generate images by describing your imagination in words.", ConversationCommand.Help: "Display a help message with all available commands and other metadata.", } diff --git a/src/khoj/utils/initialization.py b/src/khoj/utils/initialization.py index 313b18fc..0bb78dbe 100644 --- a/src/khoj/utils/initialization.py +++ b/src/khoj/utils/initialization.py @@ -7,6 +7,7 @@ from khoj.database.models import ( OpenAIProcessorConversationConfig, ChatModelOptions, SpeechToTextModelOptions, + TextToImageModelConfig, ) from khoj.utils.constants import default_offline_chat_model, default_online_chat_model @@ -103,6 +104,15 @@ def initialization(): model_name=openai_speech2text_model, model_type=SpeechToTextModelOptions.ModelType.OPENAI ) + default_text_to_image_model = "dall-e-3" + openai_text_to_image_model = input( + f"Enter the OpenAI text to image model you want to use (default: {default_text_to_image_model}): " + ) + openai_speech2text_model = openai_text_to_image_model or default_text_to_image_model + TextToImageModelConfig.objects.create( + model_name=openai_text_to_image_model, model_type=TextToImageModelConfig.ModelType.OPENAI + ) + if use_offline_model == "y" or use_openai_model == "y": logger.info("🗣️ Chat model configuration complete") diff --git a/src/khoj/utils/models.py b/src/khoj/utils/models.py index b5bbe292..e1298b08 100644 --- a/src/khoj/utils/models.py +++ b/src/khoj/utils/models.py @@ -22,17 +22,9 @@ class BaseEncoder(ABC): class OpenAI(BaseEncoder): - def __init__(self, model_name, device=None): + def __init__(self, model_name, client: openai.OpenAI, device=None): self.model_name = model_name - if ( - not state.processor_config - or not state.processor_config.conversation - or not state.processor_config.conversation.openai_model - ): - raise Exception( - f"Set OpenAI API key under processor-config > conversation > openai-api-key in config file: {state.config_file}" - ) - openai.api_key = state.processor_config.conversation.openai_model.api_key + self.openai_client = client self.embedding_dimensions = None def encode(self, entries, device=None, **kwargs): @@ -43,7 +35,7 @@ class OpenAI(BaseEncoder): processed_entry = entries[index].replace("\n", " ") try: - response = openai.Embedding.create(input=processed_entry, model=self.model_name) + response = self.openai_client.embeddings.create(input=processed_entry, model=self.model_name) embedding_tensors += [torch.tensor(response.data[0].embedding, device=device)] # Use current models embedding dimension, once available # Else default to embedding dimensions of the text-embedding-ada-002 model diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index b54cf4b3..d5358868 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -1,15 +1,16 @@ # Standard Packages +from collections import defaultdict import os +from pathlib import Path import threading from typing import List, Dict -from collections import defaultdict # External Packages -from pathlib import Path -from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel +from openai import OpenAI from whisper import Whisper # Internal Packages +from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.utils import config as utils_config from khoj.utils.config import ContentIndex, SearchModels, GPT4AllProcessorModel from khoj.utils.helpers import LRU, get_device @@ -21,6 +22,7 @@ search_models = SearchModels() embeddings_model: EmbeddingsModel = None cross_encoder_model: CrossEncoderModel = None content_index = ContentIndex() +openai_client: OpenAI = None gpt4all_processor_config: GPT4AllProcessorModel = None whisper_model: Whisper = None config_file: Path = None diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index 7b98d4da..50ad60af 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -68,10 +68,10 @@ def test_chat_with_online_content(chat_client): response_message = response_message.split("### compiled references")[0] # Assert - expected_responses = ["http://www.paulgraham.com/greatwork.html"] + expected_responses = ["http://www.paulgraham.com/greatwork.html", "Please set your SERPER_DEV_API_KEY"] assert response.status_code == 200 assert any([expected_response in response_message for expected_response in expected_responses]), ( - "Expected assistants name, [K|k]hoj, in response but got: " + response_message + "Expected links or serper not setup in response but got: " + response_message )