Chat with Multiple Images. Support Vision with Gemini (#942)

## Overview
- Add vision support for Gemini models in Khoj
- Allow sharing multiple images as part of user query from the web app
- Handle multiple images shared in query to chat API
This commit is contained in:
Debanjum 2024-10-22 19:59:18 -07:00 committed by GitHub
commit c6f3253ebd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 454 additions and 294 deletions

View file

@ -27,14 +27,14 @@ interface ChatBodyDataProps {
setUploadedFiles: (files: string[]) => void; setUploadedFiles: (files: string[]) => void;
isMobileWidth?: boolean; isMobileWidth?: boolean;
isLoggedIn: boolean; isLoggedIn: boolean;
setImage64: (image64: string) => void; setImages: (images: string[]) => void;
} }
function ChatBodyData(props: ChatBodyDataProps) { function ChatBodyData(props: ChatBodyDataProps) {
const searchParams = useSearchParams(); const searchParams = useSearchParams();
const conversationId = searchParams.get("conversationId"); const conversationId = searchParams.get("conversationId");
const [message, setMessage] = useState(""); const [message, setMessage] = useState("");
const [image, setImage] = useState<string | null>(null); const [images, setImages] = useState<string[]>([]);
const [processingMessage, setProcessingMessage] = useState(false); const [processingMessage, setProcessingMessage] = useState(false);
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null); const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
@ -44,17 +44,20 @@ function ChatBodyData(props: ChatBodyDataProps) {
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6"; const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
useEffect(() => { useEffect(() => {
if (image) { if (images.length > 0) {
props.setImage64(encodeURIComponent(image)); const encodedImages = images.map((image) => encodeURIComponent(image));
props.setImages(encodedImages);
} }
}, [image, props.setImage64]); }, [images, props.setImages]);
useEffect(() => { useEffect(() => {
const storedImage = localStorage.getItem("image"); const storedImages = localStorage.getItem("images");
if (storedImage) { if (storedImages) {
setImage(storedImage); const parsedImages: string[] = JSON.parse(storedImages);
props.setImage64(encodeURIComponent(storedImage)); setImages(parsedImages);
localStorage.removeItem("image"); const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
props.setImages(encodedImages);
localStorage.removeItem("images");
} }
const storedMessage = localStorage.getItem("message"); const storedMessage = localStorage.getItem("message");
@ -62,7 +65,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
setProcessingMessage(true); setProcessingMessage(true);
setQueryToProcess(storedMessage); setQueryToProcess(storedMessage);
} }
}, [setQueryToProcess]); }, [setQueryToProcess, props.setImages]);
useEffect(() => { useEffect(() => {
if (message) { if (message) {
@ -84,6 +87,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
props.streamedMessages[props.streamedMessages.length - 1].completed props.streamedMessages[props.streamedMessages.length - 1].completed
) { ) {
setProcessingMessage(false); setProcessingMessage(false);
setImages([]); // Reset images after processing
} else { } else {
setMessage(""); setMessage("");
} }
@ -113,7 +117,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
agentColor={agentMetadata?.color} agentColor={agentMetadata?.color}
isLoggedIn={props.isLoggedIn} isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)} sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImage(image)} sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage} sendDisabled={processingMessage}
chatOptionsData={props.chatOptionsData} chatOptionsData={props.chatOptionsData}
conversationId={conversationId} conversationId={conversationId}
@ -135,7 +139,7 @@ export default function Chat() {
const [queryToProcess, setQueryToProcess] = useState<string>(""); const [queryToProcess, setQueryToProcess] = useState<string>("");
const [processQuerySignal, setProcessQuerySignal] = useState(false); const [processQuerySignal, setProcessQuerySignal] = useState(false);
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]); const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
const [image64, setImage64] = useState<string>(""); const [images, setImages] = useState<string[]>([]);
const locationData = useIPLocationData() || { const locationData = useIPLocationData() || {
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone, timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
@ -171,7 +175,7 @@ export default function Chat() {
completed: false, completed: false,
timestamp: new Date().toISOString(), timestamp: new Date().toISOString(),
rawQuery: queryToProcess || "", rawQuery: queryToProcess || "",
uploadedImageData: decodeURIComponent(image64), images: images,
}; };
setMessages((prevMessages) => [...prevMessages, newStreamMessage]); setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
setProcessQuerySignal(true); setProcessQuerySignal(true);
@ -202,7 +206,7 @@ export default function Chat() {
if (done) { if (done) {
setQueryToProcess(""); setQueryToProcess("");
setProcessQuerySignal(false); setProcessQuerySignal(false);
setImage64(""); setImages([]);
break; break;
} }
@ -250,7 +254,7 @@ export default function Chat() {
country_code: locationData.countryCode, country_code: locationData.countryCode,
timezone: locationData.timezone, timezone: locationData.timezone,
}), }),
...(image64 && { image: image64 }), ...(images.length > 0 && { images: images }),
}; };
const response = await fetch(chatAPI, { const response = await fetch(chatAPI, {
@ -264,7 +268,8 @@ export default function Chat() {
try { try {
await readChatStream(response); await readChatStream(response);
} catch (err) { } catch (err) {
console.error(err); const apiError = await response.json();
console.error(apiError);
// Retrieve latest message being processed // Retrieve latest message being processed
const currentMessage = messages.find((message) => !message.completed); const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) return; if (!currentMessage) return;
@ -273,7 +278,11 @@ export default function Chat() {
const errorMessage = (err as Error).message; const errorMessage = (err as Error).message;
if (errorMessage.includes("Error in input stream")) if (errorMessage.includes("Error in input stream"))
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`; currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
else else if (response.status === 429) {
"detail" in apiError
? (currentMessage.rawResponse = `${apiError.detail}`)
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
} else
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`; currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
// Complete message streaming teardown properly // Complete message streaming teardown properly
@ -332,7 +341,7 @@ export default function Chat() {
setUploadedFiles={setUploadedFiles} setUploadedFiles={setUploadedFiles}
isMobileWidth={isMobileWidth} isMobileWidth={isMobileWidth}
onConversationIdChange={handleConversationIdChange} onConversationIdChange={handleConversationIdChange}
setImage64={setImage64} setImages={setImages}
/> />
</Suspense> </Suspense>
</div> </div>

View file

@ -299,7 +299,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
created: message.timestamp, created: message.timestamp,
by: "you", by: "you",
automationId: "", automationId: "",
uploadedImageData: message.uploadedImageData, images: message.images,
}} }}
customClassName="fullHistory" customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`} borderLeftColor={`${data?.agent?.color}-500`}
@ -348,7 +348,6 @@ export default function ChatHistory(props: ChatHistoryProps) {
created: new Date().getTime().toString(), created: new Date().getTime().toString(),
by: "you", by: "you",
automationId: "", automationId: "",
uploadedImageData: props.pendingMessage,
}} }}
customClassName="fullHistory" customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`} borderLeftColor={`${data?.agent?.color}-500`}

View file

@ -3,23 +3,7 @@ import React, { useEffect, useRef, useState } from "react";
import DOMPurify from "dompurify"; import DOMPurify from "dompurify";
import "katex/dist/katex.min.css"; import "katex/dist/katex.min.css";
import { import { ArrowUp, Microphone, Paperclip, X, Stop } from "@phosphor-icons/react";
ArrowRight,
ArrowUp,
Browser,
ChatsTeardrop,
GlobeSimple,
Gps,
Image,
Microphone,
Notebook,
Paperclip,
X,
Question,
Robot,
Shapes,
Stop,
} from "@phosphor-icons/react";
import { import {
Command, Command,
@ -78,10 +62,11 @@ export default function ChatInputArea(props: ChatInputProps) {
const [loginRedirectMessage, setLoginRedirectMessage] = useState<string | null>(null); const [loginRedirectMessage, setLoginRedirectMessage] = useState<string | null>(null);
const [showLoginPrompt, setShowLoginPrompt] = useState(false); const [showLoginPrompt, setShowLoginPrompt] = useState(false);
const [recording, setRecording] = useState(false);
const [imageUploaded, setImageUploaded] = useState(false); const [imageUploaded, setImageUploaded] = useState(false);
const [imagePath, setImagePath] = useState<string>(""); const [imagePaths, setImagePaths] = useState<string[]>([]);
const [imageData, setImageData] = useState<string | null>(null); const [imageData, setImageData] = useState<string[]>([]);
const [recording, setRecording] = useState(false);
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null); const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
const [progressValue, setProgressValue] = useState(0); const [progressValue, setProgressValue] = useState(0);
@ -106,27 +91,31 @@ export default function ChatInputArea(props: ChatInputProps) {
useEffect(() => { useEffect(() => {
async function fetchImageData() { async function fetchImageData() {
if (imagePath) { if (imagePaths.length > 0) {
const response = await fetch(imagePath); const newImageData = await Promise.all(
imagePaths.map(async (path) => {
const response = await fetch(path);
const blob = await response.blob(); const blob = await response.blob();
return new Promise<string>((resolve) => {
const reader = new FileReader(); const reader = new FileReader();
reader.onload = function () { reader.onload = () => resolve(reader.result as string);
const base64data = reader.result;
setImageData(base64data as string);
};
reader.readAsDataURL(blob); reader.readAsDataURL(blob);
});
}),
);
setImageData(newImageData);
} }
setUploading(false); setUploading(false);
} }
setUploading(true); setUploading(true);
fetchImageData(); fetchImageData();
}, [imagePath]); }, [imagePaths]);
function onSendMessage() { function onSendMessage() {
if (imageUploaded) { if (imageUploaded) {
setImageUploaded(false); setImageUploaded(false);
setImagePath(""); setImagePaths([]);
props.sendImage(imageData || ""); imageData.forEach((data) => props.sendImage(data));
} }
if (!message.trim()) return; if (!message.trim()) return;
@ -172,18 +161,23 @@ export default function ChatInputArea(props: ChatInputProps) {
setShowLoginPrompt(true); setShowLoginPrompt(true);
return; return;
} }
// check for image file // check for image files
const image_endings = ["jpg", "jpeg", "png", "webp"]; const image_endings = ["jpg", "jpeg", "png", "webp"];
const newImagePaths: string[] = [];
for (let i = 0; i < files.length; i++) { for (let i = 0; i < files.length; i++) {
const file = files[i]; const file = files[i];
const file_extension = file.name.split(".").pop(); const file_extension = file.name.split(".").pop();
if (image_endings.includes(file_extension || "")) { if (image_endings.includes(file_extension || "")) {
setImageUploaded(true); newImagePaths.push(DOMPurify.sanitize(URL.createObjectURL(file)));
setImagePath(DOMPurify.sanitize(URL.createObjectURL(file)));
return;
} }
} }
if (newImagePaths.length > 0) {
setImageUploaded(true);
setImagePaths((prevPaths) => [...prevPaths, ...newImagePaths]);
return;
}
uploadDataForIndexing( uploadDataForIndexing(
files, files,
setWarning, setWarning,
@ -288,9 +282,12 @@ export default function ChatInputArea(props: ChatInputProps) {
setIsDragAndDropping(false); setIsDragAndDropping(false);
} }
function removeImageUpload() { function removeImageUpload(index: number) {
setImagePaths((prevPaths) => prevPaths.filter((_, i) => i !== index));
setImageData((prevData) => prevData.filter((_, i) => i !== index));
if (imagePaths.length === 1) {
setImageUploaded(false); setImageUploaded(false);
setImagePath(""); }
} }
return ( return (
@ -407,24 +404,11 @@ export default function ChatInputArea(props: ChatInputProps) {
</div> </div>
)} )}
<div <div
className={`${styles.actualInputArea} items-center justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`} className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
onDragOver={handleDragOver} onDragOver={handleDragOver}
onDragLeave={handleDragLeave} onDragLeave={handleDragLeave}
onDrop={handleDragAndDropFiles} onDrop={handleDragAndDropFiles}
> >
{imageUploaded && (
<div className="absolute bottom-[80px] left-0 right-0 dark:bg-neutral-700 bg-white pt-5 pb-5 w-full rounded-lg border dark:border-none grid grid-cols-2">
<div className="pl-4 pr-4">
<img src={imagePath} alt="img" className="w-auto max-h-[100px]" />
</div>
<div className="pl-4 pr-4">
<X
className="w-6 h-6 float-right dark:hover:bg-[hsl(var(--background))] hover:bg-neutral-100 rounded-sm"
onClick={removeImageUpload}
/>
</div>
</div>
)}
<input <input
type="file" type="file"
multiple={true} multiple={true}
@ -432,6 +416,7 @@ export default function ChatInputArea(props: ChatInputProps) {
onChange={handleFileChange} onChange={handleFileChange}
style={{ display: "none" }} style={{ display: "none" }}
/> />
<div className="flex items-end pb-4">
<Button <Button
variant={"ghost"} variant={"ghost"}
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500" className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
@ -440,7 +425,28 @@ export default function ChatInputArea(props: ChatInputProps) {
> >
<Paperclip className="w-8 h-8" /> <Paperclip className="w-8 h-8" />
</Button> </Button>
<div className="grid w-full gap-1.5 relative"> </div>
<div className="flex-grow flex flex-col w-full gap-1.5 relative pb-2">
<div className="flex items-center gap-2 overflow-x-auto">
{imageUploaded &&
imagePaths.map((path, index) => (
<div key={index} className="relative flex-shrink-0 pb-3 pt-2 group">
<img
src={path}
alt={`img-${index}`}
className="w-auto h-16 object-cover rounded-xl"
/>
<Button
variant="ghost"
size="icon"
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
onClick={() => removeImageUpload(index)}
>
<X className="h-3 w-3" />
</Button>
</div>
))}
</div>
<Textarea <Textarea
ref={chatInputRef} ref={chatInputRef}
className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`} className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`}
@ -451,7 +457,7 @@ export default function ChatInputArea(props: ChatInputProps) {
onKeyDown={(e) => { onKeyDown={(e) => {
if (e.key === "Enter" && !e.shiftKey) { if (e.key === "Enter" && !e.shiftKey) {
setImageUploaded(false); setImageUploaded(false);
setImagePath(""); setImagePaths([]);
e.preventDefault(); e.preventDefault();
onSendMessage(); onSendMessage();
} }
@ -460,6 +466,7 @@ export default function ChatInputArea(props: ChatInputProps) {
disabled={props.sendDisabled || recording} disabled={props.sendDisabled || recording}
/> />
</div> </div>
<div className="flex items-end pb-4">
{recording ? ( {recording ? (
<TooltipProvider> <TooltipProvider>
<Tooltip> <Tooltip>
@ -512,6 +519,7 @@ export default function ChatInputArea(props: ChatInputProps) {
<ArrowUp className="w-6 h-6" weight="bold" /> <ArrowUp className="w-6 h-6" weight="bold" />
</Button> </Button>
</div> </div>
</div>
</> </>
); );
} }

View file

@ -57,7 +57,26 @@ div.emptyChatMessage {
display: none; display: none;
} }
div.chatMessageContainer img { div.imagesContainer {
display: flex;
overflow-x: auto;
padding-bottom: 8px;
margin-bottom: 8px;
}
div.imageWrapper {
flex: 0 0 auto;
margin-right: 8px;
}
div.imageWrapper img {
width: auto;
height: 128px;
object-fit: cover;
border-radius: 8px;
}
div.chatMessageContainer > img {
width: auto; width: auto;
height: auto; height: auto;
max-width: 100%; max-width: 100%;

View file

@ -116,7 +116,7 @@ export interface SingleChatMessage {
rawQuery?: string; rawQuery?: string;
intent?: Intent; intent?: Intent;
agent?: AgentData; agent?: AgentData;
uploadedImageData?: string; images?: string[];
} }
export interface StreamMessage { export interface StreamMessage {
@ -128,10 +128,9 @@ export interface StreamMessage {
rawQuery: string; rawQuery: string;
timestamp: string; timestamp: string;
agent?: AgentData; agent?: AgentData;
uploadedImageData?: string; images?: string[];
intentType?: string; intentType?: string;
inferredQueries?: string[]; inferredQueries?: string[];
image?: string;
} }
export interface ChatHistoryData { export interface ChatHistoryData {
@ -213,7 +212,6 @@ interface ChatMessageProps {
borderLeftColor?: string; borderLeftColor?: string;
isLastMessage?: boolean; isLastMessage?: boolean;
agent?: AgentData; agent?: AgentData;
uploadedImageData?: string;
} }
interface TrainOfThoughtProps { interface TrainOfThoughtProps {
@ -343,8 +341,17 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
.replace(/\\\[/g, "LEFTBRACKET") .replace(/\\\[/g, "LEFTBRACKET")
.replace(/\\\]/g, "RIGHTBRACKET"); .replace(/\\\]/g, "RIGHTBRACKET");
if (props.chatMessage.uploadedImageData) { if (props.chatMessage.images && props.chatMessage.images.length > 0) {
message = `![uploaded image](${props.chatMessage.uploadedImageData})\n\n${message}`; const imagesInMd = props.chatMessage.images
.map((image, index) => {
const decodedImage = image.startsWith("data%3Aimage")
? decodeURIComponent(image)
: image;
const sanitizedImage = DOMPurify.sanitize(decodedImage);
return `<div class="${styles.imageWrapper}"><img src="${sanitizedImage}" alt="uploaded image ${index + 1}" /></div>`;
})
.join("");
message = `<div class="${styles.imagesContainer}">${imagesInMd}</div>${message}`;
} }
const intentTypeHandlers = { const intentTypeHandlers = {
@ -384,7 +391,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
// Sanitize and set the rendered markdown // Sanitize and set the rendered markdown
setMarkdownRendered(DOMPurify.sanitize(markdownRendered)); setMarkdownRendered(DOMPurify.sanitize(markdownRendered));
}, [props.chatMessage.message, props.chatMessage.intent]); }, [props.chatMessage.message, props.chatMessage.images, props.chatMessage.intent]);
useEffect(() => { useEffect(() => {
if (copySuccess) { if (copySuccess) {

View file

@ -44,7 +44,7 @@ function FisherYatesShuffle(array: any[]) {
function ChatBodyData(props: ChatBodyDataProps) { function ChatBodyData(props: ChatBodyDataProps) {
const [message, setMessage] = useState(""); const [message, setMessage] = useState("");
const [image, setImage] = useState<string | null>(null); const [images, setImages] = useState<string[]>([]);
const [processingMessage, setProcessingMessage] = useState(false); const [processingMessage, setProcessingMessage] = useState(false);
const [greeting, setGreeting] = useState(""); const [greeting, setGreeting] = useState("");
const [shuffledOptions, setShuffledOptions] = useState<Suggestion[]>([]); const [shuffledOptions, setShuffledOptions] = useState<Suggestion[]>([]);
@ -138,20 +138,21 @@ function ChatBodyData(props: ChatBodyDataProps) {
try { try {
const newConversationId = await createNewConversation(selectedAgent || "khoj"); const newConversationId = await createNewConversation(selectedAgent || "khoj");
onConversationIdChange?.(newConversationId); onConversationIdChange?.(newConversationId);
window.location.href = `/chat?conversationId=${newConversationId}`;
localStorage.setItem("message", message); localStorage.setItem("message", message);
if (image) { if (images.length > 0) {
localStorage.setItem("image", image); localStorage.setItem("images", JSON.stringify(images));
} }
window.location.href = `/chat?conversationId=${newConversationId}`;
} catch (error) { } catch (error) {
console.error("Error creating new conversation:", error); console.error("Error creating new conversation:", error);
setProcessingMessage(false); setProcessingMessage(false);
} }
setMessage(""); setMessage("");
setImages([]);
} }
}; };
processMessage(); processMessage();
if (message) { if (message || images.length > 0) {
setProcessingMessage(true); setProcessingMessage(true);
} }
}, [selectedAgent, message, processingMessage, onConversationIdChange]); }, [selectedAgent, message, processingMessage, onConversationIdChange]);
@ -224,7 +225,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
</div> </div>
)} )}
</div> </div>
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit"}`}> <div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit max-w-screen-md"}`}>
{!props.isMobileWidth && ( {!props.isMobileWidth && (
<div <div
className={`w-full ${styles.inputBox} shadow-lg bg-background align-middle items-center justify-center px-3 py-1 dark:bg-neutral-700 border-stone-100 dark:border-none dark:shadow-none rounded-2xl`} className={`w-full ${styles.inputBox} shadow-lg bg-background align-middle items-center justify-center px-3 py-1 dark:bg-neutral-700 border-stone-100 dark:border-none dark:shadow-none rounded-2xl`}
@ -232,7 +233,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
<ChatInputArea <ChatInputArea
isLoggedIn={props.isLoggedIn} isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)} sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImage(image)} sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage} sendDisabled={processingMessage}
chatOptionsData={props.chatOptionsData} chatOptionsData={props.chatOptionsData}
conversationId={null} conversationId={null}
@ -313,7 +314,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
<ChatInputArea <ChatInputArea
isLoggedIn={props.isLoggedIn} isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)} sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImage(image)} sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage} sendDisabled={processingMessage}
chatOptionsData={props.chatOptionsData} chatOptionsData={props.chatOptionsData}
conversationId={null} conversationId={null}

View file

@ -28,12 +28,12 @@ interface ChatBodyDataProps {
isLoggedIn: boolean; isLoggedIn: boolean;
conversationId?: string; conversationId?: string;
setQueryToProcess: (query: string) => void; setQueryToProcess: (query: string) => void;
setImage64: (image64: string) => void; setImages: (images: string[]) => void;
} }
function ChatBodyData(props: ChatBodyDataProps) { function ChatBodyData(props: ChatBodyDataProps) {
const [message, setMessage] = useState(""); const [message, setMessage] = useState("");
const [image, setImage] = useState<string | null>(null); const [images, setImages] = useState<string[]>([]);
const [processingMessage, setProcessingMessage] = useState(false); const [processingMessage, setProcessingMessage] = useState(false);
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null); const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
const setQueryToProcess = props.setQueryToProcess; const setQueryToProcess = props.setQueryToProcess;
@ -42,10 +42,28 @@ function ChatBodyData(props: ChatBodyDataProps) {
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6"; const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
useEffect(() => { useEffect(() => {
if (image) { if (images.length > 0) {
props.setImage64(encodeURIComponent(image)); const encodedImages = images.map((image) => encodeURIComponent(image));
props.setImages(encodedImages);
} }
}, [image, props.setImage64]); }, [images, props.setImages]);
useEffect(() => {
const storedImages = localStorage.getItem("images");
if (storedImages) {
const parsedImages: string[] = JSON.parse(storedImages);
setImages(parsedImages);
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
props.setImages(encodedImages);
localStorage.removeItem("images");
}
const storedMessage = localStorage.getItem("message");
if (storedMessage) {
setProcessingMessage(true);
setQueryToProcess(storedMessage);
}
}, [setQueryToProcess, props.setImages]);
useEffect(() => { useEffect(() => {
if (message) { if (message) {
@ -89,7 +107,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
<ChatInputArea <ChatInputArea
isLoggedIn={props.isLoggedIn} isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)} sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImage(image)} sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage} sendDisabled={processingMessage}
chatOptionsData={props.chatOptionsData} chatOptionsData={props.chatOptionsData}
conversationId={props.conversationId} conversationId={props.conversationId}
@ -112,7 +130,7 @@ export default function SharedChat() {
const [processQuerySignal, setProcessQuerySignal] = useState(false); const [processQuerySignal, setProcessQuerySignal] = useState(false);
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]); const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined); const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
const [image64, setImage64] = useState<string>(""); const [images, setImages] = useState<string[]>([]);
const locationData = useIPLocationData() || { const locationData = useIPLocationData() || {
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone, timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
@ -170,7 +188,7 @@ export default function SharedChat() {
completed: false, completed: false,
timestamp: new Date().toISOString(), timestamp: new Date().toISOString(),
rawQuery: queryToProcess || "", rawQuery: queryToProcess || "",
uploadedImageData: decodeURIComponent(image64), images: images,
}; };
setMessages((prevMessages) => [...prevMessages, newStreamMessage]); setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
setProcessQuerySignal(true); setProcessQuerySignal(true);
@ -197,7 +215,7 @@ export default function SharedChat() {
if (done) { if (done) {
setQueryToProcess(""); setQueryToProcess("");
setProcessQuerySignal(false); setProcessQuerySignal(false);
setImage64(""); setImages([]);
break; break;
} }
@ -239,7 +257,7 @@ export default function SharedChat() {
country_code: locationData.countryCode, country_code: locationData.countryCode,
timezone: locationData.timezone, timezone: locationData.timezone,
}), }),
...(image64 && { image: image64 }), ...(images.length > 0 && { image: images }),
}; };
const response = await fetch(chatAPI, { const response = await fetch(chatAPI, {
@ -302,7 +320,7 @@ export default function SharedChat() {
setTitle={setTitle} setTitle={setTitle}
setUploadedFiles={setUploadedFiles} setUploadedFiles={setUploadedFiles}
isMobileWidth={isMobileWidth} isMobileWidth={isMobileWidth}
setImage64={setImage64} setImages={setImages}
/> />
</Suspense> </Suspense>
</div> </div>

View file

@ -6,14 +6,17 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage from langchain.schema import ChatMessage
from khoj.database.models import Agent, KhojUser from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import ( from khoj.processor.conversation.google.utils import (
format_messages_for_gemini, format_messages_for_gemini,
gemini_chat_completion_with_backoff, gemini_chat_completion_with_backoff,
gemini_completion_with_backoff, gemini_completion_with_backoff,
) )
from khoj.processor.conversation.utils import generate_chatml_messages_with_context from khoj.processor.conversation.utils import (
construct_structured_message,
generate_chatml_messages_with_context,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import LocationData
@ -29,6 +32,8 @@ def extract_questions_gemini(
max_tokens=None, max_tokens=None,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None, user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
): ):
""" """
@ -70,17 +75,17 @@ def extract_questions_gemini(
text=text, text=text,
) )
messages = [ChatMessage(content=prompt, role="user")] prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.GOOGLE,
vision_enabled=vision_enabled,
)
model_kwargs = {"response_mime_type": "application/json"} messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
response = gemini_completion_with_backoff( response = gemini_send_message_to_model(
messages=messages, messages, api_key, model, response_type="json_object", temperature=temperature
system_prompt=system_prompt,
model_name=model,
temperature=temperature,
api_key=api_key,
model_kwargs=model_kwargs,
) )
# Extract, Clean Message from Gemini's Response # Extract, Clean Message from Gemini's Response
@ -102,7 +107,7 @@ def extract_questions_gemini(
return questions return questions
def gemini_send_message_to_model(messages, api_key, model, response_type="text"): def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
""" """
Send message to model Send message to model
""" """
@ -114,7 +119,12 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
# Get Response from Gemini # Get Response from Gemini
return gemini_completion_with_backoff( return gemini_completion_with_backoff(
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs messages=messages,
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
temperature=temperature,
model_kwargs=model_kwargs,
) )
@ -133,6 +143,8 @@ def converse_gemini(
location_data: LocationData = None, location_data: LocationData = None,
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
): ):
""" """
Converse with user using Google's Gemini Converse with user using Google's Gemini
@ -187,6 +199,9 @@ def converse_gemini(
model_name=model, model_name=model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE,
) )
messages, system_prompt = format_messages_for_gemini(messages, system_prompt) messages, system_prompt = format_messages_for_gemini(messages, system_prompt)

View file

@ -1,8 +1,11 @@
import logging import logging
import random import random
from io import BytesIO
from threading import Thread from threading import Thread
import google.generativeai as genai import google.generativeai as genai
import PIL.Image
import requests
from google.generativeai.types.answer_types import FinishReason from google.generativeai.types.answer_types import FinishReason
from google.generativeai.types.generation_types import StopCandidateException from google.generativeai.types.generation_types import StopCandidateException
from google.generativeai.types.safety_types import ( from google.generativeai.types.safety_types import (
@ -53,14 +56,14 @@ def gemini_completion_with_backoff(
}, },
) )
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages] formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# Start chat session. All messages up to the last are considered to be part of the chat history # Start chat session. All messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1]) chat_session = model.start_chat(history=formatted_messages[0:-1])
try: try:
# Generate the response. The last message is considered to be the current prompt # Generate the response. The last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0]) aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
return aggregated_response.text return aggregated_response.text
except StopCandidateException as e: except StopCandidateException as e:
response_message, _ = handle_gemini_response(e.args) response_message, _ = handle_gemini_response(e.args)
@ -117,11 +120,11 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
}, },
) )
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages] formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# all messages up to the last are considered to be part of the chat history # all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1]) chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt # the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True): for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text message = message or chunk.text
g.send(message) g.send(message)
@ -191,14 +194,6 @@ def generate_safety_response(safety_ratings):
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]: def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
if len(messages) == 1:
messages[0].role = "user"
return messages, system_prompt
for message in messages:
if message.role == "assistant":
message.role = "model"
# Extract system message # Extract system message
system_prompt = system_prompt or "" system_prompt = system_prompt or ""
for message in messages.copy(): for message in messages.copy():
@ -207,4 +202,31 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
messages.remove(message) messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
for message in messages:
# Convert message content to string list from chatml dictionary list
if isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message.content = [
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"]
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
]
elif isinstance(message.content, str):
message.content = [message.content]
if message.role == "assistant":
message.role = "model"
if len(messages) == 1:
messages[0].role = "user"
return messages, system_prompt return messages, system_prompt
def get_image_from_url(image_url: str) -> PIL.Image:
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
return PIL.Image.open(BytesIO(response.content))
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return None

View file

@ -30,7 +30,7 @@ def extract_questions(
api_base_url=None, api_base_url=None,
location_data: LocationData = None, location_data: LocationData = None,
user: KhojUser = None, user: KhojUser = None,
uploaded_image_url: Optional[str] = None, query_images: Optional[list[str]] = None,
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
): ):
@ -74,7 +74,7 @@ def extract_questions(
prompt = construct_structured_message( prompt = construct_structured_message(
message=prompt, message=prompt,
image_url=uploaded_image_url, images=query_images,
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModelOptions.ModelType.OPENAI,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
) )
@ -135,7 +135,7 @@ def converse(
location_data: LocationData = None, location_data: LocationData = None,
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
image_url: Optional[str] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
): ):
""" """
@ -191,7 +191,7 @@ def converse(
model_name=model, model_name=model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
uploaded_image_url=image_url, query_images=query_images,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModelOptions.ModelType.OPENAI,
) )

View file

@ -109,7 +109,7 @@ def save_to_conversation_log(
client_application: ClientApplication = None, client_application: ClientApplication = None,
conversation_id: str = None, conversation_id: str = None,
automation_id: str = None, automation_id: str = None,
uploaded_image_url: str = None, query_images: List[str] = None,
): ):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log( updated_conversation = message_to_log(
@ -117,7 +117,7 @@ def save_to_conversation_log(
chat_response=chat_response, chat_response=chat_response,
user_message_metadata={ user_message_metadata={
"created": user_message_time, "created": user_message_time,
"uploadedImageData": uploaded_image_url, "images": query_images,
}, },
khoj_message_metadata={ khoj_message_metadata={
"context": compiled_references, "context": compiled_references,
@ -145,10 +145,18 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
) )
# Format user and system messages to chatml format def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool):
def construct_structured_message(message, image_url, model_type, vision_enabled): """
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI: Format messages into appropriate multimedia format for supported chat model types
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}] """
if not images or not vision_enabled:
return message
if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]:
return [
{"type": "text", "text": message},
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
]
return message return message
@ -160,7 +168,7 @@ def generate_chatml_messages_with_context(
loaded_model: Optional[Llama] = None, loaded_model: Optional[Llama] = None,
max_prompt_size=None, max_prompt_size=None,
tokenizer_name=None, tokenizer_name=None,
uploaded_image_url=None, query_images=None,
vision_enabled=False, vision_enabled=False,
model_type="", model_type="",
): ):
@ -186,9 +194,7 @@ def generate_chatml_messages_with_context(
else: else:
message_content = chat["message"] + message_notes message_content = chat["message"] + message_notes
message_content = construct_structured_message( message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled)
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
)
reconstructed_message = ChatMessage(content=message_content, role=role) reconstructed_message = ChatMessage(content=message_content, role=role)
@ -201,7 +207,7 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(user_message): if not is_none_or_empty(user_message):
messages.append( messages.append(
ChatMessage( ChatMessage(
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled), content=construct_structured_message(user_message, query_images, model_type, vision_enabled),
role="user", role="user",
) )
) )
@ -225,7 +231,6 @@ def truncate_messages(
tokenizer_name=None, tokenizer_name=None,
) -> list[ChatMessage]: ) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model""" """Truncate messages to fit within max prompt size supported by model"""
default_tokenizer = "gpt-4o" default_tokenizer = "gpt-4o"
try: try:
@ -255,6 +260,7 @@ def truncate_messages(
system_message = messages.pop(idx) system_message = messages.pop(idx)
break break
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
system_message_tokens = ( system_message_tokens = (
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0 len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
) )

View file

@ -26,7 +26,7 @@ async def text_to_image(
references: List[Dict[str, Any]], references: List[Dict[str, Any]],
online_results: Dict[str, Any], online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None, query_images: Optional[List[str]] = None,
agent: Agent = None, agent: Agent = None,
): ):
status_code = 200 status_code = 200
@ -65,7 +65,7 @@ async def text_to_image(
note_references=references, note_references=references,
online_results=online_results, online_results=online_results,
model_type=text_to_image_config.model_type, model_type=text_to_image_config.model_type,
uploaded_image_url=uploaded_image_url, query_images=query_images,
user=user, user=user,
agent=agent, agent=agent,
) )

View file

@ -62,7 +62,7 @@ async def search_online(
user: KhojUser, user: KhojUser,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [], custom_filters: List[str] = [],
uploaded_image_url: str = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
): ):
query += " ".join(custom_filters) query += " ".join(custom_filters)
@ -73,7 +73,7 @@ async def search_online(
# Breakdown the query into subqueries to get the correct answer # Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries( subqueries = await generate_online_subqueries(
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent query, conversation_history, location, user, query_images=query_images, agent=agent
) )
response_dict = {} response_dict = {}
@ -151,7 +151,7 @@ async def read_webpages(
location: LocationData, location: LocationData,
user: KhojUser, user: KhojUser,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
): ):
"Infer web pages to read from the query and extract relevant information from them" "Infer web pages to read from the query and extract relevant information from them"
@ -159,7 +159,7 @@ async def read_webpages(
if send_status_func: if send_status_func:
async for event in send_status_func(f"**Inferring web pages to read**"): async for event in send_status_func(f"**Inferring web pages to read**"):
yield {ChatEvent.STATUS: event} yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url) urls = await infer_webpage_urls(query, conversation_history, location, user, query_images)
logger.info(f"Reading web pages at: {urls}") logger.info(f"Reading web pages at: {urls}")
if send_status_func: if send_status_func:

View file

@ -347,7 +347,7 @@ async def extract_references_and_questions(
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default], conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
location_data: LocationData = None, location_data: LocationData = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None, query_images: Optional[List[str]] = None,
agent: Agent = None, agent: Agent = None,
): ):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
@ -438,7 +438,7 @@ async def extract_references_and_questions(
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
user=user, user=user,
uploaded_image_url=uploaded_image_url, query_images=query_images,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
) )
@ -459,12 +459,14 @@ async def extract_references_and_questions(
chat_model = conversation_config.chat_model chat_model = conversation_config.chat_model
inferred_queries = extract_questions_gemini( inferred_queries = extract_questions_gemini(
defiltered_query, defiltered_query,
query_images=query_images,
model=chat_model, model=chat_model,
api_key=api_key, api_key=api_key,
conversation_log=meta_log, conversation_log=meta_log,
location_data=location_data, location_data=location_data,
max_tokens=conversation_config.max_prompt_size, max_tokens=conversation_config.max_prompt_size,
user=user, user=user,
vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
) )

View file

@ -30,8 +30,10 @@ from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ApiImageRateLimiter,
ApiUserRateLimiter, ApiUserRateLimiter,
ChatEvent, ChatEvent,
ChatRequestBody,
CommonQueryParams, CommonQueryParams,
ConversationCommandRateLimiter, ConversationCommandRateLimiter,
agenerate_chat_response, agenerate_chat_response,
@ -524,22 +526,6 @@ async def set_conversation_title(
) )
class ChatRequestBody(BaseModel):
q: str
n: Optional[int] = 7
d: Optional[float] = None
stream: Optional[bool] = False
title: Optional[str] = None
conversation_id: Optional[str] = None
city: Optional[str] = None
region: Optional[str] = None
country: Optional[str] = None
country_code: Optional[str] = None
timezone: Optional[str] = None
image: Optional[str] = None
create_new: Optional[bool] = False
@api_chat.post("") @api_chat.post("")
@requires(["authenticated"]) @requires(["authenticated"])
async def chat( async def chat(
@ -552,6 +538,7 @@ async def chat(
rate_limiter_per_day=Depends( rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day") ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
), ),
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
): ):
# Access the parameters from the body # Access the parameters from the body
q = body.q q = body.q
@ -565,9 +552,9 @@ async def chat(
country = body.country or get_country_name_from_timezone(body.timezone) country = body.country or get_country_name_from_timezone(body.timezone)
country_code = body.country_code or get_country_code_from_timezone(body.timezone) country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone timezone = body.timezone
image = body.image raw_images = body.images
async def event_generator(q: str, image: str): async def event_generator(q: str, images: list[str]):
start_time = time.perf_counter() start_time = time.perf_counter()
ttft = None ttft = None
chat_metadata: dict = {} chat_metadata: dict = {}
@ -577,16 +564,16 @@ async def chat(
q = unquote(q) q = unquote(q)
nonlocal conversation_id nonlocal conversation_id
uploaded_image_url = None uploaded_images: list[str] = []
if image: if images:
for image in images:
decoded_string = unquote(image) decoded_string = unquote(image)
base64_data = decoded_string.split(",", 1)[1] base64_data = decoded_string.split(",", 1)[1]
image_bytes = base64.b64decode(base64_data) image_bytes = base64.b64decode(base64_data)
webp_image_bytes = convert_image_to_webp(image_bytes) webp_image_bytes = convert_image_to_webp(image_bytes)
try: uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id) if uploaded_image:
except: uploaded_images.append(uploaded_image)
uploaded_image_url = None
async def send_event(event_type: ChatEvent, data: str | dict): async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft nonlocal connection_alive, ttft
@ -693,7 +680,7 @@ async def chat(
meta_log, meta_log,
is_automated_task, is_automated_task,
user=user, user=user,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
agent=agent, agent=agent,
) )
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands]) conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
@ -702,7 +689,7 @@ async def chat(
): ):
yield result yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent) mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"): async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result yield result
if mode not in conversation_commands: if mode not in conversation_commands:
@ -765,7 +752,7 @@ async def chat(
q, q,
contextual_data, contextual_data,
conversation_history=meta_log, conversation_history=meta_log,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
user=user, user=user,
agent=agent, agent=agent,
) )
@ -786,7 +773,7 @@ async def chat(
intent_type="summarize", intent_type="summarize",
client_application=request.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
) )
return return
@ -829,7 +816,7 @@ async def chat(
conversation_id=conversation_id, conversation_id=conversation_id,
inferred_queries=[query_to_run], inferred_queries=[query_to_run],
automation_id=automation.id, automation_id=automation.id,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
) )
async for result in send_llm_response(llm_response): async for result in send_llm_response(llm_response):
yield result yield result
@ -849,7 +836,7 @@ async def chat(
conversation_commands, conversation_commands,
location, location,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
agent=agent, agent=agent,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -893,7 +880,7 @@ async def chat(
user, user,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
custom_filters, custom_filters,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
agent=agent, agent=agent,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -917,7 +904,7 @@ async def chat(
location, location,
user, user,
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
agent=agent, agent=agent,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -967,20 +954,20 @@ async def chat(
references=compiled_references, references=compiled_references,
online_results=online_results, online_results=online_results,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
agent=agent, agent=agent,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
else: else:
image, status_code, improved_image_prompt, intent_type = result generated_image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200: if generated_image is None or status_code != 200:
content_obj = { content_obj = {
"content-type": "application/json", "content-type": "application/json",
"intentType": intent_type, "intentType": intent_type,
"detail": improved_image_prompt, "detail": improved_image_prompt,
"image": image, "image": None,
} }
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj)):
yield result yield result
@ -988,7 +975,7 @@ async def chat(
await sync_to_async(save_to_conversation_log)( await sync_to_async(save_to_conversation_log)(
q, q,
image, generated_image,
user, user,
meta_log, meta_log,
user_message_time, user_message_time,
@ -998,12 +985,12 @@ async def chat(
conversation_id=conversation_id, conversation_id=conversation_id,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
) )
content_obj = { content_obj = {
"intentType": intent_type, "intentType": intent_type,
"inferredQueries": [improved_image_prompt], "inferredQueries": [improved_image_prompt],
"image": image, "image": generated_image,
} }
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj)):
yield result yield result
@ -1023,7 +1010,7 @@ async def chat(
location_data=location, location_data=location,
note_references=compiled_references, note_references=compiled_references,
online_results=online_results, online_results=online_results,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
user=user, user=user,
agent=agent, agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
@ -1053,7 +1040,7 @@ async def chat(
conversation_id=conversation_id, conversation_id=conversation_id,
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
uploaded_image_url=uploaded_image_url, query_images=uploaded_images,
) )
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj)):
@ -1076,7 +1063,7 @@ async def chat(
conversation_id, conversation_id,
location, location,
user_name, user_name,
uploaded_image_url, uploaded_images,
) )
# Send Response # Send Response
@ -1102,9 +1089,9 @@ async def chat(
## Stream Text Response ## Stream Text Response
if stream: if stream:
return StreamingResponse(event_generator(q, image=image), media_type="text/plain") return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
## Non-Streaming Text Response ## Non-Streaming Text Response
else: else:
response_iterator = event_generator(q, image=image) response_iterator = event_generator(q, images=raw_images)
response_data = await read_chat_stream(response_iterator) response_data = await read_chat_stream(response_iterator)
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200) return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)

View file

@ -1,4 +1,5 @@
import asyncio import asyncio
import base64
import hashlib import hashlib
import json import json
import logging import logging
@ -22,7 +23,7 @@ from typing import (
Tuple, Tuple,
Union, Union,
) )
from urllib.parse import parse_qs, quote, urljoin, urlparse from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse
import cron_descriptor import cron_descriptor
import pytz import pytz
@ -31,6 +32,7 @@ from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import Depends, Header, HTTPException, Request, UploadFile from fastapi import Depends, Header, HTTPException, Request, UploadFile
from pydantic import BaseModel
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
from starlette.requests import URL from starlette.requests import URL
@ -296,7 +298,7 @@ async def aget_relevant_information_sources(
conversation_history: dict, conversation_history: dict,
is_task: bool, is_task: bool,
user: KhojUser, user: KhojUser,
uploaded_image_url: str = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
): ):
""" """
@ -315,8 +317,8 @@ async def aget_relevant_information_sources(
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
if uploaded_image_url: if query_images:
query = f"[placeholder for user attached image]\n{query}" query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
personality_context = ( personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@ -373,7 +375,7 @@ async def aget_relevant_output_modes(
conversation_history: dict, conversation_history: dict,
is_task: bool = False, is_task: bool = False,
user: KhojUser = None, user: KhojUser = None,
uploaded_image_url: str = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
): ):
""" """
@ -395,8 +397,8 @@ async def aget_relevant_output_modes(
chat_history = construct_chat_history(conversation_history) chat_history = construct_chat_history(conversation_history)
if uploaded_image_url: if query_images:
query = f"[placeholder for user attached image]\n{query}" query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
personality_context = ( personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else "" prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@ -439,7 +441,7 @@ async def infer_webpage_urls(
conversation_history: dict, conversation_history: dict,
location_data: LocationData, location_data: LocationData,
user: KhojUser, user: KhojUser,
uploaded_image_url: str = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
) -> List[str]: ) -> List[str]:
""" """
@ -465,7 +467,7 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger): with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user online_queries_prompt, query_images=query_images, response_type="json_object", user=user
) )
# Validate that the response is a non-empty, JSON-serializable list of URLs # Validate that the response is a non-empty, JSON-serializable list of URLs
@ -485,7 +487,7 @@ async def generate_online_subqueries(
conversation_history: dict, conversation_history: dict,
location_data: LocationData, location_data: LocationData,
user: KhojUser, user: KhojUser,
uploaded_image_url: str = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
) -> List[str]: ) -> List[str]:
""" """
@ -511,7 +513,7 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger): with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user online_queries_prompt, query_images=query_images, response_type="json_object", user=user
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
@ -530,7 +532,7 @@ async def generate_online_subqueries(
async def schedule_query( async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
) -> Tuple[str, ...]: ) -> Tuple[str, ...]:
""" """
Schedule the date, time to run the query. Assume the server timezone is UTC. Schedule the date, time to run the query. Assume the server timezone is UTC.
@ -543,7 +545,7 @@ async def schedule_query(
) )
raw_response = await send_message_to_model_wrapper( raw_response = await send_message_to_model_wrapper(
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user crontime_prompt, query_images=query_images, response_type="json_object", user=user
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
@ -589,7 +591,7 @@ async def extract_relevant_summary(
q: str, q: str,
corpus: str, corpus: str,
conversation_history: dict, conversation_history: dict,
uploaded_image_url: str = None, query_images: List[str] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
) -> Union[str, None]: ) -> Union[str, None]:
@ -618,7 +620,7 @@ async def extract_relevant_summary(
extract_relevant_information, extract_relevant_information,
prompts.system_prompt_extract_relevant_summary, prompts.system_prompt_extract_relevant_summary,
user=user, user=user,
uploaded_image_url=uploaded_image_url, query_images=query_images,
) )
return response.strip() return response.strip()
@ -629,7 +631,7 @@ async def generate_excalidraw_diagram(
location_data: LocationData, location_data: LocationData,
note_references: List[Dict[str, Any]], note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None, online_results: Optional[dict] = None,
uploaded_image_url: Optional[str] = None, query_images: List[str] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
@ -644,7 +646,7 @@ async def generate_excalidraw_diagram(
location_data=location_data, location_data=location_data,
note_references=note_references, note_references=note_references,
online_results=online_results, online_results=online_results,
uploaded_image_url=uploaded_image_url, query_images=query_images,
user=user, user=user,
agent=agent, agent=agent,
) )
@ -668,7 +670,7 @@ async def generate_better_diagram_description(
location_data: LocationData, location_data: LocationData,
note_references: List[Dict[str, Any]], note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None, online_results: Optional[dict] = None,
uploaded_image_url: Optional[str] = None, query_images: List[str] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
) -> str: ) -> str:
@ -711,7 +713,7 @@ async def generate_better_diagram_description(
with timer("Chat actor: Generate better diagram description", logger): with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
improve_diagram_description_prompt, uploaded_image_url=uploaded_image_url, user=user improve_diagram_description_prompt, query_images=query_images, user=user
) )
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@ -753,7 +755,7 @@ async def generate_better_image_prompt(
note_references: List[Dict[str, Any]], note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None, online_results: Optional[dict] = None,
model_type: Optional[str] = None, model_type: Optional[str] = None,
uploaded_image_url: Optional[str] = None, query_images: Optional[List[str]] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
) -> str: ) -> str:
@ -805,7 +807,7 @@ async def generate_better_image_prompt(
) )
with timer("Chat actor: Generate contextual image prompt", logger): with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user) response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1] response = response[1:-1]
@ -818,11 +820,11 @@ async def send_message_to_model_wrapper(
system_message: str = "", system_message: str = "",
response_type: str = "text", response_type: str = "text",
user: KhojUser = None, user: KhojUser = None,
uploaded_image_url: str = None, query_images: List[str] = None,
): ):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled vision_available = conversation_config.vision_enabled
if not vision_available and uploaded_image_url: if not vision_available and query_images:
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config() vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config: if vision_enabled_config:
conversation_config = vision_enabled_config conversation_config = vision_enabled_config
@ -875,7 +877,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
vision_enabled=vision_available, vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url, query_images=query_images,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
) )
@ -895,7 +897,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
vision_enabled=vision_available, vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url, query_images=query_images,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
) )
@ -913,7 +915,8 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
tokenizer_name=tokenizer, tokenizer_name=tokenizer,
vision_enabled=vision_available, vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url, query_images=query_images,
model_type=conversation_config.model_type,
) )
return gemini_send_message_to_model( return gemini_send_message_to_model(
@ -1004,6 +1007,7 @@ def send_message_to_model_wrapper_sync(
model_name=chat_model, model_name=chat_model,
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type,
) )
return gemini_send_message_to_model( return gemini_send_message_to_model(
@ -1029,7 +1033,7 @@ def generate_chat_response(
conversation_id: str = None, conversation_id: str = None,
location_data: LocationData = None, location_data: LocationData = None,
user_name: Optional[str] = None, user_name: Optional[str] = None,
uploaded_image_url: Optional[str] = None, query_images: Optional[List[str]] = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables # Initialize Variables
chat_response = None chat_response = None
@ -1048,12 +1052,12 @@ def generate_chat_response(
inferred_queries=inferred_queries, inferred_queries=inferred_queries,
client_application=client_application, client_application=client_application,
conversation_id=conversation_id, conversation_id=conversation_id,
uploaded_image_url=uploaded_image_url, query_images=query_images,
) )
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
vision_available = conversation_config.vision_enabled vision_available = conversation_config.vision_enabled
if not vision_available and uploaded_image_url: if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config() vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
if vision_enabled_config: if vision_enabled_config:
conversation_config = vision_enabled_config conversation_config = vision_enabled_config
@ -1084,7 +1088,7 @@ def generate_chat_response(
chat_response = converse( chat_response = converse(
compiled_references, compiled_references,
q, q,
image_url=uploaded_image_url, query_images=query_images,
online_results=online_results, online_results=online_results,
conversation_log=meta_log, conversation_log=meta_log,
model=chat_model, model=chat_model,
@ -1122,8 +1126,9 @@ def generate_chat_response(
chat_response = converse_gemini( chat_response = converse_gemini(
compiled_references, compiled_references,
q, q,
online_results, query_images=query_images,
meta_log, online_results=online_results,
conversation_log=meta_log,
model=conversation_config.chat_model, model=conversation_config.chat_model,
api_key=api_key, api_key=api_key,
completion_func=partial_completion, completion_func=partial_completion,
@ -1133,6 +1138,7 @@ def generate_chat_response(
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
vision_available=vision_available,
) )
metadata.update({"chat_model": conversation_config.chat_model}) metadata.update({"chat_model": conversation_config.chat_model})
@ -1144,6 +1150,22 @@ def generate_chat_response(
return chat_response, metadata return chat_response, metadata
class ChatRequestBody(BaseModel):
q: str
n: Optional[int] = 7
d: Optional[float] = None
stream: Optional[bool] = False
title: Optional[str] = None
conversation_id: Optional[str] = None
city: Optional[str] = None
region: Optional[str] = None
country: Optional[str] = None
country_code: Optional[str] = None
timezone: Optional[str] = None
images: Optional[list[str]] = None
create_new: Optional[bool] = False
class ApiUserRateLimiter: class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str): def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests self.requests = requests
@ -1189,13 +1211,58 @@ class ApiUserRateLimiter:
) )
raise HTTPException( raise HTTPException(
status_code=429, status_code=429,
detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).", detail="I'm glad you're enjoying interacting with me! But you've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
) )
# Add the current request to the cache # Add the current request to the cache
UserRequests.objects.create(user=user, slug=self.slug) UserRequests.objects.create(user=user, slug=self.slug)
class ApiImageRateLimiter:
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
self.max_images = max_images
self.max_combined_size_mb = max_combined_size_mb
def __call__(self, request: Request, body: ChatRequestBody):
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
# Other systems handle authentication
if not request.user.is_authenticated:
return
if not body.images:
return
# Check number of images
if len(body.images) > self.max_images:
raise HTTPException(
status_code=429,
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
)
# Check total size of images
total_size_mb = 0.0
for image in body.images:
# Unquote the image in case it's URL encoded
image = unquote(image)
# Assuming the image is a base64 encoded string
# Remove the data:image/jpeg;base64, part if present
if "," in image:
image = image.split(",", 1)[1]
# Decode base64 to get the actual size
image_bytes = base64.b64decode(image)
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
if total_size_mb > self.max_combined_size_mb:
raise HTTPException(
status_code=429,
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
)
class ConversationCommandRateLimiter: class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str): def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
self.slug = slug self.slug = slug

View file

@ -352,9 +352,9 @@ tool_descriptions_for_llm = {
} }
mode_descriptions_for_llm = { mode_descriptions_for_llm = {
ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.", ConversationCommand.Image: "Use this if the user is requesting you to create a new picture based on their description.",
ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency", ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency",
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.", ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query.",
ConversationCommand.Diagram: "Use this if the user is requesting a visual representation that requires primitives like lines, rectangles, and text.", ConversationCommand.Diagram: "Use this if the user is requesting a visual representation that requires primitives like lines, rectangles, and text.",
} }