From 81beb7940c6a2ada305e536aca2fbb9e9231b3df Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Fri, 8 Mar 2024 10:54:13 +0530 Subject: [PATCH] Upload generated images to s3, if AWS credentials and bucket is available (#667) * Upload generated images to s3, if AWS credentials and bucket is available. - In clients, render the images via the URL if it's returned with a text-to-image2 intent type * Make the loading screen more intuitve, less jerky and update the programmatic copy button * Update the loading icon when waiting for a chat response --- pyproject.toml | 1 + src/interface/desktop/chat.html | 165 +++++++++---- src/interface/obsidian/src/chat_modal.ts | 16 +- .../web/assets/icons/copy_button.svg | 1 + src/khoj/interface/web/chat.html | 224 ++++++++++++++---- .../content/notion/notion_to_entries.py | 2 +- src/khoj/processor/conversation/utils.py | 2 +- src/khoj/routers/api_chat.py | 15 +- src/khoj/routers/helpers.py | 58 +++-- src/khoj/routers/storage.py | 34 +++ 10 files changed, 392 insertions(+), 126 deletions(-) create mode 100644 src/khoj/interface/web/assets/icons/copy_button.svg create mode 100644 src/khoj/routers/storage.py diff --git a/pyproject.toml b/pyproject.toml index c47a5cb2..17003c6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ prod = [ "google-auth == 2.23.3", "stripe == 7.3.0", "twilio == 8.11", + "boto3 >= 1.34.57", ] dev = [ "khoj-assistant[prod]", diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 9bdc7cef..567b9f8c 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -198,8 +198,14 @@ function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { - if (intentType === "text-to-image") { - let imageMarkdown = `![](data:image/png;base64,${message})`; + if (intentType?.includes("text-to-image")) { + let imageMarkdown; + if (intentType === "text-to-image") { + imageMarkdown = `![](data:image/png;base64,${message})`; + } else if (intentType === "text-to-image2") { + imageMarkdown = `![](${message})`; + } + const inferredQuery = inferredQueries?.[0]; if (inferredQuery) { imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; @@ -266,8 +272,13 @@ references.appendChild(referenceSection); - if (intentType === "text-to-image") { - let imageMarkdown = `![](data:image/png;base64,${message})`; + if (intentType?.includes("text-to-image")) { + let imageMarkdown; + if (intentType === "text-to-image") { + imageMarkdown = `![](data:image/png;base64,${message})`; + } else if (intentType === "text-to-image2") { + imageMarkdown = `![](${message})`; + } const inferredQuery = inferredQueries?.[0]; if (inferredQuery) { imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; @@ -423,9 +434,27 @@ new_response.appendChild(newResponseText); // Temporary status message to indicate that Khoj is thinking - let loadingSpinner = document.createElement("div"); - loadingSpinner.classList.add("spinner"); - newResponseText.appendChild(loadingSpinner); + let loadingEllipsis = document.createElement("div"); + loadingEllipsis.classList.add("lds-ellipsis"); + + let firstEllipsis = document.createElement("div"); + firstEllipsis.classList.add("lds-ellipsis-item"); + + let secondEllipsis = document.createElement("div"); + secondEllipsis.classList.add("lds-ellipsis-item"); + + let thirdEllipsis = document.createElement("div"); + thirdEllipsis.classList.add("lds-ellipsis-item"); + + let fourthEllipsis = document.createElement("div"); + fourthEllipsis.classList.add("lds-ellipsis-item"); + + loadingEllipsis.appendChild(firstEllipsis); + loadingEllipsis.appendChild(secondEllipsis); + loadingEllipsis.appendChild(thirdEllipsis); + loadingEllipsis.appendChild(fourthEllipsis); + + newResponseText.appendChild(loadingEllipsis); document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; let chatTooltip = document.getElementById("chat-tooltip"); @@ -446,7 +475,11 @@ const responseAsJson = await response.json(); if (responseAsJson.image) { // If response has image field, response is a generated image. - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + if (responseAsJson.intentType === "text-to-image") { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + } else if (responseAsJson.intentType === "text-to-image2") { + rawResponse += `![${query}](${responseAsJson.image})`; + } const inferredQueries = responseAsJson.inferredQueries?.[0]; if (inferredQueries) { rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`; @@ -509,40 +542,16 @@ readStream(); } else { // Display response from Khoj - if (newResponseText.getElementsByClassName("spinner").length > 0) { - newResponseText.removeChild(loadingSpinner); + if (newResponseText.getElementsByClassName("lds-ellipsis").length > 0) { + newResponseText.removeChild(loadingEllipsis); } - // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. - if (chunk.startsWith("{") && chunk.endsWith("}")) { - try { - const responseAsJson = JSON.parse(chunk); - if (responseAsJson.image) { - // If response has image field, response is a generated image. - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; - const inferredQuery = responseAsJson.inferredQueries?.[0]; - if (inferredQuery) { - rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; - } - } - if (responseAsJson.detail) { - rawResponse += responseAsJson.detail; - } - } catch (error) { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - } finally { - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - } - } else { - // If the chunk is not a JSON object, just display it as is - rawResponse += chunk; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); - readStream(); - } + readStream(); } // Scroll to bottom of chat window as chat response is streamed @@ -1575,13 +1584,27 @@ } button.copy-button { - display: block; border-radius: 4px; background-color: var(--background-color); - } - button.copy-button:hover { - background: #f5f5f5; + border: 1px solid var(--main-text-color); + text-align: center; + font-size: 16px; + transition: all 0.5s; cursor: pointer; + padding: 4px; + float: right; + } + + button.copy-button span { + cursor: pointer; + display: inline-block; + position: relative; + transition: 0.5s; + } + + button.copy-button:hover { + background-color: black; + color: #f5f5f5; } pre { @@ -1815,5 +1838,61 @@ padding: 10px; white-space: pre-wrap; } + + .lds-ellipsis { + display: inline-block; + position: relative; + width: 60px; + height: 32px; + } + .lds-ellipsis div { + position: absolute; + top: 12px; + width: 12px; + height: 12px; + border-radius: 50%; + background: var(--main-text-color); + animation-timing-function: cubic-bezier(0, 1, 1, 0); + } + .lds-ellipsis div:nth-child(1) { + left: 8px; + animation: lds-ellipsis1 0.6s infinite; + } + .lds-ellipsis div:nth-child(2) { + left: 8px; + animation: lds-ellipsis2 0.6s infinite; + } + .lds-ellipsis div:nth-child(3) { + left: 32px; + animation: lds-ellipsis2 0.6s infinite; + } + .lds-ellipsis div:nth-child(4) { + left: 56px; + animation: lds-ellipsis3 0.6s infinite; + } + @keyframes lds-ellipsis1 { + 0% { + transform: scale(0); + } + 100% { + transform: scale(1); + } + } + @keyframes lds-ellipsis3 { + 0% { + transform: scale(1); + } + 100% { + transform: scale(0); + } + } + @keyframes lds-ellipsis2 { + 0% { + transform: translate(0, 0); + } + 100% { + transform: translate(24px, 0); + } + } diff --git a/src/interface/obsidian/src/chat_modal.ts b/src/interface/obsidian/src/chat_modal.ts index 765a43a8..328ce299 100644 --- a/src/interface/obsidian/src/chat_modal.ts +++ b/src/interface/obsidian/src/chat_modal.ts @@ -150,8 +150,13 @@ export class KhojChatModal extends Modal { renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string, inferredQueries?: string) { if (!message) { return; - } else if (intentType === "text-to-image") { - let imageMarkdown = `![](data:image/png;base64,${message})`; + } else if (intentType?.includes("text-to-image")) { + let imageMarkdown = ""; + if (intentType === "text-to-image") { + imageMarkdown = `![](data:image/png;base64,${message})`; + } else if (intentType === "text-to-image2") { + imageMarkdown = `![](${message})`; + } if (inferredQueries) { imageMarkdown += "\n\n**Inferred Query**:"; for (let inferredQuery of inferredQueries) { @@ -419,7 +424,12 @@ export class KhojChatModal extends Modal { try { const responseAsJson = await response.json() as ChatJsonResult; if (responseAsJson.image) { - responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`; + // If response has image field, response is a generated image. + if (responseAsJson.intentType === "text-to-image") { + responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`; + } else if (responseAsJson.intentType === "text-to-image2") { + responseText += `![${query}](${responseAsJson.image})`; + } const inferredQuery = responseAsJson.inferredQueries?.[0]; if (inferredQuery) { responseText += `\n\n**Inferred Query**:\n\n${inferredQuery}`; diff --git a/src/khoj/interface/web/assets/icons/copy_button.svg b/src/khoj/interface/web/assets/icons/copy_button.svg new file mode 100644 index 00000000..fadf344a --- /dev/null +++ b/src/khoj/interface/web/assets/icons/copy_button.svg @@ -0,0 +1 @@ + diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index accc83a5..6cb61d85 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -28,18 +28,17 @@ To get started, just start typing below. You can also type / to see a list of co let chatOptions = []; function copyProgrammaticOutput(event) { // Remove the first 4 characters which are the "Copy" button - const originalCopyText = event.target.parentNode.textContent.trim().slice(0, 4); - const programmaticOutput = event.target.parentNode.textContent.trim().slice(4); + const programmaticOutput = event.target.parentNode.textContent.trim(); navigator.clipboard.writeText(programmaticOutput).then(() => { event.target.textContent = "✅ Copied to clipboard!"; setTimeout(() => { - event.target.textContent = originalCopyText; + event.target.textContent = "✅"; }, 1000); }).catch((error) => { console.error("Error copying programmatic output to clipboard:", error); event.target.textContent = "⛔️ Failed to copy!"; setTimeout(() => { - event.target.textContent = originalCopyText; + event.target.textContent = "⛔️"; }, 1000); }); } @@ -211,8 +210,13 @@ To get started, just start typing below. You can also type / to see a list of co function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { - if (intentType === "text-to-image") { - let imageMarkdown = `![](data:image/png;base64,${message})`; + if (intentType?.includes("text-to-image")) { + let imageMarkdown; + if (intentType === "text-to-image") { + imageMarkdown = `![](data:image/png;base64,${message})`; + } else if (intentType === "text-to-image2") { + imageMarkdown = `![](${message})`; + } const inferredQuery = inferredQueries?.[0]; if (inferredQuery) { imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; @@ -274,8 +278,13 @@ To get started, just start typing below. You can also type / to see a list of co references.appendChild(referenceSection); - if (intentType === "text-to-image") { - let imageMarkdown = `![](data:image/png;base64,${message})`; + if (intentType?.includes("text-to-image")) { + let imageMarkdown; + if (intentType === "text-to-image") { + imageMarkdown = `![](data:image/png;base64,${message})`; + } else if (intentType === "text-to-image2") { + imageMarkdown = `![](${message})`; + } const inferredQuery = inferredQueries?.[0]; if (inferredQuery) { imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; @@ -326,7 +335,10 @@ To get started, just start typing below. You can also type / to see a list of co // Add a copy button to each element let copyButton = document.createElement('button'); copyButton.classList.add("copy-button"); - copyButton.innerHTML = "Copy"; + let copyIcon = document.createElement("img"); + copyIcon.src = "/static/assets/icons/copy_button.svg"; + copyIcon.classList.add("copy-icon"); + copyButton.appendChild(copyIcon); copyButton.addEventListener('click', copyProgrammaticOutput); codeElement.prepend(copyButton); }); @@ -427,9 +439,27 @@ To get started, just start typing below. You can also type / to see a list of co new_response.appendChild(newResponseText); // Temporary status message to indicate that Khoj is thinking - let loadingSpinner = document.createElement("div"); - loadingSpinner.classList.add("spinner"); - newResponseText.appendChild(loadingSpinner); + let loadingEllipsis = document.createElement("div"); + loadingEllipsis.classList.add("lds-ellipsis"); + + let firstEllipsis = document.createElement("div"); + firstEllipsis.classList.add("lds-ellipsis-item"); + + let secondEllipsis = document.createElement("div"); + secondEllipsis.classList.add("lds-ellipsis-item"); + + let thirdEllipsis = document.createElement("div"); + thirdEllipsis.classList.add("lds-ellipsis-item"); + + let fourthEllipsis = document.createElement("div"); + fourthEllipsis.classList.add("lds-ellipsis-item"); + + loadingEllipsis.appendChild(firstEllipsis); + loadingEllipsis.appendChild(secondEllipsis); + loadingEllipsis.appendChild(thirdEllipsis); + loadingEllipsis.appendChild(fourthEllipsis); + + newResponseText.appendChild(loadingEllipsis); document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; let chatTooltip = document.getElementById("chat-tooltip"); @@ -450,7 +480,11 @@ To get started, just start typing below. You can also type / to see a list of co const responseAsJson = await response.json(); if (responseAsJson.image) { // If response has image field, response is a generated image. - rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + if (responseAsJson.intentType === "text-to-image") { + rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; + } else if (responseAsJson.intentType === "text-to-image2") { + rawResponse += `![${query}](${responseAsJson.image})`; + } const inferredQuery = responseAsJson.inferredQueries?.[0]; if (inferredQuery) { rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; @@ -513,8 +547,8 @@ To get started, just start typing below. You can also type / to see a list of co readStream(); } else { // Display response from Khoj - if (newResponseText.getElementsByClassName("spinner").length > 0) { - newResponseText.removeChild(loadingSpinner); + if (newResponseText.getElementsByClassName("lds-ellipsis").length > 0) { + newResponseText.removeChild(loadingEllipsis); } // If the chunk is not a JSON object, just display it as is @@ -745,7 +779,9 @@ To get started, just start typing below. You can also type / to see a list of co // Create loading screen and add it to chat-body let loadingScreen = document.createElement('div'); - loadingScreen.classList.add('loading-screen', 'gradient-animation'); + loadingScreen.classList.add("loading-spinner"); + let yellowOrb = document.createElement('div'); + loadingScreen.appendChild(yellowOrb); chatBody.appendChild(loadingScreen); fetch(chatHistoryUrl, { method: "GET" }) @@ -792,8 +828,6 @@ To get started, just start typing below. You can also type / to see a list of co }); // Add fade out animation to loading screen and remove it after the animation ends - loadingScreen.classList.remove('gradient-animation'); - loadingScreen.classList.add('fade-out-animation'); chatBodyWrapperHeight = chatBodyWrapper.clientHeight; chatBody.style.height = chatBodyWrapperHeight; setTimeout(() => { @@ -1493,34 +1527,6 @@ To get started, just start typing below. You can also type / to see a list of co font-size: 2rem; color: #333; z-index: 9999; /* This is the important part */ - - /* Adding gradient effect */ - background: radial-gradient(circle, var(--primary-hover) 0%, var(--flower) 100%); - background-size: 200% 200%; - } - - div.loading-screen::after { - content: "Loading..."; - } - - .gradient-animation { - animation: gradient 2s ease infinite; - } - - @keyframes gradient { - 0% {background-position: 0% 50%;} - 50% {background-position: 100% 50%;} - 100% {background-position: 0% 50%;} - } - - .fade-out-animation { - animation-name: fadeOut; - animation-duration: 1.5s; - } - - @keyframes fadeOut { - from {opacity: 1;} - to {opacity: 0;} } /* add chat metatdata to bottom of bubble */ @@ -1757,13 +1763,32 @@ To get started, just start typing below. You can also type / to see a list of co } button.copy-button { - display: block; border-radius: 4px; background-color: var(--background-color); - } - button.copy-button:hover { - background: #f5f5f5; + border: 1px solid var(--main-text-color); + text-align: center; + font-size: 16px; + transition: all 0.5s; cursor: pointer; + padding: 4px; + float: right; + } + + button.copy-button span { + cursor: pointer; + display: inline-block; + position: relative; + transition: 0.5s; + } + + img.copy-icon { + width: 16px; + height: 16px; + } + + button.copy-button:hover { + background-color: black; + color: #f5f5f5; } pre { @@ -1997,5 +2022,104 @@ To get started, just start typing below. You can also type / to see a list of co padding: 10px; white-space: pre-wrap; } + + .loading-spinner { + display: inline-block; + position: relative; + width: 80px; + height: 80px; + } + + .loading-spinner div { + position: absolute; + border: 4px solid var(--primary-hover); + opacity: 1; + border-radius: 50%; + animation: lds-ripple 0.5s cubic-bezier(0, 0.2, 0.8, 1) infinite; + } + + .loading-spinner div:nth-child(2) { + animation-delay: -0.5s; + } + + @keyframes lds-ripple { + 0% { + top: 36px; + left: 36px; + width: 0; + height: 0; + opacity: 1; + border-color: var(--primary-hover); + } + 50% { + border-color: var(--flower); + } + 100% { + top: 0px; + left: 0px; + width: 72px; + height: 72px; + opacity: 0; + border-color: var(--water); + } + } + + .lds-ellipsis { + display: inline-block; + position: relative; + width: 60px; + height: 32px; + } + .lds-ellipsis div { + position: absolute; + top: 12px; + width: 12px; + height: 12px; + border-radius: 50%; + background: var(--main-text-color); + animation-timing-function: cubic-bezier(0, 1, 1, 0); + } + .lds-ellipsis div:nth-child(1) { + left: 8px; + animation: lds-ellipsis1 0.6s infinite; + } + .lds-ellipsis div:nth-child(2) { + left: 8px; + animation: lds-ellipsis2 0.6s infinite; + } + .lds-ellipsis div:nth-child(3) { + left: 32px; + animation: lds-ellipsis2 0.6s infinite; + } + .lds-ellipsis div:nth-child(4) { + left: 56px; + animation: lds-ellipsis3 0.6s infinite; + } + @keyframes lds-ellipsis1 { + 0% { + transform: scale(0); + } + 100% { + transform: scale(1); + } + } + @keyframes lds-ellipsis3 { + 0% { + transform: scale(1); + } + 100% { + transform: scale(0); + } + } + @keyframes lds-ellipsis2 { + 0% { + transform: translate(0, 0); + } + 100% { + transform: translate(24px, 0); + } + } + + diff --git a/src/khoj/processor/content/notion/notion_to_entries.py b/src/khoj/processor/content/notion/notion_to_entries.py index 541a1732..d8d40689 100644 --- a/src/khoj/processor/content/notion/notion_to_entries.py +++ b/src/khoj/processor/content/notion/notion_to_entries.py @@ -234,7 +234,7 @@ class NotionToEntries(TextToEntries): elif "Event" in properties: title_field = "Event" elif title_field not in properties: - logger.warning(f"Title field not found for page {page_id}. Setting title as None...") + logger.debug(f"Title field not found for page {page_id}. Setting title as None...") title = None return title, content try: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index a1a439dd..35fc03d7 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -128,7 +128,7 @@ def save_to_conversation_log( Saved Conversation Turn You ({user.username}): "{q}" -Khoj: "{inferred_queries if intent_type == "text-to-image" else chat_response}" +Khoj: "{inferred_queries if ("text_to_image" in intent_type) else chat_response}" """.strip() ) diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 093e8cc9..3c170f6f 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -300,25 +300,30 @@ async def chat( metadata={"conversation_command": conversation_commands[0].value}, **common.__dict__, ) - image, status_code, improved_image_prompt = await text_to_image( - q, meta_log, location_data=location, references=compiled_references, online_results=online_results + intent_type = "text-to-image" + image, status_code, improved_image_prompt, image_url = await text_to_image( + q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results ) if image is None: - content_obj = {"image": image, "intentType": "text-to-image", "detail": improved_image_prompt} + content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt} return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) + + if image_url: + intent_type = "text-to-image2" + image = image_url await sync_to_async(save_to_conversation_log)( q, image, user, meta_log, - intent_type="text-to-image", + intent_type=intent_type, inferred_queries=[improved_image_prompt], client_application=request.user.client_app, conversation_id=conversation_id, compiled_references=compiled_references, online_results=online_results, ) - content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore + content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) # Get the (streamed) chat response from the LLM of choice. diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index f24ff688..25709ecf 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -4,11 +4,9 @@ import logging from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone from functools import partial -from time import time from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union import openai -import requests from fastapi import Depends, Header, HTTPException, Request, UploadFile from starlette.authentication import has_required_scope @@ -32,6 +30,7 @@ from khoj.processor.conversation.utils import ( generate_chatml_messages_with_context, save_to_conversation_log, ) +from khoj.routers.storage import upload_image from khoj.utils import state from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.helpers import ( @@ -39,6 +38,7 @@ from khoj.utils.helpers import ( is_none_or_empty, log_telemetry, mode_descriptions_for_llm, + timer, tool_descriptions_for_llm, ) from khoj.utils.rawconfig import LocationData @@ -439,53 +439,65 @@ def generate_chat_response( async def text_to_image( message: str, + user: KhojUser, conversation_log: dict, location_data: LocationData, references: List[str], online_results: Dict[str, Any], -) -> Tuple[Optional[str], int, Optional[str]]: +) -> Tuple[Optional[str], int, Optional[str], Optional[str]]: status_code = 200 image = None response = None + image_url = None text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() if not text_to_image_config: # If the user has not configured a text to image model, return an unsupported on server error status_code = 501 message = "Failed to generate image. Setup image generation on the server." - return image, status_code, message + return image, status_code, message, image_url elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: + logger.info("Generating image with OpenAI") text2image_model = text_to_image_config.model_name chat_history = "" for chat in conversation_log.get("chat", [])[-4:]: if chat["by"] == "khoj" and chat["intent"].get("type") == "remember": chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"A: {chat['message']}\n" - improved_image_prompt = await generate_better_image_prompt( - message, - chat_history, - location_data=location_data, - note_references=references, - online_results=online_results, - ) - try: - response = state.openai_client.images.generate( - prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" + elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): + chat_history += f"Q: {chat['intent']['query']}\n" + chat_history += f"A: [generated image redacted by admin]. Enhanced image prompt: {chat['intent']['inferred-queries'][0]}\n" + + with timer("Improve the original user query", logger): + improved_image_prompt = await generate_better_image_prompt( + message, + chat_history, + location_data=location_data, + note_references=references, + online_results=online_results, ) - image = response.data[0].b64_json + try: + with timer("Generate image with OpenAI", logger): + response = state.openai_client.images.generate( + prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" + ) + image = response.data[0].b64_json + + with timer("Upload image to S3", logger): + image_url = upload_image(image, user.uuid) + return image, status_code, improved_image_prompt, image_url except openai.OpenAIError or openai.BadRequestError as e: if "content_policy_violation" in e.message: logger.error(f"Image Generation blocked by OpenAI: {e}") - status_code = e.status_code - message = f"Image generation blocked by OpenAI: {e.message}" - return image, status_code, message + status_code = e.status_code # type: ignore + message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore + return image, status_code, message, image_url else: logger.error(f"Image Generation failed with {e}", exc_info=True) - message = f"Image generation failed with OpenAI error: {e.message}" - status_code = e.status_code - return image, status_code, message - - return image, status_code, improved_image_prompt + message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore + status_code = e.status_code # type: ignore + return image, status_code, message, image_url + return image, status_code, response, image_url class ApiUserRateLimiter: diff --git a/src/khoj/routers/storage.py b/src/khoj/routers/storage.py new file mode 100644 index 00000000..57c28c5a --- /dev/null +++ b/src/khoj/routers/storage.py @@ -0,0 +1,34 @@ +import base64 +import logging +import os +import uuid + +logger = logging.getLogger(__name__) + +AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY") +AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY") +AWS_UPLOAD_IMAGE_BUCKET_NAME = os.getenv("AWS_IMAGE_UPLOAD_BUCKET") + +aws_enabled = AWS_ACCESS_KEY is not None and AWS_SECRET_KEY is not None and AWS_UPLOAD_IMAGE_BUCKET_NAME is not None + +if aws_enabled: + from boto3 import client + + s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY) + + +def upload_image(image: str, user_id: uuid.UUID): + """Upload the image to the S3 bucket""" + if not aws_enabled: + logger.info("AWS is not enabled. Skipping image upload") + return None + + decoded_image = base64.b64decode(image) + image_key = f"{user_id}/{uuid.uuid4()}.png" + try: + s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=decoded_image, ACL="public-read") + url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}" + return url + except Exception as e: + logger.error(f"Failed to upload image to S3: {e}") + return None