Upload generated images to s3, if AWS credentials and bucket is available (#667)

* Upload generated images to s3, if AWS credentials and bucket is available.
- In clients, render the images via the URL if it's returned with a text-to-image2 intent type
* Make the loading screen more intuitve, less jerky and update the programmatic copy button
* Update the loading icon when waiting for a chat response
This commit is contained in:
sabaimran 2024-03-08 10:54:13 +05:30 committed by GitHub
parent 13894e1fd5
commit 81beb7940c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 392 additions and 126 deletions

View file

@ -93,6 +93,7 @@ prod = [
"google-auth == 2.23.3", "google-auth == 2.23.3",
"stripe == 7.3.0", "stripe == 7.3.0",
"twilio == 8.11", "twilio == 8.11",
"boto3 >= 1.34.57",
] ]
dev = [ dev = [
"khoj-assistant[prod]", "khoj-assistant[prod]",

View file

@ -198,8 +198,14 @@
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
if (intentType === "text-to-image") { if (intentType?.includes("text-to-image")) {
let imageMarkdown = `![](data:image/png;base64,${message})`; let imageMarkdown;
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
}
const inferredQuery = inferredQueries?.[0]; const inferredQuery = inferredQueries?.[0];
if (inferredQuery) { if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@ -266,8 +272,13 @@
references.appendChild(referenceSection); references.appendChild(referenceSection);
if (intentType === "text-to-image") { if (intentType?.includes("text-to-image")) {
let imageMarkdown = `![](data:image/png;base64,${message})`; let imageMarkdown;
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
}
const inferredQuery = inferredQueries?.[0]; const inferredQuery = inferredQueries?.[0];
if (inferredQuery) { if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@ -423,9 +434,27 @@
new_response.appendChild(newResponseText); new_response.appendChild(newResponseText);
// Temporary status message to indicate that Khoj is thinking // Temporary status message to indicate that Khoj is thinking
let loadingSpinner = document.createElement("div"); let loadingEllipsis = document.createElement("div");
loadingSpinner.classList.add("spinner"); loadingEllipsis.classList.add("lds-ellipsis");
newResponseText.appendChild(loadingSpinner);
let firstEllipsis = document.createElement("div");
firstEllipsis.classList.add("lds-ellipsis-item");
let secondEllipsis = document.createElement("div");
secondEllipsis.classList.add("lds-ellipsis-item");
let thirdEllipsis = document.createElement("div");
thirdEllipsis.classList.add("lds-ellipsis-item");
let fourthEllipsis = document.createElement("div");
fourthEllipsis.classList.add("lds-ellipsis-item");
loadingEllipsis.appendChild(firstEllipsis);
loadingEllipsis.appendChild(secondEllipsis);
loadingEllipsis.appendChild(thirdEllipsis);
loadingEllipsis.appendChild(fourthEllipsis);
newResponseText.appendChild(loadingEllipsis);
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
let chatTooltip = document.getElementById("chat-tooltip"); let chatTooltip = document.getElementById("chat-tooltip");
@ -446,7 +475,11 @@
const responseAsJson = await response.json(); const responseAsJson = await response.json();
if (responseAsJson.image) { if (responseAsJson.image) {
// If response has image field, response is a generated image. // If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; if (responseAsJson.intentType === "text-to-image") {
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image2") {
rawResponse += `![${query}](${responseAsJson.image})`;
}
const inferredQueries = responseAsJson.inferredQueries?.[0]; const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) { if (inferredQueries) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`; rawResponse += `\n\n**Inferred Query**:\n\n${inferredQueries}`;
@ -509,40 +542,16 @@
readStream(); readStream();
} else { } else {
// Display response from Khoj // Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) { if (newResponseText.getElementsByClassName("lds-ellipsis").length > 0) {
newResponseText.removeChild(loadingSpinner); newResponseText.removeChild(loadingEllipsis);
} }
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. // If the chunk is not a JSON object, just display it as is
if (chunk.startsWith("{") && chunk.endsWith("}")) { rawResponse += chunk;
try { newResponseText.innerHTML = "";
const responseAsJson = JSON.parse(chunk); newResponseText.appendChild(formatHTMLMessage(rawResponse));
if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
const inferredQuery = responseAsJson.inferredQueries?.[0];
if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
}
}
if (responseAsJson.detail) {
rawResponse += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
} finally {
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream(); readStream();
}
} }
// Scroll to bottom of chat window as chat response is streamed // Scroll to bottom of chat window as chat response is streamed
@ -1575,13 +1584,27 @@
} }
button.copy-button { button.copy-button {
display: block;
border-radius: 4px; border-radius: 4px;
background-color: var(--background-color); background-color: var(--background-color);
} border: 1px solid var(--main-text-color);
button.copy-button:hover { text-align: center;
background: #f5f5f5; font-size: 16px;
transition: all 0.5s;
cursor: pointer; cursor: pointer;
padding: 4px;
float: right;
}
button.copy-button span {
cursor: pointer;
display: inline-block;
position: relative;
transition: 0.5s;
}
button.copy-button:hover {
background-color: black;
color: #f5f5f5;
} }
pre { pre {
@ -1815,5 +1838,61 @@
padding: 10px; padding: 10px;
white-space: pre-wrap; white-space: pre-wrap;
} }
.lds-ellipsis {
display: inline-block;
position: relative;
width: 60px;
height: 32px;
}
.lds-ellipsis div {
position: absolute;
top: 12px;
width: 12px;
height: 12px;
border-radius: 50%;
background: var(--main-text-color);
animation-timing-function: cubic-bezier(0, 1, 1, 0);
}
.lds-ellipsis div:nth-child(1) {
left: 8px;
animation: lds-ellipsis1 0.6s infinite;
}
.lds-ellipsis div:nth-child(2) {
left: 8px;
animation: lds-ellipsis2 0.6s infinite;
}
.lds-ellipsis div:nth-child(3) {
left: 32px;
animation: lds-ellipsis2 0.6s infinite;
}
.lds-ellipsis div:nth-child(4) {
left: 56px;
animation: lds-ellipsis3 0.6s infinite;
}
@keyframes lds-ellipsis1 {
0% {
transform: scale(0);
}
100% {
transform: scale(1);
}
}
@keyframes lds-ellipsis3 {
0% {
transform: scale(1);
}
100% {
transform: scale(0);
}
}
@keyframes lds-ellipsis2 {
0% {
transform: translate(0, 0);
}
100% {
transform: translate(24px, 0);
}
}
</style> </style>
</html> </html>

View file

@ -150,8 +150,13 @@ export class KhojChatModal extends Modal {
renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string, inferredQueries?: string) { renderMessageWithReferences(chatEl: Element, message: string, sender: string, context?: string[], dt?: Date, intentType?: string, inferredQueries?: string) {
if (!message) { if (!message) {
return; return;
} else if (intentType === "text-to-image") { } else if (intentType?.includes("text-to-image")) {
let imageMarkdown = `![](data:image/png;base64,${message})`; let imageMarkdown = "";
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
}
if (inferredQueries) { if (inferredQueries) {
imageMarkdown += "\n\n**Inferred Query**:"; imageMarkdown += "\n\n**Inferred Query**:";
for (let inferredQuery of inferredQueries) { for (let inferredQuery of inferredQueries) {
@ -419,7 +424,12 @@ export class KhojChatModal extends Modal {
try { try {
const responseAsJson = await response.json() as ChatJsonResult; const responseAsJson = await response.json() as ChatJsonResult;
if (responseAsJson.image) { if (responseAsJson.image) {
responseText = `![${query}](data:image/png;base64,${responseAsJson.image})`; // If response has image field, response is a generated image.
if (responseAsJson.intentType === "text-to-image") {
responseText += `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image2") {
responseText += `![${query}](${responseAsJson.image})`;
}
const inferredQuery = responseAsJson.inferredQueries?.[0]; const inferredQuery = responseAsJson.inferredQueries?.[0];
if (inferredQuery) { if (inferredQuery) {
responseText += `\n\n**Inferred Query**:\n\n${inferredQuery}`; responseText += `\n\n**Inferred Query**:\n\n${inferredQuery}`;

View file

@ -0,0 +1 @@
<svg viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"> <path d="M6 11C6 8.17157 6 6.75736 6.87868 5.87868C7.75736 5 9.17157 5 12 5H15C17.8284 5 19.2426 5 20.1213 5.87868C21 6.75736 21 8.17157 21 11V16C21 18.8284 21 20.2426 20.1213 21.1213C19.2426 22 17.8284 22 15 22H12C9.17157 22 7.75736 22 6.87868 21.1213C6 20.2426 6 18.8284 6 16V11Z" stroke="#1C274C" stroke-width="1.5"></path> <path d="M6 19C4.34315 19 3 17.6569 3 16V10C3 6.22876 3 4.34315 4.17157 3.17157C5.34315 2 7.22876 2 11 2H15C16.6569 2 18 3.34315 18 5" stroke="#1C274C" stroke-width="1.5"></path> </g></svg>

After

Width:  |  Height:  |  Size: 746 B

View file

@ -28,18 +28,17 @@ To get started, just start typing below. You can also type / to see a list of co
let chatOptions = []; let chatOptions = [];
function copyProgrammaticOutput(event) { function copyProgrammaticOutput(event) {
// Remove the first 4 characters which are the "Copy" button // Remove the first 4 characters which are the "Copy" button
const originalCopyText = event.target.parentNode.textContent.trim().slice(0, 4); const programmaticOutput = event.target.parentNode.textContent.trim();
const programmaticOutput = event.target.parentNode.textContent.trim().slice(4);
navigator.clipboard.writeText(programmaticOutput).then(() => { navigator.clipboard.writeText(programmaticOutput).then(() => {
event.target.textContent = "✅ Copied to clipboard!"; event.target.textContent = "✅ Copied to clipboard!";
setTimeout(() => { setTimeout(() => {
event.target.textContent = originalCopyText; event.target.textContent = "✅";
}, 1000); }, 1000);
}).catch((error) => { }).catch((error) => {
console.error("Error copying programmatic output to clipboard:", error); console.error("Error copying programmatic output to clipboard:", error);
event.target.textContent = "⛔️ Failed to copy!"; event.target.textContent = "⛔️ Failed to copy!";
setTimeout(() => { setTimeout(() => {
event.target.textContent = originalCopyText; event.target.textContent = "⛔️";
}, 1000); }, 1000);
}); });
} }
@ -211,8 +210,13 @@ To get started, just start typing below. You can also type / to see a list of co
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) { if ((context == null || context.length == 0) && (onlineContext == null || (onlineContext && Object.keys(onlineContext).length == 0))) {
if (intentType === "text-to-image") { if (intentType?.includes("text-to-image")) {
let imageMarkdown = `![](data:image/png;base64,${message})`; let imageMarkdown;
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
}
const inferredQuery = inferredQueries?.[0]; const inferredQuery = inferredQueries?.[0];
if (inferredQuery) { if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@ -274,8 +278,13 @@ To get started, just start typing below. You can also type / to see a list of co
references.appendChild(referenceSection); references.appendChild(referenceSection);
if (intentType === "text-to-image") { if (intentType?.includes("text-to-image")) {
let imageMarkdown = `![](data:image/png;base64,${message})`; let imageMarkdown;
if (intentType === "text-to-image") {
imageMarkdown = `![](data:image/png;base64,${message})`;
} else if (intentType === "text-to-image2") {
imageMarkdown = `![](${message})`;
}
const inferredQuery = inferredQueries?.[0]; const inferredQuery = inferredQueries?.[0];
if (inferredQuery) { if (inferredQuery) {
imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`; imageMarkdown += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@ -326,7 +335,10 @@ To get started, just start typing below. You can also type / to see a list of co
// Add a copy button to each element // Add a copy button to each element
let copyButton = document.createElement('button'); let copyButton = document.createElement('button');
copyButton.classList.add("copy-button"); copyButton.classList.add("copy-button");
copyButton.innerHTML = "Copy"; let copyIcon = document.createElement("img");
copyIcon.src = "/static/assets/icons/copy_button.svg";
copyIcon.classList.add("copy-icon");
copyButton.appendChild(copyIcon);
copyButton.addEventListener('click', copyProgrammaticOutput); copyButton.addEventListener('click', copyProgrammaticOutput);
codeElement.prepend(copyButton); codeElement.prepend(copyButton);
}); });
@ -427,9 +439,27 @@ To get started, just start typing below. You can also type / to see a list of co
new_response.appendChild(newResponseText); new_response.appendChild(newResponseText);
// Temporary status message to indicate that Khoj is thinking // Temporary status message to indicate that Khoj is thinking
let loadingSpinner = document.createElement("div"); let loadingEllipsis = document.createElement("div");
loadingSpinner.classList.add("spinner"); loadingEllipsis.classList.add("lds-ellipsis");
newResponseText.appendChild(loadingSpinner);
let firstEllipsis = document.createElement("div");
firstEllipsis.classList.add("lds-ellipsis-item");
let secondEllipsis = document.createElement("div");
secondEllipsis.classList.add("lds-ellipsis-item");
let thirdEllipsis = document.createElement("div");
thirdEllipsis.classList.add("lds-ellipsis-item");
let fourthEllipsis = document.createElement("div");
fourthEllipsis.classList.add("lds-ellipsis-item");
loadingEllipsis.appendChild(firstEllipsis);
loadingEllipsis.appendChild(secondEllipsis);
loadingEllipsis.appendChild(thirdEllipsis);
loadingEllipsis.appendChild(fourthEllipsis);
newResponseText.appendChild(loadingEllipsis);
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
let chatTooltip = document.getElementById("chat-tooltip"); let chatTooltip = document.getElementById("chat-tooltip");
@ -450,7 +480,11 @@ To get started, just start typing below. You can also type / to see a list of co
const responseAsJson = await response.json(); const responseAsJson = await response.json();
if (responseAsJson.image) { if (responseAsJson.image) {
// If response has image field, response is a generated image. // If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; if (responseAsJson.intentType === "text-to-image") {
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
} else if (responseAsJson.intentType === "text-to-image2") {
rawResponse += `![${query}](${responseAsJson.image})`;
}
const inferredQuery = responseAsJson.inferredQueries?.[0]; const inferredQuery = responseAsJson.inferredQueries?.[0];
if (inferredQuery) { if (inferredQuery) {
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`; rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
@ -513,8 +547,8 @@ To get started, just start typing below. You can also type / to see a list of co
readStream(); readStream();
} else { } else {
// Display response from Khoj // Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) { if (newResponseText.getElementsByClassName("lds-ellipsis").length > 0) {
newResponseText.removeChild(loadingSpinner); newResponseText.removeChild(loadingEllipsis);
} }
// If the chunk is not a JSON object, just display it as is // If the chunk is not a JSON object, just display it as is
@ -745,7 +779,9 @@ To get started, just start typing below. You can also type / to see a list of co
// Create loading screen and add it to chat-body // Create loading screen and add it to chat-body
let loadingScreen = document.createElement('div'); let loadingScreen = document.createElement('div');
loadingScreen.classList.add('loading-screen', 'gradient-animation'); loadingScreen.classList.add("loading-spinner");
let yellowOrb = document.createElement('div');
loadingScreen.appendChild(yellowOrb);
chatBody.appendChild(loadingScreen); chatBody.appendChild(loadingScreen);
fetch(chatHistoryUrl, { method: "GET" }) fetch(chatHistoryUrl, { method: "GET" })
@ -792,8 +828,6 @@ To get started, just start typing below. You can also type / to see a list of co
}); });
// Add fade out animation to loading screen and remove it after the animation ends // Add fade out animation to loading screen and remove it after the animation ends
loadingScreen.classList.remove('gradient-animation');
loadingScreen.classList.add('fade-out-animation');
chatBodyWrapperHeight = chatBodyWrapper.clientHeight; chatBodyWrapperHeight = chatBodyWrapper.clientHeight;
chatBody.style.height = chatBodyWrapperHeight; chatBody.style.height = chatBodyWrapperHeight;
setTimeout(() => { setTimeout(() => {
@ -1493,34 +1527,6 @@ To get started, just start typing below. You can also type / to see a list of co
font-size: 2rem; font-size: 2rem;
color: #333; color: #333;
z-index: 9999; /* This is the important part */ z-index: 9999; /* This is the important part */
/* Adding gradient effect */
background: radial-gradient(circle, var(--primary-hover) 0%, var(--flower) 100%);
background-size: 200% 200%;
}
div.loading-screen::after {
content: "Loading...";
}
.gradient-animation {
animation: gradient 2s ease infinite;
}
@keyframes gradient {
0% {background-position: 0% 50%;}
50% {background-position: 100% 50%;}
100% {background-position: 0% 50%;}
}
.fade-out-animation {
animation-name: fadeOut;
animation-duration: 1.5s;
}
@keyframes fadeOut {
from {opacity: 1;}
to {opacity: 0;}
} }
/* add chat metatdata to bottom of bubble */ /* add chat metatdata to bottom of bubble */
@ -1757,13 +1763,32 @@ To get started, just start typing below. You can also type / to see a list of co
} }
button.copy-button { button.copy-button {
display: block;
border-radius: 4px; border-radius: 4px;
background-color: var(--background-color); background-color: var(--background-color);
} border: 1px solid var(--main-text-color);
button.copy-button:hover { text-align: center;
background: #f5f5f5; font-size: 16px;
transition: all 0.5s;
cursor: pointer; cursor: pointer;
padding: 4px;
float: right;
}
button.copy-button span {
cursor: pointer;
display: inline-block;
position: relative;
transition: 0.5s;
}
img.copy-icon {
width: 16px;
height: 16px;
}
button.copy-button:hover {
background-color: black;
color: #f5f5f5;
} }
pre { pre {
@ -1997,5 +2022,104 @@ To get started, just start typing below. You can also type / to see a list of co
padding: 10px; padding: 10px;
white-space: pre-wrap; white-space: pre-wrap;
} }
.loading-spinner {
display: inline-block;
position: relative;
width: 80px;
height: 80px;
}
.loading-spinner div {
position: absolute;
border: 4px solid var(--primary-hover);
opacity: 1;
border-radius: 50%;
animation: lds-ripple 0.5s cubic-bezier(0, 0.2, 0.8, 1) infinite;
}
.loading-spinner div:nth-child(2) {
animation-delay: -0.5s;
}
@keyframes lds-ripple {
0% {
top: 36px;
left: 36px;
width: 0;
height: 0;
opacity: 1;
border-color: var(--primary-hover);
}
50% {
border-color: var(--flower);
}
100% {
top: 0px;
left: 0px;
width: 72px;
height: 72px;
opacity: 0;
border-color: var(--water);
}
}
.lds-ellipsis {
display: inline-block;
position: relative;
width: 60px;
height: 32px;
}
.lds-ellipsis div {
position: absolute;
top: 12px;
width: 12px;
height: 12px;
border-radius: 50%;
background: var(--main-text-color);
animation-timing-function: cubic-bezier(0, 1, 1, 0);
}
.lds-ellipsis div:nth-child(1) {
left: 8px;
animation: lds-ellipsis1 0.6s infinite;
}
.lds-ellipsis div:nth-child(2) {
left: 8px;
animation: lds-ellipsis2 0.6s infinite;
}
.lds-ellipsis div:nth-child(3) {
left: 32px;
animation: lds-ellipsis2 0.6s infinite;
}
.lds-ellipsis div:nth-child(4) {
left: 56px;
animation: lds-ellipsis3 0.6s infinite;
}
@keyframes lds-ellipsis1 {
0% {
transform: scale(0);
}
100% {
transform: scale(1);
}
}
@keyframes lds-ellipsis3 {
0% {
transform: scale(1);
}
100% {
transform: scale(0);
}
}
@keyframes lds-ellipsis2 {
0% {
transform: translate(0, 0);
}
100% {
transform: translate(24px, 0);
}
}
</style> </style>
</html> </html>

View file

@ -234,7 +234,7 @@ class NotionToEntries(TextToEntries):
elif "Event" in properties: elif "Event" in properties:
title_field = "Event" title_field = "Event"
elif title_field not in properties: elif title_field not in properties:
logger.warning(f"Title field not found for page {page_id}. Setting title as None...") logger.debug(f"Title field not found for page {page_id}. Setting title as None...")
title = None title = None
return title, content return title, content
try: try:

View file

@ -128,7 +128,7 @@ def save_to_conversation_log(
Saved Conversation Turn Saved Conversation Turn
You ({user.username}): "{q}" You ({user.username}): "{q}"
Khoj: "{inferred_queries if intent_type == "text-to-image" else chat_response}" Khoj: "{inferred_queries if ("text_to_image" in intent_type) else chat_response}"
""".strip() """.strip()
) )

View file

@ -300,25 +300,30 @@ async def chat(
metadata={"conversation_command": conversation_commands[0].value}, metadata={"conversation_command": conversation_commands[0].value},
**common.__dict__, **common.__dict__,
) )
image, status_code, improved_image_prompt = await text_to_image( intent_type = "text-to-image"
q, meta_log, location_data=location, references=compiled_references, online_results=online_results image, status_code, improved_image_prompt, image_url = await text_to_image(
q, user, meta_log, location_data=location, references=compiled_references, online_results=online_results
) )
if image is None: if image is None:
content_obj = {"image": image, "intentType": "text-to-image", "detail": improved_image_prompt} content_obj = {"image": image, "intentType": intent_type, "detail": improved_image_prompt}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
if image_url:
intent_type = "text-to-image2"
image = image_url
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
image, image,
user, user,
meta_log, meta_log,
intent_type="text-to-image", intent_type=intent_type,
inferred_queries=[improved_image_prompt], inferred_queries=[improved_image_prompt],
client_application=request.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
) )
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore content_obj = {"image": image, "intentType": intent_type, "inferredQueries": [improved_image_prompt], "context": compiled_references, "online_results": online_results} # type: ignore
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice. # Get the (streamed) chat response from the LLM of choice.

View file

@ -4,11 +4,9 @@ import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from functools import partial from functools import partial
from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
import openai import openai
import requests
from fastapi import Depends, Header, HTTPException, Request, UploadFile from fastapi import Depends, Header, HTTPException, Request, UploadFile
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
@ -32,6 +30,7 @@ from khoj.processor.conversation.utils import (
generate_chatml_messages_with_context, generate_chatml_messages_with_context,
save_to_conversation_log, save_to_conversation_log,
) )
from khoj.routers.storage import upload_image
from khoj.utils import state from khoj.utils import state
from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import ( from khoj.utils.helpers import (
@ -39,6 +38,7 @@ from khoj.utils.helpers import (
is_none_or_empty, is_none_or_empty,
log_telemetry, log_telemetry,
mode_descriptions_for_llm, mode_descriptions_for_llm,
timer,
tool_descriptions_for_llm, tool_descriptions_for_llm,
) )
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@ -439,53 +439,65 @@ def generate_chat_response(
async def text_to_image( async def text_to_image(
message: str, message: str,
user: KhojUser,
conversation_log: dict, conversation_log: dict,
location_data: LocationData, location_data: LocationData,
references: List[str], references: List[str],
online_results: Dict[str, Any], online_results: Dict[str, Any],
) -> Tuple[Optional[str], int, Optional[str]]: ) -> Tuple[Optional[str], int, Optional[str], Optional[str]]:
status_code = 200 status_code = 200
image = None image = None
response = None response = None
image_url = None
text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config() text_to_image_config = await ConversationAdapters.aget_text_to_image_model_config()
if not text_to_image_config: if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error # If the user has not configured a text to image model, return an unsupported on server error
status_code = 501 status_code = 501
message = "Failed to generate image. Setup image generation on the server." message = "Failed to generate image. Setup image generation on the server."
return image, status_code, message return image, status_code, message, image_url
elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: elif state.openai_client and text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
logger.info("Generating image with OpenAI")
text2image_model = text_to_image_config.model_name text2image_model = text_to_image_config.model_name
chat_history = "" chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]: for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember": if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
chat_history += f"Q: {chat['intent']['query']}\n" chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n" chat_history += f"A: {chat['message']}\n"
improved_image_prompt = await generate_better_image_prompt( elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
message, chat_history += f"Q: {chat['intent']['query']}\n"
chat_history, chat_history += f"A: [generated image redacted by admin]. Enhanced image prompt: {chat['intent']['inferred-queries'][0]}\n"
location_data=location_data,
note_references=references, with timer("Improve the original user query", logger):
online_results=online_results, improved_image_prompt = await generate_better_image_prompt(
) message,
try: chat_history,
response = state.openai_client.images.generate( location_data=location_data,
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json" note_references=references,
online_results=online_results,
) )
image = response.data[0].b64_json try:
with timer("Generate image with OpenAI", logger):
response = state.openai_client.images.generate(
prompt=improved_image_prompt, model=text2image_model, response_format="b64_json"
)
image = response.data[0].b64_json
with timer("Upload image to S3", logger):
image_url = upload_image(image, user.uuid)
return image, status_code, improved_image_prompt, image_url
except openai.OpenAIError or openai.BadRequestError as e: except openai.OpenAIError or openai.BadRequestError as e:
if "content_policy_violation" in e.message: if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}") logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
return image, status_code, message return image, status_code, message, image_url
else: else:
logger.error(f"Image Generation failed with {e}", exc_info=True) logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code status_code = e.status_code # type: ignore
return image, status_code, message return image, status_code, message, image_url
return image, status_code, response, image_url
return image, status_code, improved_image_prompt
class ApiUserRateLimiter: class ApiUserRateLimiter:

View file

@ -0,0 +1,34 @@
import base64
import logging
import os
import uuid
logger = logging.getLogger(__name__)
AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY")
AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY")
AWS_UPLOAD_IMAGE_BUCKET_NAME = os.getenv("AWS_IMAGE_UPLOAD_BUCKET")
aws_enabled = AWS_ACCESS_KEY is not None and AWS_SECRET_KEY is not None and AWS_UPLOAD_IMAGE_BUCKET_NAME is not None
if aws_enabled:
from boto3 import client
s3_client = client("s3", aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
def upload_image(image: str, user_id: uuid.UUID):
"""Upload the image to the S3 bucket"""
if not aws_enabled:
logger.info("AWS is not enabled. Skipping image upload")
return None
decoded_image = base64.b64decode(image)
image_key = f"{user_id}/{uuid.uuid4()}.png"
try:
s3_client.put_object(Bucket=AWS_UPLOAD_IMAGE_BUCKET_NAME, Key=image_key, Body=decoded_image, ACL="public-read")
url = f"https://{AWS_UPLOAD_IMAGE_BUCKET_NAME}.s3.amazonaws.com/{image_key}"
return url
except Exception as e:
logger.error(f"Failed to upload image to S3: {e}")
return None