mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Chat with Multiple Images. Support Vision with Gemini (#942)
## Overview - Add vision support for Gemini models in Khoj - Allow sharing multiple images as part of user query from the web app - Handle multiple images shared in query to chat API
This commit is contained in:
commit
c6f3253ebd
17 changed files with 454 additions and 294 deletions
|
@ -27,14 +27,14 @@ interface ChatBodyDataProps {
|
|||
setUploadedFiles: (files: string[]) => void;
|
||||
isMobileWidth?: boolean;
|
||||
isLoggedIn: boolean;
|
||||
setImage64: (image64: string) => void;
|
||||
setImages: (images: string[]) => void;
|
||||
}
|
||||
|
||||
function ChatBodyData(props: ChatBodyDataProps) {
|
||||
const searchParams = useSearchParams();
|
||||
const conversationId = searchParams.get("conversationId");
|
||||
const [message, setMessage] = useState("");
|
||||
const [image, setImage] = useState<string | null>(null);
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
const [processingMessage, setProcessingMessage] = useState(false);
|
||||
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
||||
|
||||
|
@ -44,17 +44,20 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
|
||||
|
||||
useEffect(() => {
|
||||
if (image) {
|
||||
props.setImage64(encodeURIComponent(image));
|
||||
if (images.length > 0) {
|
||||
const encodedImages = images.map((image) => encodeURIComponent(image));
|
||||
props.setImages(encodedImages);
|
||||
}
|
||||
}, [image, props.setImage64]);
|
||||
}, [images, props.setImages]);
|
||||
|
||||
useEffect(() => {
|
||||
const storedImage = localStorage.getItem("image");
|
||||
if (storedImage) {
|
||||
setImage(storedImage);
|
||||
props.setImage64(encodeURIComponent(storedImage));
|
||||
localStorage.removeItem("image");
|
||||
const storedImages = localStorage.getItem("images");
|
||||
if (storedImages) {
|
||||
const parsedImages: string[] = JSON.parse(storedImages);
|
||||
setImages(parsedImages);
|
||||
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
|
||||
props.setImages(encodedImages);
|
||||
localStorage.removeItem("images");
|
||||
}
|
||||
|
||||
const storedMessage = localStorage.getItem("message");
|
||||
|
@ -62,7 +65,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
setProcessingMessage(true);
|
||||
setQueryToProcess(storedMessage);
|
||||
}
|
||||
}, [setQueryToProcess]);
|
||||
}, [setQueryToProcess, props.setImages]);
|
||||
|
||||
useEffect(() => {
|
||||
if (message) {
|
||||
|
@ -84,6 +87,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
props.streamedMessages[props.streamedMessages.length - 1].completed
|
||||
) {
|
||||
setProcessingMessage(false);
|
||||
setImages([]); // Reset images after processing
|
||||
} else {
|
||||
setMessage("");
|
||||
}
|
||||
|
@ -113,7 +117,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
agentColor={agentMetadata?.color}
|
||||
isLoggedIn={props.isLoggedIn}
|
||||
sendMessage={(message) => setMessage(message)}
|
||||
sendImage={(image) => setImage(image)}
|
||||
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||
sendDisabled={processingMessage}
|
||||
chatOptionsData={props.chatOptionsData}
|
||||
conversationId={conversationId}
|
||||
|
@ -135,7 +139,7 @@ export default function Chat() {
|
|||
const [queryToProcess, setQueryToProcess] = useState<string>("");
|
||||
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||
const [image64, setImage64] = useState<string>("");
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
|
||||
const locationData = useIPLocationData() || {
|
||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||
|
@ -171,7 +175,7 @@ export default function Chat() {
|
|||
completed: false,
|
||||
timestamp: new Date().toISOString(),
|
||||
rawQuery: queryToProcess || "",
|
||||
uploadedImageData: decodeURIComponent(image64),
|
||||
images: images,
|
||||
};
|
||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
||||
setProcessQuerySignal(true);
|
||||
|
@ -202,7 +206,7 @@ export default function Chat() {
|
|||
if (done) {
|
||||
setQueryToProcess("");
|
||||
setProcessQuerySignal(false);
|
||||
setImage64("");
|
||||
setImages([]);
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -250,7 +254,7 @@ export default function Chat() {
|
|||
country_code: locationData.countryCode,
|
||||
timezone: locationData.timezone,
|
||||
}),
|
||||
...(image64 && { image: image64 }),
|
||||
...(images.length > 0 && { images: images }),
|
||||
};
|
||||
|
||||
const response = await fetch(chatAPI, {
|
||||
|
@ -264,7 +268,8 @@ export default function Chat() {
|
|||
try {
|
||||
await readChatStream(response);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
const apiError = await response.json();
|
||||
console.error(apiError);
|
||||
// Retrieve latest message being processed
|
||||
const currentMessage = messages.find((message) => !message.completed);
|
||||
if (!currentMessage) return;
|
||||
|
@ -273,7 +278,11 @@ export default function Chat() {
|
|||
const errorMessage = (err as Error).message;
|
||||
if (errorMessage.includes("Error in input stream"))
|
||||
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
|
||||
else
|
||||
else if (response.status === 429) {
|
||||
"detail" in apiError
|
||||
? (currentMessage.rawResponse = `${apiError.detail}`)
|
||||
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
|
||||
} else
|
||||
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
|
||||
|
||||
// Complete message streaming teardown properly
|
||||
|
@ -332,7 +341,7 @@ export default function Chat() {
|
|||
setUploadedFiles={setUploadedFiles}
|
||||
isMobileWidth={isMobileWidth}
|
||||
onConversationIdChange={handleConversationIdChange}
|
||||
setImage64={setImage64}
|
||||
setImages={setImages}
|
||||
/>
|
||||
</Suspense>
|
||||
</div>
|
||||
|
|
|
@ -299,7 +299,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
created: message.timestamp,
|
||||
by: "you",
|
||||
automationId: "",
|
||||
uploadedImageData: message.uploadedImageData,
|
||||
images: message.images,
|
||||
}}
|
||||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
|
@ -348,7 +348,6 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
created: new Date().getTime().toString(),
|
||||
by: "you",
|
||||
automationId: "",
|
||||
uploadedImageData: props.pendingMessage,
|
||||
}}
|
||||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
|
|
|
@ -3,23 +3,7 @@ import React, { useEffect, useRef, useState } from "react";
|
|||
|
||||
import DOMPurify from "dompurify";
|
||||
import "katex/dist/katex.min.css";
|
||||
import {
|
||||
ArrowRight,
|
||||
ArrowUp,
|
||||
Browser,
|
||||
ChatsTeardrop,
|
||||
GlobeSimple,
|
||||
Gps,
|
||||
Image,
|
||||
Microphone,
|
||||
Notebook,
|
||||
Paperclip,
|
||||
X,
|
||||
Question,
|
||||
Robot,
|
||||
Shapes,
|
||||
Stop,
|
||||
} from "@phosphor-icons/react";
|
||||
import { ArrowUp, Microphone, Paperclip, X, Stop } from "@phosphor-icons/react";
|
||||
|
||||
import {
|
||||
Command,
|
||||
|
@ -78,10 +62,11 @@ export default function ChatInputArea(props: ChatInputProps) {
|
|||
const [loginRedirectMessage, setLoginRedirectMessage] = useState<string | null>(null);
|
||||
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
|
||||
|
||||
const [recording, setRecording] = useState(false);
|
||||
const [imageUploaded, setImageUploaded] = useState(false);
|
||||
const [imagePath, setImagePath] = useState<string>("");
|
||||
const [imageData, setImageData] = useState<string | null>(null);
|
||||
const [imagePaths, setImagePaths] = useState<string[]>([]);
|
||||
const [imageData, setImageData] = useState<string[]>([]);
|
||||
|
||||
const [recording, setRecording] = useState(false);
|
||||
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
|
||||
|
||||
const [progressValue, setProgressValue] = useState(0);
|
||||
|
@ -106,27 +91,31 @@ export default function ChatInputArea(props: ChatInputProps) {
|
|||
|
||||
useEffect(() => {
|
||||
async function fetchImageData() {
|
||||
if (imagePath) {
|
||||
const response = await fetch(imagePath);
|
||||
if (imagePaths.length > 0) {
|
||||
const newImageData = await Promise.all(
|
||||
imagePaths.map(async (path) => {
|
||||
const response = await fetch(path);
|
||||
const blob = await response.blob();
|
||||
return new Promise<string>((resolve) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = function () {
|
||||
const base64data = reader.result;
|
||||
setImageData(base64data as string);
|
||||
};
|
||||
reader.onload = () => resolve(reader.result as string);
|
||||
reader.readAsDataURL(blob);
|
||||
});
|
||||
}),
|
||||
);
|
||||
setImageData(newImageData);
|
||||
}
|
||||
setUploading(false);
|
||||
}
|
||||
setUploading(true);
|
||||
fetchImageData();
|
||||
}, [imagePath]);
|
||||
}, [imagePaths]);
|
||||
|
||||
function onSendMessage() {
|
||||
if (imageUploaded) {
|
||||
setImageUploaded(false);
|
||||
setImagePath("");
|
||||
props.sendImage(imageData || "");
|
||||
setImagePaths([]);
|
||||
imageData.forEach((data) => props.sendImage(data));
|
||||
}
|
||||
if (!message.trim()) return;
|
||||
|
||||
|
@ -172,18 +161,23 @@ export default function ChatInputArea(props: ChatInputProps) {
|
|||
setShowLoginPrompt(true);
|
||||
return;
|
||||
}
|
||||
// check for image file
|
||||
// check for image files
|
||||
const image_endings = ["jpg", "jpeg", "png", "webp"];
|
||||
const newImagePaths: string[] = [];
|
||||
for (let i = 0; i < files.length; i++) {
|
||||
const file = files[i];
|
||||
const file_extension = file.name.split(".").pop();
|
||||
if (image_endings.includes(file_extension || "")) {
|
||||
setImageUploaded(true);
|
||||
setImagePath(DOMPurify.sanitize(URL.createObjectURL(file)));
|
||||
return;
|
||||
newImagePaths.push(DOMPurify.sanitize(URL.createObjectURL(file)));
|
||||
}
|
||||
}
|
||||
|
||||
if (newImagePaths.length > 0) {
|
||||
setImageUploaded(true);
|
||||
setImagePaths((prevPaths) => [...prevPaths, ...newImagePaths]);
|
||||
return;
|
||||
}
|
||||
|
||||
uploadDataForIndexing(
|
||||
files,
|
||||
setWarning,
|
||||
|
@ -288,9 +282,12 @@ export default function ChatInputArea(props: ChatInputProps) {
|
|||
setIsDragAndDropping(false);
|
||||
}
|
||||
|
||||
function removeImageUpload() {
|
||||
function removeImageUpload(index: number) {
|
||||
setImagePaths((prevPaths) => prevPaths.filter((_, i) => i !== index));
|
||||
setImageData((prevData) => prevData.filter((_, i) => i !== index));
|
||||
if (imagePaths.length === 1) {
|
||||
setImageUploaded(false);
|
||||
setImagePath("");
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
|
@ -407,24 +404,11 @@ export default function ChatInputArea(props: ChatInputProps) {
|
|||
</div>
|
||||
)}
|
||||
<div
|
||||
className={`${styles.actualInputArea} items-center justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
||||
className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
||||
onDragOver={handleDragOver}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={handleDragAndDropFiles}
|
||||
>
|
||||
{imageUploaded && (
|
||||
<div className="absolute bottom-[80px] left-0 right-0 dark:bg-neutral-700 bg-white pt-5 pb-5 w-full rounded-lg border dark:border-none grid grid-cols-2">
|
||||
<div className="pl-4 pr-4">
|
||||
<img src={imagePath} alt="img" className="w-auto max-h-[100px]" />
|
||||
</div>
|
||||
<div className="pl-4 pr-4">
|
||||
<X
|
||||
className="w-6 h-6 float-right dark:hover:bg-[hsl(var(--background))] hover:bg-neutral-100 rounded-sm"
|
||||
onClick={removeImageUpload}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<input
|
||||
type="file"
|
||||
multiple={true}
|
||||
|
@ -432,6 +416,7 @@ export default function ChatInputArea(props: ChatInputProps) {
|
|||
onChange={handleFileChange}
|
||||
style={{ display: "none" }}
|
||||
/>
|
||||
<div className="flex items-end pb-4">
|
||||
<Button
|
||||
variant={"ghost"}
|
||||
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
|
||||
|
@ -440,7 +425,28 @@ export default function ChatInputArea(props: ChatInputProps) {
|
|||
>
|
||||
<Paperclip className="w-8 h-8" />
|
||||
</Button>
|
||||
<div className="grid w-full gap-1.5 relative">
|
||||
</div>
|
||||
<div className="flex-grow flex flex-col w-full gap-1.5 relative pb-2">
|
||||
<div className="flex items-center gap-2 overflow-x-auto">
|
||||
{imageUploaded &&
|
||||
imagePaths.map((path, index) => (
|
||||
<div key={index} className="relative flex-shrink-0 pb-3 pt-2 group">
|
||||
<img
|
||||
src={path}
|
||||
alt={`img-${index}`}
|
||||
className="w-auto h-16 object-cover rounded-xl"
|
||||
/>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onClick={() => removeImageUpload(index)}
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<Textarea
|
||||
ref={chatInputRef}
|
||||
className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`}
|
||||
|
@ -451,7 +457,7 @@ export default function ChatInputArea(props: ChatInputProps) {
|
|||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
setImageUploaded(false);
|
||||
setImagePath("");
|
||||
setImagePaths([]);
|
||||
e.preventDefault();
|
||||
onSendMessage();
|
||||
}
|
||||
|
@ -460,6 +466,7 @@ export default function ChatInputArea(props: ChatInputProps) {
|
|||
disabled={props.sendDisabled || recording}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-end pb-4">
|
||||
{recording ? (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
|
@ -512,6 +519,7 @@ export default function ChatInputArea(props: ChatInputProps) {
|
|||
<ArrowUp className="w-6 h-6" weight="bold" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
|
|
@ -57,7 +57,26 @@ div.emptyChatMessage {
|
|||
display: none;
|
||||
}
|
||||
|
||||
div.chatMessageContainer img {
|
||||
div.imagesContainer {
|
||||
display: flex;
|
||||
overflow-x: auto;
|
||||
padding-bottom: 8px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
div.imageWrapper {
|
||||
flex: 0 0 auto;
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
div.imageWrapper img {
|
||||
width: auto;
|
||||
height: 128px;
|
||||
object-fit: cover;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
div.chatMessageContainer > img {
|
||||
width: auto;
|
||||
height: auto;
|
||||
max-width: 100%;
|
||||
|
|
|
@ -116,7 +116,7 @@ export interface SingleChatMessage {
|
|||
rawQuery?: string;
|
||||
intent?: Intent;
|
||||
agent?: AgentData;
|
||||
uploadedImageData?: string;
|
||||
images?: string[];
|
||||
}
|
||||
|
||||
export interface StreamMessage {
|
||||
|
@ -128,10 +128,9 @@ export interface StreamMessage {
|
|||
rawQuery: string;
|
||||
timestamp: string;
|
||||
agent?: AgentData;
|
||||
uploadedImageData?: string;
|
||||
images?: string[];
|
||||
intentType?: string;
|
||||
inferredQueries?: string[];
|
||||
image?: string;
|
||||
}
|
||||
|
||||
export interface ChatHistoryData {
|
||||
|
@ -213,7 +212,6 @@ interface ChatMessageProps {
|
|||
borderLeftColor?: string;
|
||||
isLastMessage?: boolean;
|
||||
agent?: AgentData;
|
||||
uploadedImageData?: string;
|
||||
}
|
||||
|
||||
interface TrainOfThoughtProps {
|
||||
|
@ -343,8 +341,17 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
|||
.replace(/\\\[/g, "LEFTBRACKET")
|
||||
.replace(/\\\]/g, "RIGHTBRACKET");
|
||||
|
||||
if (props.chatMessage.uploadedImageData) {
|
||||
message = `![uploaded image](${props.chatMessage.uploadedImageData})\n\n${message}`;
|
||||
if (props.chatMessage.images && props.chatMessage.images.length > 0) {
|
||||
const imagesInMd = props.chatMessage.images
|
||||
.map((image, index) => {
|
||||
const decodedImage = image.startsWith("data%3Aimage")
|
||||
? decodeURIComponent(image)
|
||||
: image;
|
||||
const sanitizedImage = DOMPurify.sanitize(decodedImage);
|
||||
return `<div class="${styles.imageWrapper}"><img src="${sanitizedImage}" alt="uploaded image ${index + 1}" /></div>`;
|
||||
})
|
||||
.join("");
|
||||
message = `<div class="${styles.imagesContainer}">${imagesInMd}</div>${message}`;
|
||||
}
|
||||
|
||||
const intentTypeHandlers = {
|
||||
|
@ -384,7 +391,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
|||
|
||||
// Sanitize and set the rendered markdown
|
||||
setMarkdownRendered(DOMPurify.sanitize(markdownRendered));
|
||||
}, [props.chatMessage.message, props.chatMessage.intent]);
|
||||
}, [props.chatMessage.message, props.chatMessage.images, props.chatMessage.intent]);
|
||||
|
||||
useEffect(() => {
|
||||
if (copySuccess) {
|
||||
|
|
|
@ -44,7 +44,7 @@ function FisherYatesShuffle(array: any[]) {
|
|||
|
||||
function ChatBodyData(props: ChatBodyDataProps) {
|
||||
const [message, setMessage] = useState("");
|
||||
const [image, setImage] = useState<string | null>(null);
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
const [processingMessage, setProcessingMessage] = useState(false);
|
||||
const [greeting, setGreeting] = useState("");
|
||||
const [shuffledOptions, setShuffledOptions] = useState<Suggestion[]>([]);
|
||||
|
@ -138,20 +138,21 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
try {
|
||||
const newConversationId = await createNewConversation(selectedAgent || "khoj");
|
||||
onConversationIdChange?.(newConversationId);
|
||||
window.location.href = `/chat?conversationId=${newConversationId}`;
|
||||
localStorage.setItem("message", message);
|
||||
if (image) {
|
||||
localStorage.setItem("image", image);
|
||||
if (images.length > 0) {
|
||||
localStorage.setItem("images", JSON.stringify(images));
|
||||
}
|
||||
window.location.href = `/chat?conversationId=${newConversationId}`;
|
||||
} catch (error) {
|
||||
console.error("Error creating new conversation:", error);
|
||||
setProcessingMessage(false);
|
||||
}
|
||||
setMessage("");
|
||||
setImages([]);
|
||||
}
|
||||
};
|
||||
processMessage();
|
||||
if (message) {
|
||||
if (message || images.length > 0) {
|
||||
setProcessingMessage(true);
|
||||
}
|
||||
}, [selectedAgent, message, processingMessage, onConversationIdChange]);
|
||||
|
@ -224,7 +225,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit"}`}>
|
||||
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit max-w-screen-md"}`}>
|
||||
{!props.isMobileWidth && (
|
||||
<div
|
||||
className={`w-full ${styles.inputBox} shadow-lg bg-background align-middle items-center justify-center px-3 py-1 dark:bg-neutral-700 border-stone-100 dark:border-none dark:shadow-none rounded-2xl`}
|
||||
|
@ -232,7 +233,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
<ChatInputArea
|
||||
isLoggedIn={props.isLoggedIn}
|
||||
sendMessage={(message) => setMessage(message)}
|
||||
sendImage={(image) => setImage(image)}
|
||||
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||
sendDisabled={processingMessage}
|
||||
chatOptionsData={props.chatOptionsData}
|
||||
conversationId={null}
|
||||
|
@ -313,7 +314,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
<ChatInputArea
|
||||
isLoggedIn={props.isLoggedIn}
|
||||
sendMessage={(message) => setMessage(message)}
|
||||
sendImage={(image) => setImage(image)}
|
||||
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||
sendDisabled={processingMessage}
|
||||
chatOptionsData={props.chatOptionsData}
|
||||
conversationId={null}
|
||||
|
|
|
@ -28,12 +28,12 @@ interface ChatBodyDataProps {
|
|||
isLoggedIn: boolean;
|
||||
conversationId?: string;
|
||||
setQueryToProcess: (query: string) => void;
|
||||
setImage64: (image64: string) => void;
|
||||
setImages: (images: string[]) => void;
|
||||
}
|
||||
|
||||
function ChatBodyData(props: ChatBodyDataProps) {
|
||||
const [message, setMessage] = useState("");
|
||||
const [image, setImage] = useState<string | null>(null);
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
const [processingMessage, setProcessingMessage] = useState(false);
|
||||
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
||||
const setQueryToProcess = props.setQueryToProcess;
|
||||
|
@ -42,10 +42,28 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
|
||||
|
||||
useEffect(() => {
|
||||
if (image) {
|
||||
props.setImage64(encodeURIComponent(image));
|
||||
if (images.length > 0) {
|
||||
const encodedImages = images.map((image) => encodeURIComponent(image));
|
||||
props.setImages(encodedImages);
|
||||
}
|
||||
}, [image, props.setImage64]);
|
||||
}, [images, props.setImages]);
|
||||
|
||||
useEffect(() => {
|
||||
const storedImages = localStorage.getItem("images");
|
||||
if (storedImages) {
|
||||
const parsedImages: string[] = JSON.parse(storedImages);
|
||||
setImages(parsedImages);
|
||||
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
|
||||
props.setImages(encodedImages);
|
||||
localStorage.removeItem("images");
|
||||
}
|
||||
|
||||
const storedMessage = localStorage.getItem("message");
|
||||
if (storedMessage) {
|
||||
setProcessingMessage(true);
|
||||
setQueryToProcess(storedMessage);
|
||||
}
|
||||
}, [setQueryToProcess, props.setImages]);
|
||||
|
||||
useEffect(() => {
|
||||
if (message) {
|
||||
|
@ -89,7 +107,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
<ChatInputArea
|
||||
isLoggedIn={props.isLoggedIn}
|
||||
sendMessage={(message) => setMessage(message)}
|
||||
sendImage={(image) => setImage(image)}
|
||||
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||
sendDisabled={processingMessage}
|
||||
chatOptionsData={props.chatOptionsData}
|
||||
conversationId={props.conversationId}
|
||||
|
@ -112,7 +130,7 @@ export default function SharedChat() {
|
|||
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
|
||||
const [image64, setImage64] = useState<string>("");
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
|
||||
const locationData = useIPLocationData() || {
|
||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||
|
@ -170,7 +188,7 @@ export default function SharedChat() {
|
|||
completed: false,
|
||||
timestamp: new Date().toISOString(),
|
||||
rawQuery: queryToProcess || "",
|
||||
uploadedImageData: decodeURIComponent(image64),
|
||||
images: images,
|
||||
};
|
||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
||||
setProcessQuerySignal(true);
|
||||
|
@ -197,7 +215,7 @@ export default function SharedChat() {
|
|||
if (done) {
|
||||
setQueryToProcess("");
|
||||
setProcessQuerySignal(false);
|
||||
setImage64("");
|
||||
setImages([]);
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -239,7 +257,7 @@ export default function SharedChat() {
|
|||
country_code: locationData.countryCode,
|
||||
timezone: locationData.timezone,
|
||||
}),
|
||||
...(image64 && { image: image64 }),
|
||||
...(images.length > 0 && { image: images }),
|
||||
};
|
||||
|
||||
const response = await fetch(chatAPI, {
|
||||
|
@ -302,7 +320,7 @@ export default function SharedChat() {
|
|||
setTitle={setTitle}
|
||||
setUploadedFiles={setUploadedFiles}
|
||||
isMobileWidth={isMobileWidth}
|
||||
setImage64={setImage64}
|
||||
setImages={setImages}
|
||||
/>
|
||||
</Suspense>
|
||||
</div>
|
||||
|
|
|
@ -6,14 +6,17 @@ from typing import Dict, Optional
|
|||
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
format_messages_for_gemini,
|
||||
gemini_chat_completion_with_backoff,
|
||||
gemini_completion_with_backoff,
|
||||
)
|
||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||
from khoj.processor.conversation.utils import (
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
|
@ -29,6 +32,8 @@ def extract_questions_gemini(
|
|||
max_tokens=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
|
@ -70,17 +75,17 @@ def extract_questions_gemini(
|
|||
text=text,
|
||||
)
|
||||
|
||||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
images=query_images,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
vision_enabled=vision_enabled,
|
||||
)
|
||||
|
||||
model_kwargs = {"response_mime_type": "application/json"}
|
||||
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
||||
|
||||
response = gemini_completion_with_backoff(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
model_kwargs=model_kwargs,
|
||||
response = gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
|
@ -102,7 +107,7 @@ def extract_questions_gemini(
|
|||
return questions
|
||||
|
||||
|
||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
|
||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -114,7 +119,12 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
|
|||
|
||||
# Get Response from Gemini
|
||||
return gemini_completion_with_backoff(
|
||||
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
@ -133,6 +143,8 @@ def converse_gemini(
|
|||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
):
|
||||
"""
|
||||
Converse with user using Google's Gemini
|
||||
|
@ -187,6 +199,9 @@ def converse_gemini(
|
|||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
)
|
||||
|
||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
import random
|
||||
from io import BytesIO
|
||||
from threading import Thread
|
||||
|
||||
import google.generativeai as genai
|
||||
import PIL.Image
|
||||
import requests
|
||||
from google.generativeai.types.answer_types import FinishReason
|
||||
from google.generativeai.types.generation_types import StopCandidateException
|
||||
from google.generativeai.types.safety_types import (
|
||||
|
@ -53,14 +56,14 @@ def gemini_completion_with_backoff(
|
|||
},
|
||||
)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
|
||||
# Start chat session. All messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
|
||||
try:
|
||||
# Generate the response. The last message is considered to be the current prompt
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||
return aggregated_response.text
|
||||
except StopCandidateException as e:
|
||||
response_message, _ = handle_gemini_response(e.args)
|
||||
|
@ -117,11 +120,11 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
|
|||
},
|
||||
)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
# all messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
# the last message is considered to be the current prompt
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = message or chunk.text
|
||||
g.send(message)
|
||||
|
@ -191,14 +194,6 @@ def generate_safety_response(safety_ratings):
|
|||
|
||||
|
||||
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
return messages, system_prompt
|
||||
|
||||
for message in messages:
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
# Extract system message
|
||||
system_prompt = system_prompt or ""
|
||||
for message in messages.copy():
|
||||
|
@ -207,4 +202,31 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
|
|||
messages.remove(message)
|
||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||
|
||||
for message in messages:
|
||||
# Convert message content to string list from chatml dictionary list
|
||||
if isinstance(message.content, list):
|
||||
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
||||
message.content = [
|
||||
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"]
|
||||
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||
]
|
||||
elif isinstance(message.content, str):
|
||||
message.content = [message.content]
|
||||
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
|
||||
return messages, system_prompt
|
||||
|
||||
|
||||
def get_image_from_url(image_url: str) -> PIL.Image:
|
||||
try:
|
||||
response = requests.get(image_url)
|
||||
response.raise_for_status() # Check if the request was successful
|
||||
return PIL.Image.open(BytesIO(response.content))
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||
return None
|
||||
|
|
|
@ -30,7 +30,7 @@ def extract_questions(
|
|||
api_base_url=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
):
|
||||
|
@ -74,7 +74,7 @@ def extract_questions(
|
|||
|
||||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
image_url=uploaded_image_url,
|
||||
images=query_images,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
vision_enabled=vision_enabled,
|
||||
)
|
||||
|
@ -135,7 +135,7 @@ def converse(
|
|||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
image_url: Optional[str] = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
):
|
||||
"""
|
||||
|
@ -191,7 +191,7 @@ def converse(
|
|||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
uploaded_image_url=image_url,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
)
|
||||
|
|
|
@ -109,7 +109,7 @@ def save_to_conversation_log(
|
|||
client_application: ClientApplication = None,
|
||||
conversation_id: str = None,
|
||||
automation_id: str = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
updated_conversation = message_to_log(
|
||||
|
@ -117,7 +117,7 @@ def save_to_conversation_log(
|
|||
chat_response=chat_response,
|
||||
user_message_metadata={
|
||||
"created": user_message_time,
|
||||
"uploadedImageData": uploaded_image_url,
|
||||
"images": query_images,
|
||||
},
|
||||
khoj_message_metadata={
|
||||
"context": compiled_references,
|
||||
|
@ -145,10 +145,18 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
|||
)
|
||||
|
||||
|
||||
# Format user and system messages to chatml format
|
||||
def construct_structured_message(message, image_url, model_type, vision_enabled):
|
||||
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
|
||||
def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool):
|
||||
"""
|
||||
Format messages into appropriate multimedia format for supported chat model types
|
||||
"""
|
||||
if not images or not vision_enabled:
|
||||
return message
|
||||
|
||||
if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]:
|
||||
return [
|
||||
{"type": "text", "text": message},
|
||||
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
|
||||
]
|
||||
return message
|
||||
|
||||
|
||||
|
@ -160,7 +168,7 @@ def generate_chatml_messages_with_context(
|
|||
loaded_model: Optional[Llama] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
uploaded_image_url=None,
|
||||
query_images=None,
|
||||
vision_enabled=False,
|
||||
model_type="",
|
||||
):
|
||||
|
@ -186,9 +194,7 @@ def generate_chatml_messages_with_context(
|
|||
else:
|
||||
message_content = chat["message"] + message_notes
|
||||
|
||||
message_content = construct_structured_message(
|
||||
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
|
||||
)
|
||||
message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled)
|
||||
|
||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||
|
||||
|
@ -201,7 +207,7 @@ def generate_chatml_messages_with_context(
|
|||
if not is_none_or_empty(user_message):
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
|
||||
content=construct_structured_message(user_message, query_images, model_type, vision_enabled),
|
||||
role="user",
|
||||
)
|
||||
)
|
||||
|
@ -225,7 +231,6 @@ def truncate_messages(
|
|||
tokenizer_name=None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
|
||||
default_tokenizer = "gpt-4o"
|
||||
|
||||
try:
|
||||
|
@ -255,6 +260,7 @@ def truncate_messages(
|
|||
system_message = messages.pop(idx)
|
||||
break
|
||||
|
||||
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
|
||||
system_message_tokens = (
|
||||
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
||||
)
|
||||
|
|
|
@ -26,7 +26,7 @@ async def text_to_image(
|
|||
references: List[Dict[str, Any]],
|
||||
online_results: Dict[str, Any],
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
status_code = 200
|
||||
|
@ -65,7 +65,7 @@ async def text_to_image(
|
|||
note_references=references,
|
||||
online_results=online_results,
|
||||
model_type=text_to_image_config.model_type,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
)
|
||||
|
|
|
@ -62,7 +62,7 @@ async def search_online(
|
|||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
custom_filters: List[str] = [],
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
query += " ".join(custom_filters)
|
||||
|
@ -73,7 +73,7 @@ async def search_online(
|
|||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(
|
||||
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
|
||||
query, conversation_history, location, user, query_images=query_images, agent=agent
|
||||
)
|
||||
response_dict = {}
|
||||
|
||||
|
@ -151,7 +151,7 @@ async def read_webpages(
|
|||
location: LocationData,
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
|
@ -159,7 +159,7 @@ async def read_webpages(
|
|||
if send_status_func:
|
||||
async for event in send_status_func(f"**Inferring web pages to read**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url)
|
||||
urls = await infer_webpage_urls(query, conversation_history, location, user, query_images)
|
||||
|
||||
logger.info(f"Reading web pages at: {urls}")
|
||||
if send_status_func:
|
||||
|
|
|
@ -347,7 +347,7 @@ async def extract_references_and_questions(
|
|||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||
location_data: LocationData = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
@ -438,7 +438,7 @@ async def extract_references_and_questions(
|
|||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
@ -459,12 +459,14 @@ async def extract_references_and_questions(
|
|||
chat_model = conversation_config.chat_model
|
||||
inferred_queries = extract_questions_gemini(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
max_tokens=conversation_config.max_prompt_size,
|
||||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
|
|
|
@ -30,8 +30,10 @@ from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
|||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.routers.api import extract_references_and_questions
|
||||
from khoj.routers.helpers import (
|
||||
ApiImageRateLimiter,
|
||||
ApiUserRateLimiter,
|
||||
ChatEvent,
|
||||
ChatRequestBody,
|
||||
CommonQueryParams,
|
||||
ConversationCommandRateLimiter,
|
||||
agenerate_chat_response,
|
||||
|
@ -524,22 +526,6 @@ async def set_conversation_title(
|
|||
)
|
||||
|
||||
|
||||
class ChatRequestBody(BaseModel):
|
||||
q: str
|
||||
n: Optional[int] = 7
|
||||
d: Optional[float] = None
|
||||
stream: Optional[bool] = False
|
||||
title: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
image: Optional[str] = None
|
||||
create_new: Optional[bool] = False
|
||||
|
||||
|
||||
@api_chat.post("")
|
||||
@requires(["authenticated"])
|
||||
async def chat(
|
||||
|
@ -552,6 +538,7 @@ async def chat(
|
|||
rate_limiter_per_day=Depends(
|
||||
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
|
||||
),
|
||||
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
|
||||
):
|
||||
# Access the parameters from the body
|
||||
q = body.q
|
||||
|
@ -565,9 +552,9 @@ async def chat(
|
|||
country = body.country or get_country_name_from_timezone(body.timezone)
|
||||
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||
timezone = body.timezone
|
||||
image = body.image
|
||||
raw_images = body.images
|
||||
|
||||
async def event_generator(q: str, image: str):
|
||||
async def event_generator(q: str, images: list[str]):
|
||||
start_time = time.perf_counter()
|
||||
ttft = None
|
||||
chat_metadata: dict = {}
|
||||
|
@ -577,16 +564,16 @@ async def chat(
|
|||
q = unquote(q)
|
||||
nonlocal conversation_id
|
||||
|
||||
uploaded_image_url = None
|
||||
if image:
|
||||
uploaded_images: list[str] = []
|
||||
if images:
|
||||
for image in images:
|
||||
decoded_string = unquote(image)
|
||||
base64_data = decoded_string.split(",", 1)[1]
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||
try:
|
||||
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||
except:
|
||||
uploaded_image_url = None
|
||||
uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||
if uploaded_image:
|
||||
uploaded_images.append(uploaded_image)
|
||||
|
||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||
nonlocal connection_alive, ttft
|
||||
|
@ -693,7 +680,7 @@ async def chat(
|
|||
meta_log,
|
||||
is_automated_task,
|
||||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
)
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
|
@ -702,7 +689,7 @@ async def chat(
|
|||
):
|
||||
yield result
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||
yield result
|
||||
if mode not in conversation_commands:
|
||||
|
@ -765,7 +752,7 @@ async def chat(
|
|||
q,
|
||||
contextual_data,
|
||||
conversation_history=meta_log,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
)
|
||||
|
@ -786,7 +773,7 @@ async def chat(
|
|||
intent_type="summarize",
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -829,7 +816,7 @@ async def chat(
|
|||
conversation_id=conversation_id,
|
||||
inferred_queries=[query_to_run],
|
||||
automation_id=automation.id,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
)
|
||||
async for result in send_llm_response(llm_response):
|
||||
yield result
|
||||
|
@ -849,7 +836,7 @@ async def chat(
|
|||
conversation_commands,
|
||||
location,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
@ -893,7 +880,7 @@ async def chat(
|
|||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
custom_filters,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
@ -917,7 +904,7 @@ async def chat(
|
|||
location,
|
||||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
@ -967,20 +954,20 @@ async def chat(
|
|||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
image, status_code, improved_image_prompt, intent_type = result
|
||||
generated_image, status_code, improved_image_prompt, intent_type = result
|
||||
|
||||
if image is None or status_code != 200:
|
||||
if generated_image is None or status_code != 200:
|
||||
content_obj = {
|
||||
"content-type": "application/json",
|
||||
"intentType": intent_type,
|
||||
"detail": improved_image_prompt,
|
||||
"image": image,
|
||||
"image": None,
|
||||
}
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
yield result
|
||||
|
@ -988,7 +975,7 @@ async def chat(
|
|||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
image,
|
||||
generated_image,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
|
@ -998,12 +985,12 @@ async def chat(
|
|||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
)
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
"inferredQueries": [improved_image_prompt],
|
||||
"image": image,
|
||||
"image": generated_image,
|
||||
}
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
yield result
|
||||
|
@ -1023,7 +1010,7 @@ async def chat(
|
|||
location_data=location,
|
||||
note_references=compiled_references,
|
||||
online_results=online_results,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
|
@ -1053,7 +1040,7 @@ async def chat(
|
|||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
)
|
||||
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
|
@ -1076,7 +1063,7 @@ async def chat(
|
|||
conversation_id,
|
||||
location,
|
||||
user_name,
|
||||
uploaded_image_url,
|
||||
uploaded_images,
|
||||
)
|
||||
|
||||
# Send Response
|
||||
|
@ -1102,9 +1089,9 @@ async def chat(
|
|||
|
||||
## Stream Text Response
|
||||
if stream:
|
||||
return StreamingResponse(event_generator(q, image=image), media_type="text/plain")
|
||||
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
|
||||
## Non-Streaming Text Response
|
||||
else:
|
||||
response_iterator = event_generator(q, image=image)
|
||||
response_iterator = event_generator(q, images=raw_images)
|
||||
response_data = await read_chat_stream(response_iterator)
|
||||
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
|
@ -22,7 +23,7 @@ from typing import (
|
|||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
||||
from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse
|
||||
|
||||
import cron_descriptor
|
||||
import pytz
|
||||
|
@ -31,6 +32,7 @@ from apscheduler.job import Job
|
|||
from apscheduler.triggers.cron import CronTrigger
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from starlette.authentication import has_required_scope
|
||||
from starlette.requests import URL
|
||||
|
||||
|
@ -296,7 +298,7 @@ async def aget_relevant_information_sources(
|
|||
conversation_history: dict,
|
||||
is_task: bool,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"""
|
||||
|
@ -315,8 +317,8 @@ async def aget_relevant_information_sources(
|
|||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
if uploaded_image_url:
|
||||
query = f"[placeholder for user attached image]\n{query}"
|
||||
if query_images:
|
||||
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
|
@ -373,7 +375,7 @@ async def aget_relevant_output_modes(
|
|||
conversation_history: dict,
|
||||
is_task: bool = False,
|
||||
user: KhojUser = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"""
|
||||
|
@ -395,8 +397,8 @@ async def aget_relevant_output_modes(
|
|||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
if uploaded_image_url:
|
||||
query = f"[placeholder for user attached image]\n{query}"
|
||||
if query_images:
|
||||
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
|
@ -439,7 +441,7 @@ async def infer_webpage_urls(
|
|||
conversation_history: dict,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
|
@ -465,7 +467,7 @@ async def infer_webpage_urls(
|
|||
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||
|
@ -485,7 +487,7 @@ async def generate_online_subqueries(
|
|||
conversation_history: dict,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
|
@ -511,7 +513,7 @@ async def generate_online_subqueries(
|
|||
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
@ -530,7 +532,7 @@ async def generate_online_subqueries(
|
|||
|
||||
|
||||
async def schedule_query(
|
||||
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None
|
||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
|
||||
) -> Tuple[str, ...]:
|
||||
"""
|
||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||
|
@ -543,7 +545,7 @@ async def schedule_query(
|
|||
)
|
||||
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||
crontime_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
@ -589,7 +591,7 @@ async def extract_relevant_summary(
|
|||
q: str,
|
||||
corpus: str,
|
||||
conversation_history: dict,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
) -> Union[str, None]:
|
||||
|
@ -618,7 +620,7 @@ async def extract_relevant_summary(
|
|||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_summary,
|
||||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
@ -629,7 +631,7 @@ async def generate_excalidraw_diagram(
|
|||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
|
@ -644,7 +646,7 @@ async def generate_excalidraw_diagram(
|
|||
location_data=location_data,
|
||||
note_references=note_references,
|
||||
online_results=online_results,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
)
|
||||
|
@ -668,7 +670,7 @@ async def generate_better_diagram_description(
|
|||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
) -> str:
|
||||
|
@ -711,7 +713,7 @@ async def generate_better_diagram_description(
|
|||
|
||||
with timer("Chat actor: Generate better diagram description", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
improve_diagram_description_prompt, uploaded_image_url=uploaded_image_url, user=user
|
||||
improve_diagram_description_prompt, query_images=query_images, user=user
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
|
@ -753,7 +755,7 @@ async def generate_better_image_prompt(
|
|||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
model_type: Optional[str] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
) -> str:
|
||||
|
@ -805,7 +807,7 @@ async def generate_better_image_prompt(
|
|||
)
|
||||
|
||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user)
|
||||
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
|
@ -818,11 +820,11 @@ async def send_message_to_model_wrapper(
|
|||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
user: KhojUser = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
):
|
||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
if not vision_available and uploaded_image_url:
|
||||
if not vision_available and query_images:
|
||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||
if vision_enabled_config:
|
||||
conversation_config = vision_enabled_config
|
||||
|
@ -875,7 +877,7 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
|
@ -895,7 +897,7 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
|
@ -913,7 +915,8 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
|
@ -1004,6 +1007,7 @@ def send_message_to_model_wrapper_sync(
|
|||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
|
@ -1029,7 +1033,7 @@ def generate_chat_response(
|
|||
conversation_id: str = None,
|
||||
location_data: LocationData = None,
|
||||
user_name: Optional[str] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response = None
|
||||
|
@ -1048,12 +1052,12 @@ def generate_chat_response(
|
|||
inferred_queries=inferred_queries,
|
||||
client_application=client_application,
|
||||
conversation_id=conversation_id,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
)
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
if not vision_available and uploaded_image_url:
|
||||
if not vision_available and query_images:
|
||||
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
||||
if vision_enabled_config:
|
||||
conversation_config = vision_enabled_config
|
||||
|
@ -1084,7 +1088,7 @@ def generate_chat_response(
|
|||
chat_response = converse(
|
||||
compiled_references,
|
||||
q,
|
||||
image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
online_results=online_results,
|
||||
conversation_log=meta_log,
|
||||
model=chat_model,
|
||||
|
@ -1122,8 +1126,9 @@ def generate_chat_response(
|
|||
chat_response = converse_gemini(
|
||||
compiled_references,
|
||||
q,
|
||||
online_results,
|
||||
meta_log,
|
||||
query_images=query_images,
|
||||
online_results=online_results,
|
||||
conversation_log=meta_log,
|
||||
model=conversation_config.chat_model,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
|
@ -1133,6 +1138,7 @@ def generate_chat_response(
|
|||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
)
|
||||
|
||||
metadata.update({"chat_model": conversation_config.chat_model})
|
||||
|
@ -1144,6 +1150,22 @@ def generate_chat_response(
|
|||
return chat_response, metadata
|
||||
|
||||
|
||||
class ChatRequestBody(BaseModel):
|
||||
q: str
|
||||
n: Optional[int] = 7
|
||||
d: Optional[float] = None
|
||||
stream: Optional[bool] = False
|
||||
title: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
images: Optional[list[str]] = None
|
||||
create_new: Optional[bool] = False
|
||||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
||||
self.requests = requests
|
||||
|
@ -1189,13 +1211,58 @@ class ApiUserRateLimiter:
|
|||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
|
||||
detail="I'm glad you're enjoying interacting with me! But you've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
|
||||
)
|
||||
|
||||
# Add the current request to the cache
|
||||
UserRequests.objects.create(user=user, slug=self.slug)
|
||||
|
||||
|
||||
class ApiImageRateLimiter:
|
||||
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
|
||||
self.max_images = max_images
|
||||
self.max_combined_size_mb = max_combined_size_mb
|
||||
|
||||
def __call__(self, request: Request, body: ChatRequestBody):
|
||||
if state.billing_enabled is False:
|
||||
return
|
||||
|
||||
# Rate limiting is disabled if user unauthenticated.
|
||||
# Other systems handle authentication
|
||||
if not request.user.is_authenticated:
|
||||
return
|
||||
|
||||
if not body.images:
|
||||
return
|
||||
|
||||
# Check number of images
|
||||
if len(body.images) > self.max_images:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
|
||||
)
|
||||
|
||||
# Check total size of images
|
||||
total_size_mb = 0.0
|
||||
for image in body.images:
|
||||
# Unquote the image in case it's URL encoded
|
||||
image = unquote(image)
|
||||
# Assuming the image is a base64 encoded string
|
||||
# Remove the data:image/jpeg;base64, part if present
|
||||
if "," in image:
|
||||
image = image.split(",", 1)[1]
|
||||
|
||||
# Decode base64 to get the actual size
|
||||
image_bytes = base64.b64decode(image)
|
||||
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
|
||||
|
||||
if total_size_mb > self.max_combined_size_mb:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
||||
)
|
||||
|
||||
|
||||
class ConversationCommandRateLimiter:
|
||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
||||
self.slug = slug
|
||||
|
|
|
@ -352,9 +352,9 @@ tool_descriptions_for_llm = {
|
|||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.",
|
||||
ConversationCommand.Image: "Use this if the user is requesting you to create a new picture based on their description.",
|
||||
ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency",
|
||||
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
|
||||
ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query.",
|
||||
ConversationCommand.Diagram: "Use this if the user is requesting a visual representation that requires primitives like lines, rectangles, and text.",
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue