Upload generated images to s3, if AWS credentials and bucket is available (#667)

* Upload generated images to s3, if AWS credentials and bucket is available.
- In clients, render the images via the URL if it's returned with a text-to-image2 intent type
* Make the loading screen more intuitve, less jerky and update the programmatic copy button
* Update the loading icon when waiting for a chat response
This commit is contained in:
sabaimran 2024-03-08 10:54:13 +05:30 committed by GitHub
parent 13894e1fd5
commit 81beb7940c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 392 additions and 126 deletions

View file

@ -93,6 +93,7 @@ prod = [
"google-auth == 2.23.3",
"stripe == 7.3.0",
"twilio == 8.11",
"boto3 >= 1.34.57",
]
dev = [
"khoj-assistant[prod]",

View file

@ -198,8 +198,14 @@
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
if (intentType?.includes("text-to-image")) {
let imageMarkdown;
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
}
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@ -266,8 +272,13 @@
references.appendChild(referenceSection);
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
if (intentType?.includes("text-to-image")) {
let imageMarkdown;
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
}
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@ -423,9 +434,27 @@
new_response.appendChild(newResponseText);
// Temporary status message to indicate that Khoj is thinking
let loadingSpinner = document.createElement("div");
loadingSpinner.classList.add("spinner");
newResponseText.appendChild(loadingSpinner);
let loadingEllipsis = document.createElement("div");
loadingEllipsis.classList.add("lds-ellipsis");
let firstEllipsis = document.createElement("div");
firstEllipsis.classList.add("lds-ellipsis-item");
let secondEllipsis = document.createElement("div");
secondEllipsis.classList.add("lds-ellipsis-item");
let thirdEllipsis = document.createElement("div");
thirdEllipsis.classList.add("lds-ellipsis-item");
let fourthEllipsis = document.createElement("div");
fourthEllipsis.classList.add("lds-ellipsis-item");
loadingEllipsis.appendChild(firstEllipsis);
loadingEllipsis.appendChild(secondEllipsis);
loadingEllipsis.appendChild(thirdEllipsis);
loadingEllipsis.appendChild(fourthEllipsis);
newResponseText.appendChild(loadingEllipsis);
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
let chatTooltip = document.getElementById("chat-tooltip");
@ -446,7 +475,11 @@
const responseAsJson = await response.json();
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
if (responseAsJson.intentType === "text-to-image") {
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image2") {
rawResponse += `![${query}](${responseAsJson.image})`;
}
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`;
@ -509,40 +542,16 @@
readStream();
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
if (newResponseText.getElementsByClassName("lds-ellipsis").length > 0) {
newResponseText.removeChild(loadingEllipsis);
}
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
const inferredQuery = responseAsJson.inferredQueries?.[0];
if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
}
if (responseAsJson.detail) {
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
readStream();
}
// Scroll to bottom of chat window as chat response is streamed
@ -1575,13 +1584,27 @@
}
button.copy-button {
display: block;
border-radius: 4px;
background-color: var(--background-color);
}
button.copy-button:hover {
background: #f5f5f5;
border: 1px solid var(--main-text-color);
text-align: center;
font-size: 16px;
transition: all 0.5s;
cursor: pointer;
padding: 4px;
float: right;
}
button.copy-button span {
cursor: pointer;
display: inline-block;
position: relative;
transition: 0.5s;
}
button.copy-button:hover {
background-color: black;
color: #f5f5f5;
}
pre {
@ -1815,5 +1838,61 @@
padding: 10px;
white-space: pre-wrap;
}
.lds-ellipsis {
display: inline-block;
position: relative;
width: 60px;
height: 32px;
}
.lds-ellipsis div {
position: absolute;
top: 12px;
width: 12px;
height: 12px;
border-radius: 50%;
background: var(--main-text-color);
animation-timing-function: cubic-bezier(0, 1, 1, 0);
}
.lds-ellipsis div:nth-child(1) {
left: 8px;
animation: lds-ellipsis1 0.6s infinite;
}
.lds-ellipsis div:nth-child(2) {
left: 8px;
animation: lds-ellipsis2 0.6s infinite;
}
.lds-ellipsis div:nth-child(3) {
left: 32px;
animation: lds-ellipsis2 0.6s infinite;
}
.lds-ellipsis div:nth-child(4) {
left: 56px;
animation: lds-ellipsis3 0.6s infinite;
}
@keyframes lds-ellipsis1 {
0% {
transform: scale(0);
}
100% {
transform: scale(1);
}
}
@keyframes lds-ellipsis3 {
0% {
transform: scale(1);
}
100% {
transform: scale(0);
}
}
@keyframes lds-ellipsis2 {
0% {
transform: translate(0, 0);
}
100% {
transform: translate(24px, 0);
}
}
</style>
</html>

View file

@ -150,8 +150,13 @@ export class KhojChatModal extends Modal {
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string, inferredQueries?: string) {
if (!message) {
return;
} else if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType?.includes("text-to-image")) {
let imageMarkdown = "";
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
}
if (inferredQueries) {
imageMarkdown += "\n\n**Inferred Query**:";
for (let inferredQuery of inferredQueries) {
@ -419,7 +424,12 @@ export class KhojChatModal extends Modal {
try {
const responseAsJson = await response.json() as ChatJsonResult;
if (responseAsJson.image) {
responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`;
// If response has image field, response is a generated image.
if (responseAsJson.intentType === "text-to-image") {
responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image2") {
responseText += `![${query}](${responseAsJson.image})`;
}
const inferredQuery = responseAsJson.inferredQueries?.[0];
if (inferredQuery) {
responseText += `\n\n**Inferred Query**:\n\n${inferredQuery}`;

View file

@ -0,0 +1 @@
<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"> <path d="M6 11C6 8.17157 6 6.75736 6.87868 5.87868C7.75736 5 9.17157 5 12 5H15C17.8284 5 19.2426 5 20.1213 5.87868C21 6.75736 21 8.17157 21 11V16C21 18.8284 21 20.2426 20.1213 21.1213C19.2426 22 17.8284 22 15 22H12C9.17157 22 7.75736 22 6.87868 21.1213C6 20.2426 6 18.8284 6 16V11Z" stroke="#1C274C" stroke-width="1.5"></path> <path d="M6 19C4.34315 19 3 17.6569 3 16V10C3 6.22876 3 4.34315 4.17157 3.17157C5.34315 2 7.22876 2 11 2H15C16.6569 2 18 3.34315 18 5" stroke="#1C274C" stroke-width="1.5"></path> </g></svg>

After

Width:  |  Height:  |  Size: 746 B

View file

@ -28,18 +28,17 @@ To get started, just start typing below. You can also type / to see a list of co
let chatOptions = [];
function copyProgrammaticOutput(event) {
// Remove the first 4 characters which are the "Copy" button
const originalCopyText = event.target.parentNode.textContent.trim().slice(0, 4);
const programmaticOutput = event.target.parentNode.textContent.trim().slice(4);
const programmaticOutput = event.target.parentNode.textContent.trim();
navigator.clipboard.writeText(programmaticOutput).then(() => {
event.target.textContent = "✅ Copied to clipboard!";
setTimeout(() => {
event.target.textContent = originalCopyText;
event.target.textContent = "✅";
}, 1000);
}).catch((error) => {
console.error("Error copying programmatic output to clipboard:", error);
event.target.textContent = "⛔️ Failed to copy!";
setTimeout(() => {
event.target.textContent = originalCopyText;
event.target.textContent = "⛔️";
}, 1000);
});
}
@ -211,8 +210,13 @@ To get started, just start typing below. You can also type / to see a list of co
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
if (intentType?.includes("text-to-image")) {
let imageMarkdown;
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
}
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@ -274,8 +278,13 @@ To get started, just start typing below. You can also type / to see a list of co
references.appendChild(referenceSection);
if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`;
if (intentType?.includes("text-to-image")) {
let imageMarkdown;
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
}
const inferredQuery = inferredQueries?.[0];
if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@ -326,7 +335,10 @@ To get started, just start typing below. You can also type / to see a list of co
// Add a copy button to each element
let copyButton = document.createElement('button');
copyButton.classList.add("copy-button");
copyButton.innerHTML = "Copy";
let copyIcon = document.createElement("img");
copyIcon.src = "/static/assets/icons/copy_button.svg";
copyIcon.classList.add("copy-icon");
copyButton.appendChild(copyIcon);
copyButton.addEventListener('click', copyProgrammaticOutput);
codeElement.prepend(copyButton);
});
@ -427,9 +439,27 @@ To get started, just start typing below. You can also type / to see a list of co
new_response.appendChild(newResponseText);
// Temporary status message to indicate that Khoj is thinking
let loadingSpinner = document.createElement("div");
loadingSpinner.classList.add("spinner");
newResponseText.appendChild(loadingSpinner);
let loadingEllipsis = document.createElement("div");
loadingEllipsis.classList.add("lds-ellipsis");
let firstEllipsis = document.createElement("div");
firstEllipsis.classList.add("lds-ellipsis-item");
let secondEllipsis = document.createElement("div");
secondEllipsis.classList.add("lds-ellipsis-item");
let thirdEllipsis = document.createElement("div");
thirdEllipsis.classList.add("lds-ellipsis-item");
let fourthEllipsis = document.createElement("div");
fourthEllipsis.classList.add("lds-ellipsis-item");
loadingEllipsis.appendChild(firstEllipsis);
loadingEllipsis.appendChild(secondEllipsis);
loadingEllipsis.appendChild(thirdEllipsis);
loadingEllipsis.appendChild(fourthEllipsis);
newResponseText.appendChild(loadingEllipsis);
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
let chatTooltip = document.getElementById("chat-tooltip");
@ -450,7 +480,11 @@ To get started, just start typing below. You can also type / to see a list of co
const responseAsJson = await response.json();
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
if (responseAsJson.intentType === "text-to-image") {
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image2") {
rawResponse += `![${query}](${responseAsJson.image})`;
}
const inferredQuery = responseAsJson.inferredQueries?.[0];
if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@ -513,8 +547,8 @@ To get started, just start typing below. You can also type / to see a list of co
readStream();
} else {
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
if (newResponseText.getElementsByClassName("lds-ellipsis").length > 0) {
newResponseText.removeChild(loadingEllipsis);
}
// If the chunk is not a JSON object, just display it as is
@ -745,7 +779,9 @@ To get started, just start typing below. You can also type / to see a list of co
// Create loading screen and add it to chat-body
let loadingScreen = document.createElement('div');
loadingScreen.classList.add('loading-screen', 'gradient-animation');
loadingScreen.classList.add("loading-spinner");
let yellowOrb = document.createElement('div');
loadingScreen.appendChild(yellowOrb);
chatBody.appendChild(loadingScreen);
fetch(chatHistoryUrl, { method: "GET" })
@ -792,8 +828,6 @@ To get started, just start typing below. You can also type / to see a list of co
});
// Add fade out animation to loading screen and remove it after the animation ends
loadingScreen.classList.remove('gradient-animation');
loadingScreen.classList.add('fade-out-animation');
chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
chatBody.style.height = chatBodyWrapperHeight;
setTimeout(() => {
@ -1493,34 +1527,6 @@ To get started, just start typing below. You can also type / to see a list of co
font-size: 2rem;
color: #333;
z-index: 9999; /* This is the important part */
/* Adding gradient effect */
background: radial-gradient(circle, var(--primary-hover) 0%, var(--flower) 100%);
background-size: 200% 200%;
}
div.loading-screen::after {
content: "Loading...";
}
.gradient-animation {
animation: gradient 2s ease infinite;
}
@keyframes gradient {
0% {background-position: 0% 50%;}
50% {background-position: 100% 50%;}
100% {background-position: 0% 50%;}
}
.fade-out-animation {
animation-name: fadeOut;
animation-duration: 1.5s;
}
@keyframes fadeOut {
from {opacity: 1;}
to {opacity: 0;}
}
/* add chat metatdata to bottom of bubble */
@ -1757,13 +1763,32 @@ To get started, just start typing below. You can also type / to see a list of co
}
button.copy-button {
display: block;
border-radius: 4px;
background-color: var(--background-color);
}
button.copy-button:hover {
background: #f5f5f5;
border: 1px solid var(--main-text-color);
text-align: center;
font-size: 16px;
transition: all 0.5s;
cursor: pointer;
padding: 4px;
float: right;
}
button.copy-button span {
cursor: pointer;
display: inline-block;
position: relative;
transition: 0.5s;
}
img.copy-icon {
width: 16px;
height: 16px;
}
button.copy-button:hover {
background-color: black;
color: #f5f5f5;
}
pre {
@ -1997,5 +2022,104 @@ To get started, just start typing below. You can also type / to see a list of co
padding: 10px;
white-space: pre-wrap;
}
.loading-spinner {
display: inline-block;
position: relative;
width: 80px;
height: 80px;
}
.loading-spinner div {
position: absolute;
border: 4px solid var(--primary-hover);
opacity: 1;
border-radius: 50%;
animation: lds-ripple 0.5s cubic-bezier(0, 0.2, 0.8, 1) infinite;
}
.loading-spinner div:nth-child(2) {
animation-delay: -0.5s;
}
@keyframes lds-ripple {
0% {
top: 36px;
left: 36px;
width: 0;
height: 0;
opacity: 1;
border-color: var(--primary-hover);
}
50% {
border-color: var(--flower);
}
100% {
top: 0px;
left: 0px;
width: 72px;
height: 72px;
opacity: 0;
border-color: var(--water);
}
}
.lds-ellipsis {
display: inline-block;
position: relative;
width: 60px;
height: 32px;
}
.lds-ellipsis div {
position: absolute;
top: 12px;
width: 12px;
height: 12px;
border-radius: 50%;
background: var(--main-text-color);
animation-timing-function: cubic-bezier(0, 1, 1, 0);
}
.lds-ellipsis div:nth-child(1) {
left: 8px;
animation: lds-ellipsis1 0.6s infinite;
}
.lds-ellipsis div:nth-child(2) {
left: 8px;
animation: lds-ellipsis2 0.6s infinite;
}
.lds-ellipsis div:nth-child(3) {
left: 32px;
animation: lds-ellipsis2 0.6s infinite;
}
.lds-ellipsis div:nth-child(4) {
left: 56px;
animation: lds-ellipsis3 0.6s infinite;
}
@keyframes lds-ellipsis1 {
0% {
transform: scale(0);
}
100% {
transform: scale(1);
}
}
@keyframes lds-ellipsis3 {
0% {
transform: scale(1);
}
100% {
transform: scale(0);
}
}
@keyframes lds-ellipsis2 {
0% {
transform: translate(0, 0);
}
100% {
transform: translate(24px, 0);
}
}
</style>
</html>

View file

@ -234,7 +234,7 @@ class NotionToEntries(TextToEntries):
elif "Event" in properties:
title_field = "Event"
elif title_field not in properties:
logger.warning(f"Title field not found for page {page_id}. Setting title as None...")
logger.debug(f"Title field not found for page {page_id}. Setting title as None...")
title = None
return title, content
try:

View file

@ -128,7 +128,7 @@ def save_to_conversation_log(
Saved Conversation Turn
You ({user.username}): "{q}"
Khoj: "{inferred_queries if intent_type == "text-to-image" else chat_response}"
Khoj: "{inferred_queries if ("text_to_image" in intent_type) else chat_response}"
""".strip()
)

View file

@ -300,25 +300,30 @@ async def chat(
metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__,
)
image, status_code, improved_image_prompt = await text_to_image(
q, meta_log, location_data=location, references=compiled_references, online_results=online_results
intent_type = "text-to-image"
image, status_code, improved_image_prompt, image_url = await text_to_image(
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
)
if image is None:
content_obj = {"image": image, "intentType": "text-to-image", "detail": improved_image_prompt}
content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
if image_url:
intent_type = "text-to-image2"
image = image_url
await sync_to_async(save_to_conversation_log)(
q,
image,
user,
meta_log,
intent_type="text-to-image",
intent_type=intent_type,
inferred_queries=[improved_image_prompt],
client_application=request.user.client_app,
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
)
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice.

View file

@ -4,11 +4,9 @@ import logging
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from functools import partial
from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
import openai
import requests
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from starlette.authentication import has_required_scope
@ -32,6 +30,7 @@ from khoj.processor.conversation.utils import (
generate_chatml_messages_with_context,
save_to_conversation_log,
)
from khoj.routers.storage import upload_image
from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import (
@ -39,6 +38,7 @@ from khoj.utils.helpers import (
is_none_or_empty,
log_telemetry,
mode_descriptions_for_llm,
timer,
tool_descriptions_for_llm,
)
from khoj.utils.rawconfig import LocationData
@ -439,53 +439,65 @@ def generate_chat_response(
async def text_to_image(
message: str,
user: KhojUser,
conversation_log: dict,
location_data: LocationData,
references: List[str],
online_results: Dict[str, Any],
) -> Tuple[Optional[str], int, Optional[str]]:
) -> Tuple[Optional[str], int, Optional[str], Optional[str]]:
status_code = 200
image = None
response = None
image_url = None
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
return image, status_code, message
return image, status_code, message, image_url
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
logger.info("Generating image with OpenAI")
text2image_model = text_to_image_config.model_name
chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
improved_image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
)
try:
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: [generated image redacted by admin]. Enhanced image prompt: {chat['intent']['inferred-queries'][0]}\n"
with timer("Improve the original user query", logger):
improved_image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
)
image = response.data[0].b64_json
try:
with timer("Generate image with OpenAI", logger):
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
)
image = response.data[0].b64_json
with timer("Upload image to S3", logger):
image_url = upload_image(image, user.uuid)
return image, status_code, improved_image_prompt, image_url
except openai.OpenAIError or openai.BadRequestError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code
message = f"Image generation blocked by OpenAI: {e.message}"
return image, status_code, message
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image, status_code, message, image_url
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}"
status_code = e.status_code
return image, status_code, message
return image, status_code, improved_image_prompt
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
return image, status_code, message, image_url
return image, status_code, response, image_url
class ApiUserRateLimiter:

View file

@ -0,0 +1,34 @@
import base64
import logging
import os
import uuid
logger = logging.getLogger(__name__)
AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY")
AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY")
AWS_UPLOAD_IMAGE_BUCKET_NAME = os.getenv("AWS_IMAGE_UPLOAD_BUCKET")
aws_enabled = AWS_ACCESS_KEY is not None and AWS_SECRET_KEY is not None and AWS_UPLOAD_IMAGE_BUCKET_NAME is not None
if aws_enabled:
from boto3 import client
s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
def upload_image(image: str, user_id: uuid.UUID):
"""Upload the image to the S3 bucket"""
if not aws_enabled:
logger.info("AWS is not enabled. Skipping image upload")
return None
decoded_image = base64.b64decode(image)
image_key = f"{user_id}/{uuid.uuid4()}.png"
try:
s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=decoded_image, ACL="public-read")
url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}"
return url
except Exception as e:
logger.error(f"Failed to upload image to S3: {e}")
return None