Render the inferred query along with the image that Khoj returns

This commit is contained in:
sabaimran 2023-12-17 21:02:55 +05:30
parent 49af2148fe
commit 0288804f2e
4 changed files with 37 additions and 9 deletions

View file

@ -179,9 +179,14 @@
return numOnlineReferences; return numOnlineReferences;
} }
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if (intentType === "text-to-image") { if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`; let imageMarkdown = `![](data:image/png;base64,${message})`;
imageMarkdown += "\n\n";
if (inferredQueries) {
const inferredQuery = inferredQueries?.[0];
imageMarkdown += `**Inferred Query**: ${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt); renderMessage(imageMarkdown, by, dt);
return; return;
} }
@ -357,6 +362,11 @@
if (responseAsJson.image) { if (responseAsJson.image) {
// If response has image field, response is a generated image. // If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
rawResponse += "\n\n";
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `**Inferred Query**: ${inferredQueries}`;
}
} }
if (responseAsJson.detail) { if (responseAsJson.detail) {
// If response has detail field, response is an error message. // If response has detail field, response is an error message.
@ -454,7 +464,13 @@
try { try {
const responseAsJson = JSON.parse(chunk); const responseAsJson = JSON.parse(chunk);
if (responseAsJson.image) { if (responseAsJson.image) {
// If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
rawResponse += "\n\n";
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `**Inferred Query**: ${inferredQueries}`;
}
} }
if (responseAsJson.detail) { if (responseAsJson.detail) {
rawResponse += responseAsJson.detail; rawResponse += responseAsJson.detail;
@ -572,7 +588,7 @@
.then(response => { .then(response => {
// Render conversation history, if any // Render conversation history, if any
response.forEach(chat_log => { response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type); renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]);
}); });
}) })
.catch(err => { .catch(err => {
@ -906,9 +922,9 @@
border: 1px solid var(--main-text-color); border: 1px solid var(--main-text-color);
box-shadow: 0 0 11px #aaa; box-shadow: 0 0 11px #aaa;
border-radius: 5px; border-radius: 5px;
padding: 5px;
font-size: 14px; font-size: 14px;
font-weight: 300; font-weight: 300;
padding: 0;
line-height: 1.5em; line-height: 1.5em;
cursor: pointer; cursor: pointer;
transition: background 0.3s ease-in-out; transition: background 0.3s ease-in-out;

View file

@ -188,9 +188,14 @@ To get started, just start typing below. You can also type / to see a list of co
return numOnlineReferences; return numOnlineReferences;
} }
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null) { function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) {
if (intentType === "text-to-image") { if (intentType === "text-to-image") {
let imageMarkdown = `![](data:image/png;base64,${message})`; let imageMarkdown = `![](data:image/png;base64,${message})`;
imageMarkdown += "\n\n";
if (inferredQueries) {
const inferredQuery = inferredQueries?.[0];
imageMarkdown += `**Inferred Query**: ${inferredQuery}`;
}
renderMessage(imageMarkdown, by, dt); renderMessage(imageMarkdown, by, dt);
return; return;
} }
@ -362,6 +367,11 @@ To get started, just start typing below. You can also type / to see a list of co
if (responseAsJson.image) { if (responseAsJson.image) {
// If response has image field, response is a generated image. // If response has image field, response is a generated image.
rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`; rawResponse += `![${query}](data:image/png;base64,${responseAsJson.image})`;
rawResponse += "\n\n";
const inferredQueries = responseAsJson.inferredQueries?.[0];
if (inferredQueries) {
rawResponse += `**Inferred Query**: ${inferredQueries}`;
}
} }
if (responseAsJson.detail) { if (responseAsJson.detail) {
// If response has detail field, response is an error message. // If response has detail field, response is an error message.
@ -543,7 +553,7 @@ To get started, just start typing below. You can also type / to see a list of co
.then(response => { .then(response => {
// Render conversation history, if any // Render conversation history, if any
response.forEach(chat_log => { response.forEach(chat_log => {
renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type); renderMessageWithReference(chat_log.message, chat_log.by, chat_log.context, new Date(chat_log.created), chat_log.onlineContext, chat_log.intent?.type, chat_log.intent?.["inferred-queries"]);
}); });
}) })
.catch(err => { .catch(err => {

View file

@ -721,7 +721,7 @@ async def chat(
metadata={"conversation_command": conversation_command.value}, metadata={"conversation_command": conversation_command.value},
**common.__dict__, **common.__dict__,
) )
image, status_code = await text_to_image(q) image, status_code, improved_image_prompt = await text_to_image(q)
if image is None: if image is None:
content_obj = { content_obj = {
"image": image, "image": image,
@ -729,8 +729,10 @@ async def chat(
"detail": "Failed to generate image. Make sure your image generation configuration is set.", "detail": "Failed to generate image. Make sure your image generation configuration is set.",
} }
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
await sync_to_async(save_to_conversation_log)(q, image, user, meta_log, intent_type="text-to-image") await sync_to_async(save_to_conversation_log)(
content_obj = {"image": image, "intentType": "text-to-image"} q, image, user, meta_log, intent_type="text-to-image", inferred_queries=[improved_image_prompt]
)
content_obj = {"image": image, "intentType": "text-to-image", "inferredQueries": [improved_image_prompt]}
return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code) return Response(content=json.dumps(content_obj), media_type="application/json", status_code=status_code)
# Get the (streamed) chat response from the LLM of choice. # Get the (streamed) chat response from the LLM of choice.

View file

@ -286,7 +286,7 @@ async def text_to_image(message: str) -> Tuple[Optional[str], int]:
logger.error(f"Image Generation failed with {e}", exc_info=True) logger.error(f"Image Generation failed with {e}", exc_info=True)
status_code = 500 status_code = 500
return image, status_code return image, status_code, improved_image_prompt
class ApiUserRateLimiter: class ApiUserRateLimiter: