mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 19:03:01 +01:00
Chat with Multiple Images. Support Vision with Gemini (#942)
## Overview - Add vision support for Gemini models in Khoj - Allow sharing multiple images as part of user query from the web app - Handle multiple images shared in query to chat API
This commit is contained in:
commit
c6f3253ebd
17 changed files with 454 additions and 294 deletions
|
@ -27,14 +27,14 @@ interface ChatBodyDataProps {
|
||||||
setUploadedFiles: (files: string[]) => void;
|
setUploadedFiles: (files: string[]) => void;
|
||||||
isMobileWidth?: boolean;
|
isMobileWidth?: boolean;
|
||||||
isLoggedIn: boolean;
|
isLoggedIn: boolean;
|
||||||
setImage64: (image64: string) => void;
|
setImages: (images: string[]) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
function ChatBodyData(props: ChatBodyDataProps) {
|
function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
const searchParams = useSearchParams();
|
const searchParams = useSearchParams();
|
||||||
const conversationId = searchParams.get("conversationId");
|
const conversationId = searchParams.get("conversationId");
|
||||||
const [message, setMessage] = useState("");
|
const [message, setMessage] = useState("");
|
||||||
const [image, setImage] = useState<string | null>(null);
|
const [images, setImages] = useState<string[]>([]);
|
||||||
const [processingMessage, setProcessingMessage] = useState(false);
|
const [processingMessage, setProcessingMessage] = useState(false);
|
||||||
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
||||||
|
|
||||||
|
@ -44,17 +44,20 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
|
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (image) {
|
if (images.length > 0) {
|
||||||
props.setImage64(encodeURIComponent(image));
|
const encodedImages = images.map((image) => encodeURIComponent(image));
|
||||||
|
props.setImages(encodedImages);
|
||||||
}
|
}
|
||||||
}, [image, props.setImage64]);
|
}, [images, props.setImages]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const storedImage = localStorage.getItem("image");
|
const storedImages = localStorage.getItem("images");
|
||||||
if (storedImage) {
|
if (storedImages) {
|
||||||
setImage(storedImage);
|
const parsedImages: string[] = JSON.parse(storedImages);
|
||||||
props.setImage64(encodeURIComponent(storedImage));
|
setImages(parsedImages);
|
||||||
localStorage.removeItem("image");
|
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
|
||||||
|
props.setImages(encodedImages);
|
||||||
|
localStorage.removeItem("images");
|
||||||
}
|
}
|
||||||
|
|
||||||
const storedMessage = localStorage.getItem("message");
|
const storedMessage = localStorage.getItem("message");
|
||||||
|
@ -62,7 +65,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
setProcessingMessage(true);
|
setProcessingMessage(true);
|
||||||
setQueryToProcess(storedMessage);
|
setQueryToProcess(storedMessage);
|
||||||
}
|
}
|
||||||
}, [setQueryToProcess]);
|
}, [setQueryToProcess, props.setImages]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (message) {
|
if (message) {
|
||||||
|
@ -84,6 +87,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
props.streamedMessages[props.streamedMessages.length - 1].completed
|
props.streamedMessages[props.streamedMessages.length - 1].completed
|
||||||
) {
|
) {
|
||||||
setProcessingMessage(false);
|
setProcessingMessage(false);
|
||||||
|
setImages([]); // Reset images after processing
|
||||||
} else {
|
} else {
|
||||||
setMessage("");
|
setMessage("");
|
||||||
}
|
}
|
||||||
|
@ -113,7 +117,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
agentColor={agentMetadata?.color}
|
agentColor={agentMetadata?.color}
|
||||||
isLoggedIn={props.isLoggedIn}
|
isLoggedIn={props.isLoggedIn}
|
||||||
sendMessage={(message) => setMessage(message)}
|
sendMessage={(message) => setMessage(message)}
|
||||||
sendImage={(image) => setImage(image)}
|
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||||
sendDisabled={processingMessage}
|
sendDisabled={processingMessage}
|
||||||
chatOptionsData={props.chatOptionsData}
|
chatOptionsData={props.chatOptionsData}
|
||||||
conversationId={conversationId}
|
conversationId={conversationId}
|
||||||
|
@ -135,7 +139,7 @@ export default function Chat() {
|
||||||
const [queryToProcess, setQueryToProcess] = useState<string>("");
|
const [queryToProcess, setQueryToProcess] = useState<string>("");
|
||||||
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
||||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||||
const [image64, setImage64] = useState<string>("");
|
const [images, setImages] = useState<string[]>([]);
|
||||||
|
|
||||||
const locationData = useIPLocationData() || {
|
const locationData = useIPLocationData() || {
|
||||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||||
|
@ -171,7 +175,7 @@ export default function Chat() {
|
||||||
completed: false,
|
completed: false,
|
||||||
timestamp: new Date().toISOString(),
|
timestamp: new Date().toISOString(),
|
||||||
rawQuery: queryToProcess || "",
|
rawQuery: queryToProcess || "",
|
||||||
uploadedImageData: decodeURIComponent(image64),
|
images: images,
|
||||||
};
|
};
|
||||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
||||||
setProcessQuerySignal(true);
|
setProcessQuerySignal(true);
|
||||||
|
@ -202,7 +206,7 @@ export default function Chat() {
|
||||||
if (done) {
|
if (done) {
|
||||||
setQueryToProcess("");
|
setQueryToProcess("");
|
||||||
setProcessQuerySignal(false);
|
setProcessQuerySignal(false);
|
||||||
setImage64("");
|
setImages([]);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -250,7 +254,7 @@ export default function Chat() {
|
||||||
country_code: locationData.countryCode,
|
country_code: locationData.countryCode,
|
||||||
timezone: locationData.timezone,
|
timezone: locationData.timezone,
|
||||||
}),
|
}),
|
||||||
...(image64 && { image: image64 }),
|
...(images.length > 0 && { images: images }),
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await fetch(chatAPI, {
|
const response = await fetch(chatAPI, {
|
||||||
|
@ -264,7 +268,8 @@ export default function Chat() {
|
||||||
try {
|
try {
|
||||||
await readChatStream(response);
|
await readChatStream(response);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err);
|
const apiError = await response.json();
|
||||||
|
console.error(apiError);
|
||||||
// Retrieve latest message being processed
|
// Retrieve latest message being processed
|
||||||
const currentMessage = messages.find((message) => !message.completed);
|
const currentMessage = messages.find((message) => !message.completed);
|
||||||
if (!currentMessage) return;
|
if (!currentMessage) return;
|
||||||
|
@ -273,7 +278,11 @@ export default function Chat() {
|
||||||
const errorMessage = (err as Error).message;
|
const errorMessage = (err as Error).message;
|
||||||
if (errorMessage.includes("Error in input stream"))
|
if (errorMessage.includes("Error in input stream"))
|
||||||
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
|
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
|
||||||
else
|
else if (response.status === 429) {
|
||||||
|
"detail" in apiError
|
||||||
|
? (currentMessage.rawResponse = `${apiError.detail}`)
|
||||||
|
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
|
||||||
|
} else
|
||||||
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
|
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
|
||||||
|
|
||||||
// Complete message streaming teardown properly
|
// Complete message streaming teardown properly
|
||||||
|
@ -332,7 +341,7 @@ export default function Chat() {
|
||||||
setUploadedFiles={setUploadedFiles}
|
setUploadedFiles={setUploadedFiles}
|
||||||
isMobileWidth={isMobileWidth}
|
isMobileWidth={isMobileWidth}
|
||||||
onConversationIdChange={handleConversationIdChange}
|
onConversationIdChange={handleConversationIdChange}
|
||||||
setImage64={setImage64}
|
setImages={setImages}
|
||||||
/>
|
/>
|
||||||
</Suspense>
|
</Suspense>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -299,7 +299,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||||
created: message.timestamp,
|
created: message.timestamp,
|
||||||
by: "you",
|
by: "you",
|
||||||
automationId: "",
|
automationId: "",
|
||||||
uploadedImageData: message.uploadedImageData,
|
images: message.images,
|
||||||
}}
|
}}
|
||||||
customClassName="fullHistory"
|
customClassName="fullHistory"
|
||||||
borderLeftColor={`${data?.agent?.color}-500`}
|
borderLeftColor={`${data?.agent?.color}-500`}
|
||||||
|
@ -348,7 +348,6 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||||
created: new Date().getTime().toString(),
|
created: new Date().getTime().toString(),
|
||||||
by: "you",
|
by: "you",
|
||||||
automationId: "",
|
automationId: "",
|
||||||
uploadedImageData: props.pendingMessage,
|
|
||||||
}}
|
}}
|
||||||
customClassName="fullHistory"
|
customClassName="fullHistory"
|
||||||
borderLeftColor={`${data?.agent?.color}-500`}
|
borderLeftColor={`${data?.agent?.color}-500`}
|
||||||
|
|
|
@ -3,23 +3,7 @@ import React, { useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
import DOMPurify from "dompurify";
|
import DOMPurify from "dompurify";
|
||||||
import "katex/dist/katex.min.css";
|
import "katex/dist/katex.min.css";
|
||||||
import {
|
import { ArrowUp, Microphone, Paperclip, X, Stop } from "@phosphor-icons/react";
|
||||||
ArrowRight,
|
|
||||||
ArrowUp,
|
|
||||||
Browser,
|
|
||||||
ChatsTeardrop,
|
|
||||||
GlobeSimple,
|
|
||||||
Gps,
|
|
||||||
Image,
|
|
||||||
Microphone,
|
|
||||||
Notebook,
|
|
||||||
Paperclip,
|
|
||||||
X,
|
|
||||||
Question,
|
|
||||||
Robot,
|
|
||||||
Shapes,
|
|
||||||
Stop,
|
|
||||||
} from "@phosphor-icons/react";
|
|
||||||
|
|
||||||
import {
|
import {
|
||||||
Command,
|
Command,
|
||||||
|
@ -78,10 +62,11 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
const [loginRedirectMessage, setLoginRedirectMessage] = useState<string | null>(null);
|
const [loginRedirectMessage, setLoginRedirectMessage] = useState<string | null>(null);
|
||||||
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
|
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
|
||||||
|
|
||||||
const [recording, setRecording] = useState(false);
|
|
||||||
const [imageUploaded, setImageUploaded] = useState(false);
|
const [imageUploaded, setImageUploaded] = useState(false);
|
||||||
const [imagePath, setImagePath] = useState<string>("");
|
const [imagePaths, setImagePaths] = useState<string[]>([]);
|
||||||
const [imageData, setImageData] = useState<string | null>(null);
|
const [imageData, setImageData] = useState<string[]>([]);
|
||||||
|
|
||||||
|
const [recording, setRecording] = useState(false);
|
||||||
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
|
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
|
||||||
|
|
||||||
const [progressValue, setProgressValue] = useState(0);
|
const [progressValue, setProgressValue] = useState(0);
|
||||||
|
@ -106,27 +91,31 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
async function fetchImageData() {
|
async function fetchImageData() {
|
||||||
if (imagePath) {
|
if (imagePaths.length > 0) {
|
||||||
const response = await fetch(imagePath);
|
const newImageData = await Promise.all(
|
||||||
|
imagePaths.map(async (path) => {
|
||||||
|
const response = await fetch(path);
|
||||||
const blob = await response.blob();
|
const blob = await response.blob();
|
||||||
|
return new Promise<string>((resolve) => {
|
||||||
const reader = new FileReader();
|
const reader = new FileReader();
|
||||||
reader.onload = function () {
|
reader.onload = () => resolve(reader.result as string);
|
||||||
const base64data = reader.result;
|
|
||||||
setImageData(base64data as string);
|
|
||||||
};
|
|
||||||
reader.readAsDataURL(blob);
|
reader.readAsDataURL(blob);
|
||||||
|
});
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
setImageData(newImageData);
|
||||||
}
|
}
|
||||||
setUploading(false);
|
setUploading(false);
|
||||||
}
|
}
|
||||||
setUploading(true);
|
setUploading(true);
|
||||||
fetchImageData();
|
fetchImageData();
|
||||||
}, [imagePath]);
|
}, [imagePaths]);
|
||||||
|
|
||||||
function onSendMessage() {
|
function onSendMessage() {
|
||||||
if (imageUploaded) {
|
if (imageUploaded) {
|
||||||
setImageUploaded(false);
|
setImageUploaded(false);
|
||||||
setImagePath("");
|
setImagePaths([]);
|
||||||
props.sendImage(imageData || "");
|
imageData.forEach((data) => props.sendImage(data));
|
||||||
}
|
}
|
||||||
if (!message.trim()) return;
|
if (!message.trim()) return;
|
||||||
|
|
||||||
|
@ -172,18 +161,23 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
setShowLoginPrompt(true);
|
setShowLoginPrompt(true);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
// check for image file
|
// check for image files
|
||||||
const image_endings = ["jpg", "jpeg", "png", "webp"];
|
const image_endings = ["jpg", "jpeg", "png", "webp"];
|
||||||
|
const newImagePaths: string[] = [];
|
||||||
for (let i = 0; i < files.length; i++) {
|
for (let i = 0; i < files.length; i++) {
|
||||||
const file = files[i];
|
const file = files[i];
|
||||||
const file_extension = file.name.split(".").pop();
|
const file_extension = file.name.split(".").pop();
|
||||||
if (image_endings.includes(file_extension || "")) {
|
if (image_endings.includes(file_extension || "")) {
|
||||||
setImageUploaded(true);
|
newImagePaths.push(DOMPurify.sanitize(URL.createObjectURL(file)));
|
||||||
setImagePath(DOMPurify.sanitize(URL.createObjectURL(file)));
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (newImagePaths.length > 0) {
|
||||||
|
setImageUploaded(true);
|
||||||
|
setImagePaths((prevPaths) => [...prevPaths, ...newImagePaths]);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
uploadDataForIndexing(
|
uploadDataForIndexing(
|
||||||
files,
|
files,
|
||||||
setWarning,
|
setWarning,
|
||||||
|
@ -288,9 +282,12 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
setIsDragAndDropping(false);
|
setIsDragAndDropping(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
function removeImageUpload() {
|
function removeImageUpload(index: number) {
|
||||||
|
setImagePaths((prevPaths) => prevPaths.filter((_, i) => i !== index));
|
||||||
|
setImageData((prevData) => prevData.filter((_, i) => i !== index));
|
||||||
|
if (imagePaths.length === 1) {
|
||||||
setImageUploaded(false);
|
setImageUploaded(false);
|
||||||
setImagePath("");
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
@ -407,24 +404,11 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<div
|
<div
|
||||||
className={`${styles.actualInputArea} items-center justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
||||||
onDragOver={handleDragOver}
|
onDragOver={handleDragOver}
|
||||||
onDragLeave={handleDragLeave}
|
onDragLeave={handleDragLeave}
|
||||||
onDrop={handleDragAndDropFiles}
|
onDrop={handleDragAndDropFiles}
|
||||||
>
|
>
|
||||||
{imageUploaded && (
|
|
||||||
<div className="absolute bottom-[80px] left-0 right-0 dark:bg-neutral-700 bg-white pt-5 pb-5 w-full rounded-lg border dark:border-none grid grid-cols-2">
|
|
||||||
<div className="pl-4 pr-4">
|
|
||||||
<img src={imagePath} alt="img" className="w-auto max-h-[100px]" />
|
|
||||||
</div>
|
|
||||||
<div className="pl-4 pr-4">
|
|
||||||
<X
|
|
||||||
className="w-6 h-6 float-right dark:hover:bg-[hsl(var(--background))] hover:bg-neutral-100 rounded-sm"
|
|
||||||
onClick={removeImageUpload}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
<input
|
<input
|
||||||
type="file"
|
type="file"
|
||||||
multiple={true}
|
multiple={true}
|
||||||
|
@ -432,6 +416,7 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
onChange={handleFileChange}
|
onChange={handleFileChange}
|
||||||
style={{ display: "none" }}
|
style={{ display: "none" }}
|
||||||
/>
|
/>
|
||||||
|
<div className="flex items-end pb-4">
|
||||||
<Button
|
<Button
|
||||||
variant={"ghost"}
|
variant={"ghost"}
|
||||||
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
|
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
|
||||||
|
@ -440,7 +425,28 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
>
|
>
|
||||||
<Paperclip className="w-8 h-8" />
|
<Paperclip className="w-8 h-8" />
|
||||||
</Button>
|
</Button>
|
||||||
<div className="grid w-full gap-1.5 relative">
|
</div>
|
||||||
|
<div className="flex-grow flex flex-col w-full gap-1.5 relative pb-2">
|
||||||
|
<div className="flex items-center gap-2 overflow-x-auto">
|
||||||
|
{imageUploaded &&
|
||||||
|
imagePaths.map((path, index) => (
|
||||||
|
<div key={index} className="relative flex-shrink-0 pb-3 pt-2 group">
|
||||||
|
<img
|
||||||
|
src={path}
|
||||||
|
alt={`img-${index}`}
|
||||||
|
className="w-auto h-16 object-cover rounded-xl"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||||
|
onClick={() => removeImageUpload(index)}
|
||||||
|
>
|
||||||
|
<X className="h-3 w-3" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
<Textarea
|
<Textarea
|
||||||
ref={chatInputRef}
|
ref={chatInputRef}
|
||||||
className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`}
|
className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`}
|
||||||
|
@ -451,7 +457,7 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
onKeyDown={(e) => {
|
onKeyDown={(e) => {
|
||||||
if (e.key === "Enter" && !e.shiftKey) {
|
if (e.key === "Enter" && !e.shiftKey) {
|
||||||
setImageUploaded(false);
|
setImageUploaded(false);
|
||||||
setImagePath("");
|
setImagePaths([]);
|
||||||
e.preventDefault();
|
e.preventDefault();
|
||||||
onSendMessage();
|
onSendMessage();
|
||||||
}
|
}
|
||||||
|
@ -460,6 +466,7 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
disabled={props.sendDisabled || recording}
|
disabled={props.sendDisabled || recording}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
<div className="flex items-end pb-4">
|
||||||
{recording ? (
|
{recording ? (
|
||||||
<TooltipProvider>
|
<TooltipProvider>
|
||||||
<Tooltip>
|
<Tooltip>
|
||||||
|
@ -512,6 +519,7 @@ export default function ChatInputArea(props: ChatInputProps) {
|
||||||
<ArrowUp className="w-6 h-6" weight="bold" />
|
<ArrowUp className="w-6 h-6" weight="bold" />
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
</div>
|
||||||
</>
|
</>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,26 @@ div.emptyChatMessage {
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
div.chatMessageContainer img {
|
div.imagesContainer {
|
||||||
|
display: flex;
|
||||||
|
overflow-x: auto;
|
||||||
|
padding-bottom: 8px;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.imageWrapper {
|
||||||
|
flex: 0 0 auto;
|
||||||
|
margin-right: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.imageWrapper img {
|
||||||
|
width: auto;
|
||||||
|
height: 128px;
|
||||||
|
object-fit: cover;
|
||||||
|
border-radius: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
div.chatMessageContainer > img {
|
||||||
width: auto;
|
width: auto;
|
||||||
height: auto;
|
height: auto;
|
||||||
max-width: 100%;
|
max-width: 100%;
|
||||||
|
|
|
@ -116,7 +116,7 @@ export interface SingleChatMessage {
|
||||||
rawQuery?: string;
|
rawQuery?: string;
|
||||||
intent?: Intent;
|
intent?: Intent;
|
||||||
agent?: AgentData;
|
agent?: AgentData;
|
||||||
uploadedImageData?: string;
|
images?: string[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface StreamMessage {
|
export interface StreamMessage {
|
||||||
|
@ -128,10 +128,9 @@ export interface StreamMessage {
|
||||||
rawQuery: string;
|
rawQuery: string;
|
||||||
timestamp: string;
|
timestamp: string;
|
||||||
agent?: AgentData;
|
agent?: AgentData;
|
||||||
uploadedImageData?: string;
|
images?: string[];
|
||||||
intentType?: string;
|
intentType?: string;
|
||||||
inferredQueries?: string[];
|
inferredQueries?: string[];
|
||||||
image?: string;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatHistoryData {
|
export interface ChatHistoryData {
|
||||||
|
@ -213,7 +212,6 @@ interface ChatMessageProps {
|
||||||
borderLeftColor?: string;
|
borderLeftColor?: string;
|
||||||
isLastMessage?: boolean;
|
isLastMessage?: boolean;
|
||||||
agent?: AgentData;
|
agent?: AgentData;
|
||||||
uploadedImageData?: string;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
interface TrainOfThoughtProps {
|
interface TrainOfThoughtProps {
|
||||||
|
@ -343,8 +341,17 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||||
.replace(/\\\[/g, "LEFTBRACKET")
|
.replace(/\\\[/g, "LEFTBRACKET")
|
||||||
.replace(/\\\]/g, "RIGHTBRACKET");
|
.replace(/\\\]/g, "RIGHTBRACKET");
|
||||||
|
|
||||||
if (props.chatMessage.uploadedImageData) {
|
if (props.chatMessage.images && props.chatMessage.images.length > 0) {
|
||||||
message = `![uploaded image](${props.chatMessage.uploadedImageData})\n\n${message}`;
|
const imagesInMd = props.chatMessage.images
|
||||||
|
.map((image, index) => {
|
||||||
|
const decodedImage = image.startsWith("data%3Aimage")
|
||||||
|
? decodeURIComponent(image)
|
||||||
|
: image;
|
||||||
|
const sanitizedImage = DOMPurify.sanitize(decodedImage);
|
||||||
|
return `<div class="${styles.imageWrapper}"><img src="${sanitizedImage}" alt="uploaded image ${index + 1}" /></div>`;
|
||||||
|
})
|
||||||
|
.join("");
|
||||||
|
message = `<div class="${styles.imagesContainer}">${imagesInMd}</div>${message}`;
|
||||||
}
|
}
|
||||||
|
|
||||||
const intentTypeHandlers = {
|
const intentTypeHandlers = {
|
||||||
|
@ -384,7 +391,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||||
|
|
||||||
// Sanitize and set the rendered markdown
|
// Sanitize and set the rendered markdown
|
||||||
setMarkdownRendered(DOMPurify.sanitize(markdownRendered));
|
setMarkdownRendered(DOMPurify.sanitize(markdownRendered));
|
||||||
}, [props.chatMessage.message, props.chatMessage.intent]);
|
}, [props.chatMessage.message, props.chatMessage.images, props.chatMessage.intent]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (copySuccess) {
|
if (copySuccess) {
|
||||||
|
|
|
@ -44,7 +44,7 @@ function FisherYatesShuffle(array: any[]) {
|
||||||
|
|
||||||
function ChatBodyData(props: ChatBodyDataProps) {
|
function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
const [message, setMessage] = useState("");
|
const [message, setMessage] = useState("");
|
||||||
const [image, setImage] = useState<string | null>(null);
|
const [images, setImages] = useState<string[]>([]);
|
||||||
const [processingMessage, setProcessingMessage] = useState(false);
|
const [processingMessage, setProcessingMessage] = useState(false);
|
||||||
const [greeting, setGreeting] = useState("");
|
const [greeting, setGreeting] = useState("");
|
||||||
const [shuffledOptions, setShuffledOptions] = useState<Suggestion[]>([]);
|
const [shuffledOptions, setShuffledOptions] = useState<Suggestion[]>([]);
|
||||||
|
@ -138,20 +138,21 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
try {
|
try {
|
||||||
const newConversationId = await createNewConversation(selectedAgent || "khoj");
|
const newConversationId = await createNewConversation(selectedAgent || "khoj");
|
||||||
onConversationIdChange?.(newConversationId);
|
onConversationIdChange?.(newConversationId);
|
||||||
window.location.href = `/chat?conversationId=${newConversationId}`;
|
|
||||||
localStorage.setItem("message", message);
|
localStorage.setItem("message", message);
|
||||||
if (image) {
|
if (images.length > 0) {
|
||||||
localStorage.setItem("image", image);
|
localStorage.setItem("images", JSON.stringify(images));
|
||||||
}
|
}
|
||||||
|
window.location.href = `/chat?conversationId=${newConversationId}`;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error creating new conversation:", error);
|
console.error("Error creating new conversation:", error);
|
||||||
setProcessingMessage(false);
|
setProcessingMessage(false);
|
||||||
}
|
}
|
||||||
setMessage("");
|
setMessage("");
|
||||||
|
setImages([]);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
processMessage();
|
processMessage();
|
||||||
if (message) {
|
if (message || images.length > 0) {
|
||||||
setProcessingMessage(true);
|
setProcessingMessage(true);
|
||||||
}
|
}
|
||||||
}, [selectedAgent, message, processingMessage, onConversationIdChange]);
|
}, [selectedAgent, message, processingMessage, onConversationIdChange]);
|
||||||
|
@ -224,7 +225,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit"}`}>
|
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit max-w-screen-md"}`}>
|
||||||
{!props.isMobileWidth && (
|
{!props.isMobileWidth && (
|
||||||
<div
|
<div
|
||||||
className={`w-full ${styles.inputBox} shadow-lg bg-background align-middle items-center justify-center px-3 py-1 dark:bg-neutral-700 border-stone-100 dark:border-none dark:shadow-none rounded-2xl`}
|
className={`w-full ${styles.inputBox} shadow-lg bg-background align-middle items-center justify-center px-3 py-1 dark:bg-neutral-700 border-stone-100 dark:border-none dark:shadow-none rounded-2xl`}
|
||||||
|
@ -232,7 +233,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
<ChatInputArea
|
<ChatInputArea
|
||||||
isLoggedIn={props.isLoggedIn}
|
isLoggedIn={props.isLoggedIn}
|
||||||
sendMessage={(message) => setMessage(message)}
|
sendMessage={(message) => setMessage(message)}
|
||||||
sendImage={(image) => setImage(image)}
|
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||||
sendDisabled={processingMessage}
|
sendDisabled={processingMessage}
|
||||||
chatOptionsData={props.chatOptionsData}
|
chatOptionsData={props.chatOptionsData}
|
||||||
conversationId={null}
|
conversationId={null}
|
||||||
|
@ -313,7 +314,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
<ChatInputArea
|
<ChatInputArea
|
||||||
isLoggedIn={props.isLoggedIn}
|
isLoggedIn={props.isLoggedIn}
|
||||||
sendMessage={(message) => setMessage(message)}
|
sendMessage={(message) => setMessage(message)}
|
||||||
sendImage={(image) => setImage(image)}
|
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||||
sendDisabled={processingMessage}
|
sendDisabled={processingMessage}
|
||||||
chatOptionsData={props.chatOptionsData}
|
chatOptionsData={props.chatOptionsData}
|
||||||
conversationId={null}
|
conversationId={null}
|
||||||
|
|
|
@ -28,12 +28,12 @@ interface ChatBodyDataProps {
|
||||||
isLoggedIn: boolean;
|
isLoggedIn: boolean;
|
||||||
conversationId?: string;
|
conversationId?: string;
|
||||||
setQueryToProcess: (query: string) => void;
|
setQueryToProcess: (query: string) => void;
|
||||||
setImage64: (image64: string) => void;
|
setImages: (images: string[]) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
function ChatBodyData(props: ChatBodyDataProps) {
|
function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
const [message, setMessage] = useState("");
|
const [message, setMessage] = useState("");
|
||||||
const [image, setImage] = useState<string | null>(null);
|
const [images, setImages] = useState<string[]>([]);
|
||||||
const [processingMessage, setProcessingMessage] = useState(false);
|
const [processingMessage, setProcessingMessage] = useState(false);
|
||||||
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
||||||
const setQueryToProcess = props.setQueryToProcess;
|
const setQueryToProcess = props.setQueryToProcess;
|
||||||
|
@ -42,10 +42,28 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
|
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (image) {
|
if (images.length > 0) {
|
||||||
props.setImage64(encodeURIComponent(image));
|
const encodedImages = images.map((image) => encodeURIComponent(image));
|
||||||
|
props.setImages(encodedImages);
|
||||||
}
|
}
|
||||||
}, [image, props.setImage64]);
|
}, [images, props.setImages]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const storedImages = localStorage.getItem("images");
|
||||||
|
if (storedImages) {
|
||||||
|
const parsedImages: string[] = JSON.parse(storedImages);
|
||||||
|
setImages(parsedImages);
|
||||||
|
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
|
||||||
|
props.setImages(encodedImages);
|
||||||
|
localStorage.removeItem("images");
|
||||||
|
}
|
||||||
|
|
||||||
|
const storedMessage = localStorage.getItem("message");
|
||||||
|
if (storedMessage) {
|
||||||
|
setProcessingMessage(true);
|
||||||
|
setQueryToProcess(storedMessage);
|
||||||
|
}
|
||||||
|
}, [setQueryToProcess, props.setImages]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (message) {
|
if (message) {
|
||||||
|
@ -89,7 +107,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
<ChatInputArea
|
<ChatInputArea
|
||||||
isLoggedIn={props.isLoggedIn}
|
isLoggedIn={props.isLoggedIn}
|
||||||
sendMessage={(message) => setMessage(message)}
|
sendMessage={(message) => setMessage(message)}
|
||||||
sendImage={(image) => setImage(image)}
|
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||||
sendDisabled={processingMessage}
|
sendDisabled={processingMessage}
|
||||||
chatOptionsData={props.chatOptionsData}
|
chatOptionsData={props.chatOptionsData}
|
||||||
conversationId={props.conversationId}
|
conversationId={props.conversationId}
|
||||||
|
@ -112,7 +130,7 @@ export default function SharedChat() {
|
||||||
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
||||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||||
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
|
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
|
||||||
const [image64, setImage64] = useState<string>("");
|
const [images, setImages] = useState<string[]>([]);
|
||||||
|
|
||||||
const locationData = useIPLocationData() || {
|
const locationData = useIPLocationData() || {
|
||||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||||
|
@ -170,7 +188,7 @@ export default function SharedChat() {
|
||||||
completed: false,
|
completed: false,
|
||||||
timestamp: new Date().toISOString(),
|
timestamp: new Date().toISOString(),
|
||||||
rawQuery: queryToProcess || "",
|
rawQuery: queryToProcess || "",
|
||||||
uploadedImageData: decodeURIComponent(image64),
|
images: images,
|
||||||
};
|
};
|
||||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
||||||
setProcessQuerySignal(true);
|
setProcessQuerySignal(true);
|
||||||
|
@ -197,7 +215,7 @@ export default function SharedChat() {
|
||||||
if (done) {
|
if (done) {
|
||||||
setQueryToProcess("");
|
setQueryToProcess("");
|
||||||
setProcessQuerySignal(false);
|
setProcessQuerySignal(false);
|
||||||
setImage64("");
|
setImages([]);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -239,7 +257,7 @@ export default function SharedChat() {
|
||||||
country_code: locationData.countryCode,
|
country_code: locationData.countryCode,
|
||||||
timezone: locationData.timezone,
|
timezone: locationData.timezone,
|
||||||
}),
|
}),
|
||||||
...(image64 && { image: image64 }),
|
...(images.length > 0 && { image: images }),
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await fetch(chatAPI, {
|
const response = await fetch(chatAPI, {
|
||||||
|
@ -302,7 +320,7 @@ export default function SharedChat() {
|
||||||
setTitle={setTitle}
|
setTitle={setTitle}
|
||||||
setUploadedFiles={setUploadedFiles}
|
setUploadedFiles={setUploadedFiles}
|
||||||
isMobileWidth={isMobileWidth}
|
isMobileWidth={isMobileWidth}
|
||||||
setImage64={setImage64}
|
setImages={setImages}
|
||||||
/>
|
/>
|
||||||
</Suspense>
|
</Suspense>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -6,14 +6,17 @@ from typing import Dict, Optional
|
||||||
|
|
||||||
from langchain.schema import ChatMessage
|
from langchain.schema import ChatMessage
|
||||||
|
|
||||||
from khoj.database.models import Agent, KhojUser
|
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.google.utils import (
|
from khoj.processor.conversation.google.utils import (
|
||||||
format_messages_for_gemini,
|
format_messages_for_gemini,
|
||||||
gemini_chat_completion_with_backoff,
|
gemini_chat_completion_with_backoff,
|
||||||
gemini_completion_with_backoff,
|
gemini_completion_with_backoff,
|
||||||
)
|
)
|
||||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
from khoj.processor.conversation.utils import (
|
||||||
|
construct_structured_message,
|
||||||
|
generate_chatml_messages_with_context,
|
||||||
|
)
|
||||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import LocationData
|
||||||
|
|
||||||
|
@ -29,6 +32,8 @@ def extract_questions_gemini(
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
|
query_images: Optional[list[str]] = None,
|
||||||
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -70,17 +75,17 @@ def extract_questions_gemini(
|
||||||
text=text,
|
text=text,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [ChatMessage(content=prompt, role="user")]
|
prompt = construct_structured_message(
|
||||||
|
message=prompt,
|
||||||
|
images=query_images,
|
||||||
|
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
|
)
|
||||||
|
|
||||||
model_kwargs = {"response_mime_type": "application/json"}
|
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
||||||
|
|
||||||
response = gemini_completion_with_backoff(
|
response = gemini_send_message_to_model(
|
||||||
messages=messages,
|
messages, api_key, model, response_type="json_object", temperature=temperature
|
||||||
system_prompt=system_prompt,
|
|
||||||
model_name=model,
|
|
||||||
temperature=temperature,
|
|
||||||
api_key=api_key,
|
|
||||||
model_kwargs=model_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract, Clean Message from Gemini's Response
|
# Extract, Clean Message from Gemini's Response
|
||||||
|
@ -102,7 +107,7 @@ def extract_questions_gemini(
|
||||||
return questions
|
return questions
|
||||||
|
|
||||||
|
|
||||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
|
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
"""
|
"""
|
||||||
|
@ -114,7 +119,12 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
|
||||||
|
|
||||||
# Get Response from Gemini
|
# Get Response from Gemini
|
||||||
return gemini_completion_with_backoff(
|
return gemini_completion_with_backoff(
|
||||||
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
|
messages=messages,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
model_name=model,
|
||||||
|
api_key=api_key,
|
||||||
|
temperature=temperature,
|
||||||
|
model_kwargs=model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -133,6 +143,8 @@ def converse_gemini(
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_images: Optional[list[str]] = None,
|
||||||
|
vision_available: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Converse with user using Google's Gemini
|
Converse with user using Google's Gemini
|
||||||
|
@ -187,6 +199,9 @@ def converse_gemini(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
|
query_images=query_images,
|
||||||
|
vision_enabled=vision_available,
|
||||||
|
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
from io import BytesIO
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
import google.generativeai as genai
|
import google.generativeai as genai
|
||||||
|
import PIL.Image
|
||||||
|
import requests
|
||||||
from google.generativeai.types.answer_types import FinishReason
|
from google.generativeai.types.answer_types import FinishReason
|
||||||
from google.generativeai.types.generation_types import StopCandidateException
|
from google.generativeai.types.generation_types import StopCandidateException
|
||||||
from google.generativeai.types.safety_types import (
|
from google.generativeai.types.safety_types import (
|
||||||
|
@ -53,14 +56,14 @@ def gemini_completion_with_backoff(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||||
|
|
||||||
# Start chat session. All messages up to the last are considered to be part of the chat history
|
# Start chat session. All messages up to the last are considered to be part of the chat history
|
||||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate the response. The last message is considered to be the current prompt
|
# Generate the response. The last message is considered to be the current prompt
|
||||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
|
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||||
return aggregated_response.text
|
return aggregated_response.text
|
||||||
except StopCandidateException as e:
|
except StopCandidateException as e:
|
||||||
response_message, _ = handle_gemini_response(e.args)
|
response_message, _ = handle_gemini_response(e.args)
|
||||||
|
@ -117,11 +120,11 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||||
# all messages up to the last are considered to be part of the chat history
|
# all messages up to the last are considered to be part of the chat history
|
||||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||||
# the last message is considered to be the current prompt
|
# the last message is considered to be the current prompt
|
||||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
|
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
||||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||||
message = message or chunk.text
|
message = message or chunk.text
|
||||||
g.send(message)
|
g.send(message)
|
||||||
|
@ -191,14 +194,6 @@ def generate_safety_response(safety_ratings):
|
||||||
|
|
||||||
|
|
||||||
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
||||||
if len(messages) == 1:
|
|
||||||
messages[0].role = "user"
|
|
||||||
return messages, system_prompt
|
|
||||||
|
|
||||||
for message in messages:
|
|
||||||
if message.role == "assistant":
|
|
||||||
message.role = "model"
|
|
||||||
|
|
||||||
# Extract system message
|
# Extract system message
|
||||||
system_prompt = system_prompt or ""
|
system_prompt = system_prompt or ""
|
||||||
for message in messages.copy():
|
for message in messages.copy():
|
||||||
|
@ -207,4 +202,31 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
|
||||||
messages.remove(message)
|
messages.remove(message)
|
||||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
# Convert message content to string list from chatml dictionary list
|
||||||
|
if isinstance(message.content, list):
|
||||||
|
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
||||||
|
message.content = [
|
||||||
|
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"]
|
||||||
|
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||||
|
]
|
||||||
|
elif isinstance(message.content, str):
|
||||||
|
message.content = [message.content]
|
||||||
|
|
||||||
|
if message.role == "assistant":
|
||||||
|
message.role = "model"
|
||||||
|
|
||||||
|
if len(messages) == 1:
|
||||||
|
messages[0].role = "user"
|
||||||
|
|
||||||
return messages, system_prompt
|
return messages, system_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_from_url(image_url: str) -> PIL.Image:
|
||||||
|
try:
|
||||||
|
response = requests.get(image_url)
|
||||||
|
response.raise_for_status() # Check if the request was successful
|
||||||
|
return PIL.Image.open(BytesIO(response.content))
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||||
|
return None
|
||||||
|
|
|
@ -30,7 +30,7 @@ def extract_questions(
|
||||||
api_base_url=None,
|
api_base_url=None,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
@ -74,7 +74,7 @@ def extract_questions(
|
||||||
|
|
||||||
prompt = construct_structured_message(
|
prompt = construct_structured_message(
|
||||||
message=prompt,
|
message=prompt,
|
||||||
image_url=uploaded_image_url,
|
images=query_images,
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
)
|
)
|
||||||
|
@ -135,7 +135,7 @@ def converse(
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
image_url: Optional[str] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -191,7 +191,7 @@ def converse(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
uploaded_image_url=image_url,
|
query_images=query_images,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||||
)
|
)
|
||||||
|
|
|
@ -109,7 +109,7 @@ def save_to_conversation_log(
|
||||||
client_application: ClientApplication = None,
|
client_application: ClientApplication = None,
|
||||||
conversation_id: str = None,
|
conversation_id: str = None,
|
||||||
automation_id: str = None,
|
automation_id: str = None,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
):
|
):
|
||||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
updated_conversation = message_to_log(
|
updated_conversation = message_to_log(
|
||||||
|
@ -117,7 +117,7 @@ def save_to_conversation_log(
|
||||||
chat_response=chat_response,
|
chat_response=chat_response,
|
||||||
user_message_metadata={
|
user_message_metadata={
|
||||||
"created": user_message_time,
|
"created": user_message_time,
|
||||||
"uploadedImageData": uploaded_image_url,
|
"images": query_images,
|
||||||
},
|
},
|
||||||
khoj_message_metadata={
|
khoj_message_metadata={
|
||||||
"context": compiled_references,
|
"context": compiled_references,
|
||||||
|
@ -145,10 +145,18 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Format user and system messages to chatml format
|
def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool):
|
||||||
def construct_structured_message(message, image_url, model_type, vision_enabled):
|
"""
|
||||||
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
|
Format messages into appropriate multimedia format for supported chat model types
|
||||||
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
|
"""
|
||||||
|
if not images or not vision_enabled:
|
||||||
|
return message
|
||||||
|
|
||||||
|
if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]:
|
||||||
|
return [
|
||||||
|
{"type": "text", "text": message},
|
||||||
|
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
|
||||||
|
]
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
@ -160,7 +168,7 @@ def generate_chatml_messages_with_context(
|
||||||
loaded_model: Optional[Llama] = None,
|
loaded_model: Optional[Llama] = None,
|
||||||
max_prompt_size=None,
|
max_prompt_size=None,
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
uploaded_image_url=None,
|
query_images=None,
|
||||||
vision_enabled=False,
|
vision_enabled=False,
|
||||||
model_type="",
|
model_type="",
|
||||||
):
|
):
|
||||||
|
@ -186,9 +194,7 @@ def generate_chatml_messages_with_context(
|
||||||
else:
|
else:
|
||||||
message_content = chat["message"] + message_notes
|
message_content = chat["message"] + message_notes
|
||||||
|
|
||||||
message_content = construct_structured_message(
|
message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled)
|
||||||
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
|
|
||||||
)
|
|
||||||
|
|
||||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||||
|
|
||||||
|
@ -201,7 +207,7 @@ def generate_chatml_messages_with_context(
|
||||||
if not is_none_or_empty(user_message):
|
if not is_none_or_empty(user_message):
|
||||||
messages.append(
|
messages.append(
|
||||||
ChatMessage(
|
ChatMessage(
|
||||||
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
|
content=construct_structured_message(user_message, query_images, model_type, vision_enabled),
|
||||||
role="user",
|
role="user",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -225,7 +231,6 @@ def truncate_messages(
|
||||||
tokenizer_name=None,
|
tokenizer_name=None,
|
||||||
) -> list[ChatMessage]:
|
) -> list[ChatMessage]:
|
||||||
"""Truncate messages to fit within max prompt size supported by model"""
|
"""Truncate messages to fit within max prompt size supported by model"""
|
||||||
|
|
||||||
default_tokenizer = "gpt-4o"
|
default_tokenizer = "gpt-4o"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -255,6 +260,7 @@ def truncate_messages(
|
||||||
system_message = messages.pop(idx)
|
system_message = messages.pop(idx)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
|
||||||
system_message_tokens = (
|
system_message_tokens = (
|
||||||
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,7 +26,7 @@ async def text_to_image(
|
||||||
references: List[Dict[str, Any]],
|
references: List[Dict[str, Any]],
|
||||||
online_results: Dict[str, Any],
|
online_results: Dict[str, Any],
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
status_code = 200
|
status_code = 200
|
||||||
|
@ -65,7 +65,7 @@ async def text_to_image(
|
||||||
note_references=references,
|
note_references=references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
model_type=text_to_image_config.model_type,
|
model_type=text_to_image_config.model_type,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
|
|
@ -62,7 +62,7 @@ async def search_online(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
custom_filters: List[str] = [],
|
custom_filters: List[str] = [],
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
query += " ".join(custom_filters)
|
query += " ".join(custom_filters)
|
||||||
|
@ -73,7 +73,7 @@ async def search_online(
|
||||||
|
|
||||||
# Breakdown the query into subqueries to get the correct answer
|
# Breakdown the query into subqueries to get the correct answer
|
||||||
subqueries = await generate_online_subqueries(
|
subqueries = await generate_online_subqueries(
|
||||||
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
|
query, conversation_history, location, user, query_images=query_images, agent=agent
|
||||||
)
|
)
|
||||||
response_dict = {}
|
response_dict = {}
|
||||||
|
|
||||||
|
@ -151,7 +151,7 @@ async def read_webpages(
|
||||||
location: LocationData,
|
location: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
"Infer web pages to read from the query and extract relevant information from them"
|
"Infer web pages to read from the query and extract relevant information from them"
|
||||||
|
@ -159,7 +159,7 @@ async def read_webpages(
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
async for event in send_status_func(f"**Inferring web pages to read**"):
|
async for event in send_status_func(f"**Inferring web pages to read**"):
|
||||||
yield {ChatEvent.STATUS: event}
|
yield {ChatEvent.STATUS: event}
|
||||||
urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url)
|
urls = await infer_webpage_urls(query, conversation_history, location, user, query_images)
|
||||||
|
|
||||||
logger.info(f"Reading web pages at: {urls}")
|
logger.info(f"Reading web pages at: {urls}")
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
|
|
|
@ -347,7 +347,7 @@ async def extract_references_and_questions(
|
||||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
|
@ -438,7 +438,7 @@ async def extract_references_and_questions(
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user=user,
|
user=user,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
@ -459,12 +459,14 @@ async def extract_references_and_questions(
|
||||||
chat_model = conversation_config.chat_model
|
chat_model = conversation_config.chat_model
|
||||||
inferred_queries = extract_questions_gemini(
|
inferred_queries = extract_questions_gemini(
|
||||||
defiltered_query,
|
defiltered_query,
|
||||||
|
query_images=query_images,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
max_tokens=conversation_config.max_prompt_size,
|
max_tokens=conversation_config.max_prompt_size,
|
||||||
user=user,
|
user=user,
|
||||||
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -30,8 +30,10 @@ from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
||||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||||
from khoj.routers.api import extract_references_and_questions
|
from khoj.routers.api import extract_references_and_questions
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
|
ApiImageRateLimiter,
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
|
ChatRequestBody,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
|
@ -524,22 +526,6 @@ async def set_conversation_title(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChatRequestBody(BaseModel):
|
|
||||||
q: str
|
|
||||||
n: Optional[int] = 7
|
|
||||||
d: Optional[float] = None
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
title: Optional[str] = None
|
|
||||||
conversation_id: Optional[str] = None
|
|
||||||
city: Optional[str] = None
|
|
||||||
region: Optional[str] = None
|
|
||||||
country: Optional[str] = None
|
|
||||||
country_code: Optional[str] = None
|
|
||||||
timezone: Optional[str] = None
|
|
||||||
image: Optional[str] = None
|
|
||||||
create_new: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
@api_chat.post("")
|
@api_chat.post("")
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
async def chat(
|
async def chat(
|
||||||
|
@ -552,6 +538,7 @@ async def chat(
|
||||||
rate_limiter_per_day=Depends(
|
rate_limiter_per_day=Depends(
|
||||||
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
|
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
|
||||||
),
|
),
|
||||||
|
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
|
||||||
):
|
):
|
||||||
# Access the parameters from the body
|
# Access the parameters from the body
|
||||||
q = body.q
|
q = body.q
|
||||||
|
@ -565,9 +552,9 @@ async def chat(
|
||||||
country = body.country or get_country_name_from_timezone(body.timezone)
|
country = body.country or get_country_name_from_timezone(body.timezone)
|
||||||
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||||
timezone = body.timezone
|
timezone = body.timezone
|
||||||
image = body.image
|
raw_images = body.images
|
||||||
|
|
||||||
async def event_generator(q: str, image: str):
|
async def event_generator(q: str, images: list[str]):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
ttft = None
|
ttft = None
|
||||||
chat_metadata: dict = {}
|
chat_metadata: dict = {}
|
||||||
|
@ -577,16 +564,16 @@ async def chat(
|
||||||
q = unquote(q)
|
q = unquote(q)
|
||||||
nonlocal conversation_id
|
nonlocal conversation_id
|
||||||
|
|
||||||
uploaded_image_url = None
|
uploaded_images: list[str] = []
|
||||||
if image:
|
if images:
|
||||||
|
for image in images:
|
||||||
decoded_string = unquote(image)
|
decoded_string = unquote(image)
|
||||||
base64_data = decoded_string.split(",", 1)[1]
|
base64_data = decoded_string.split(",", 1)[1]
|
||||||
image_bytes = base64.b64decode(base64_data)
|
image_bytes = base64.b64decode(base64_data)
|
||||||
webp_image_bytes = convert_image_to_webp(image_bytes)
|
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||||
try:
|
uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||||
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
if uploaded_image:
|
||||||
except:
|
uploaded_images.append(uploaded_image)
|
||||||
uploaded_image_url = None
|
|
||||||
|
|
||||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||||
nonlocal connection_alive, ttft
|
nonlocal connection_alive, ttft
|
||||||
|
@ -693,7 +680,7 @@ async def chat(
|
||||||
meta_log,
|
meta_log,
|
||||||
is_automated_task,
|
is_automated_task,
|
||||||
user=user,
|
user=user,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||||
|
@ -702,7 +689,7 @@ async def chat(
|
||||||
):
|
):
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
|
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
|
||||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||||
yield result
|
yield result
|
||||||
if mode not in conversation_commands:
|
if mode not in conversation_commands:
|
||||||
|
@ -765,7 +752,7 @@ async def chat(
|
||||||
q,
|
q,
|
||||||
contextual_data,
|
contextual_data,
|
||||||
conversation_history=meta_log,
|
conversation_history=meta_log,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
@ -786,7 +773,7 @@ async def chat(
|
||||||
intent_type="summarize",
|
intent_type="summarize",
|
||||||
client_application=request.user.client_app,
|
client_application=request.user.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -829,7 +816,7 @@ async def chat(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
inferred_queries=[query_to_run],
|
inferred_queries=[query_to_run],
|
||||||
automation_id=automation.id,
|
automation_id=automation.id,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
)
|
)
|
||||||
async for result in send_llm_response(llm_response):
|
async for result in send_llm_response(llm_response):
|
||||||
yield result
|
yield result
|
||||||
|
@ -849,7 +836,7 @@ async def chat(
|
||||||
conversation_commands,
|
conversation_commands,
|
||||||
location,
|
location,
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -893,7 +880,7 @@ async def chat(
|
||||||
user,
|
user,
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
custom_filters,
|
custom_filters,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -917,7 +904,7 @@ async def chat(
|
||||||
location,
|
location,
|
||||||
user,
|
user,
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -967,20 +954,20 @@ async def chat(
|
||||||
references=compiled_references,
|
references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
else:
|
else:
|
||||||
image, status_code, improved_image_prompt, intent_type = result
|
generated_image, status_code, improved_image_prompt, intent_type = result
|
||||||
|
|
||||||
if image is None or status_code != 200:
|
if generated_image is None or status_code != 200:
|
||||||
content_obj = {
|
content_obj = {
|
||||||
"content-type": "application/json",
|
"content-type": "application/json",
|
||||||
"intentType": intent_type,
|
"intentType": intent_type,
|
||||||
"detail": improved_image_prompt,
|
"detail": improved_image_prompt,
|
||||||
"image": image,
|
"image": None,
|
||||||
}
|
}
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
yield result
|
yield result
|
||||||
|
@ -988,7 +975,7 @@ async def chat(
|
||||||
|
|
||||||
await sync_to_async(save_to_conversation_log)(
|
await sync_to_async(save_to_conversation_log)(
|
||||||
q,
|
q,
|
||||||
image,
|
generated_image,
|
||||||
user,
|
user,
|
||||||
meta_log,
|
meta_log,
|
||||||
user_message_time,
|
user_message_time,
|
||||||
|
@ -998,12 +985,12 @@ async def chat(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
)
|
)
|
||||||
content_obj = {
|
content_obj = {
|
||||||
"intentType": intent_type,
|
"intentType": intent_type,
|
||||||
"inferredQueries": [improved_image_prompt],
|
"inferredQueries": [improved_image_prompt],
|
||||||
"image": image,
|
"image": generated_image,
|
||||||
}
|
}
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
yield result
|
yield result
|
||||||
|
@ -1023,7 +1010,7 @@ async def chat(
|
||||||
location_data=location,
|
location_data=location,
|
||||||
note_references=compiled_references,
|
note_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
|
@ -1053,7 +1040,7 @@ async def chat(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=uploaded_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
|
@ -1076,7 +1063,7 @@ async def chat(
|
||||||
conversation_id,
|
conversation_id,
|
||||||
location,
|
location,
|
||||||
user_name,
|
user_name,
|
||||||
uploaded_image_url,
|
uploaded_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send Response
|
# Send Response
|
||||||
|
@ -1102,9 +1089,9 @@ async def chat(
|
||||||
|
|
||||||
## Stream Text Response
|
## Stream Text Response
|
||||||
if stream:
|
if stream:
|
||||||
return StreamingResponse(event_generator(q, image=image), media_type="text/plain")
|
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
|
||||||
## Non-Streaming Text Response
|
## Non-Streaming Text Response
|
||||||
else:
|
else:
|
||||||
response_iterator = event_generator(q, image=image)
|
response_iterator = event_generator(q, images=raw_images)
|
||||||
response_data = await read_chat_stream(response_iterator)
|
response_data = await read_chat_stream(response_iterator)
|
||||||
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
@ -22,7 +23,7 @@ from typing import (
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse
|
||||||
|
|
||||||
import cron_descriptor
|
import cron_descriptor
|
||||||
import pytz
|
import pytz
|
||||||
|
@ -31,6 +32,7 @@ from apscheduler.job import Job
|
||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
|
from pydantic import BaseModel
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
from starlette.requests import URL
|
from starlette.requests import URL
|
||||||
|
|
||||||
|
@ -296,7 +298,7 @@ async def aget_relevant_information_sources(
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
is_task: bool,
|
is_task: bool,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -315,8 +317,8 @@ async def aget_relevant_information_sources(
|
||||||
|
|
||||||
chat_history = construct_chat_history(conversation_history)
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
if uploaded_image_url:
|
if query_images:
|
||||||
query = f"[placeholder for user attached image]\n{query}"
|
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||||
|
|
||||||
personality_context = (
|
personality_context = (
|
||||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||||
|
@ -373,7 +375,7 @@ async def aget_relevant_output_modes(
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
is_task: bool = False,
|
is_task: bool = False,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -395,8 +397,8 @@ async def aget_relevant_output_modes(
|
||||||
|
|
||||||
chat_history = construct_chat_history(conversation_history)
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
if uploaded_image_url:
|
if query_images:
|
||||||
query = f"[placeholder for user attached image]\n{query}"
|
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||||
|
|
||||||
personality_context = (
|
personality_context = (
|
||||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||||
|
@ -439,7 +441,7 @@ async def infer_webpage_urls(
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -465,7 +467,7 @@ async def infer_webpage_urls(
|
||||||
|
|
||||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||||
|
@ -485,7 +487,7 @@ async def generate_online_subqueries(
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -511,7 +513,7 @@ async def generate_online_subqueries(
|
||||||
|
|
||||||
with timer("Chat actor: Generate online search subqueries", logger):
|
with timer("Chat actor: Generate online search subqueries", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
@ -530,7 +532,7 @@ async def generate_online_subqueries(
|
||||||
|
|
||||||
|
|
||||||
async def schedule_query(
|
async def schedule_query(
|
||||||
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None
|
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
|
||||||
) -> Tuple[str, ...]:
|
) -> Tuple[str, ...]:
|
||||||
"""
|
"""
|
||||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||||
|
@ -543,7 +545,7 @@ async def schedule_query(
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_response = await send_message_to_model_wrapper(
|
raw_response = await send_message_to_model_wrapper(
|
||||||
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
crontime_prompt, query_images=query_images, response_type="json_object", user=user
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
@ -589,7 +591,7 @@ async def extract_relevant_summary(
|
||||||
q: str,
|
q: str,
|
||||||
corpus: str,
|
corpus: str,
|
||||||
conversation_history: dict,
|
conversation_history: dict,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> Union[str, None]:
|
) -> Union[str, None]:
|
||||||
|
@ -618,7 +620,7 @@ async def extract_relevant_summary(
|
||||||
extract_relevant_information,
|
extract_relevant_information,
|
||||||
prompts.system_prompt_extract_relevant_summary,
|
prompts.system_prompt_extract_relevant_summary,
|
||||||
user=user,
|
user=user,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
)
|
)
|
||||||
return response.strip()
|
return response.strip()
|
||||||
|
|
||||||
|
@ -629,7 +631,7 @@ async def generate_excalidraw_diagram(
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
note_references: List[Dict[str, Any]],
|
note_references: List[Dict[str, Any]],
|
||||||
online_results: Optional[dict] = None,
|
online_results: Optional[dict] = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
query_images: List[str] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
|
@ -644,7 +646,7 @@ async def generate_excalidraw_diagram(
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
note_references=note_references,
|
note_references=note_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
@ -668,7 +670,7 @@ async def generate_better_diagram_description(
|
||||||
location_data: LocationData,
|
location_data: LocationData,
|
||||||
note_references: List[Dict[str, Any]],
|
note_references: List[Dict[str, Any]],
|
||||||
online_results: Optional[dict] = None,
|
online_results: Optional[dict] = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
query_images: List[str] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -711,7 +713,7 @@ async def generate_better_diagram_description(
|
||||||
|
|
||||||
with timer("Chat actor: Generate better diagram description", logger):
|
with timer("Chat actor: Generate better diagram description", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
improve_diagram_description_prompt, uploaded_image_url=uploaded_image_url, user=user
|
improve_diagram_description_prompt, query_images=query_images, user=user
|
||||||
)
|
)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
|
@ -753,7 +755,7 @@ async def generate_better_image_prompt(
|
||||||
note_references: List[Dict[str, Any]],
|
note_references: List[Dict[str, Any]],
|
||||||
online_results: Optional[dict] = None,
|
online_results: Optional[dict] = None,
|
||||||
model_type: Optional[str] = None,
|
model_type: Optional[str] = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -805,7 +807,7 @@ async def generate_better_image_prompt(
|
||||||
)
|
)
|
||||||
|
|
||||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user)
|
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
response = response[1:-1]
|
response = response[1:-1]
|
||||||
|
@ -818,11 +820,11 @@ async def send_message_to_model_wrapper(
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
uploaded_image_url: str = None,
|
query_images: List[str] = None,
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = conversation_config.vision_enabled
|
||||||
if not vision_available and uploaded_image_url:
|
if not vision_available and query_images:
|
||||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||||
if vision_enabled_config:
|
if vision_enabled_config:
|
||||||
conversation_config = vision_enabled_config
|
conversation_config = vision_enabled_config
|
||||||
|
@ -875,7 +877,7 @@ async def send_message_to_model_wrapper(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
tokenizer_name=tokenizer,
|
tokenizer_name=tokenizer,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
model_type=conversation_config.model_type,
|
model_type=conversation_config.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -895,7 +897,7 @@ async def send_message_to_model_wrapper(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
tokenizer_name=tokenizer,
|
tokenizer_name=tokenizer,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
model_type=conversation_config.model_type,
|
model_type=conversation_config.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -913,7 +915,8 @@ async def send_message_to_model_wrapper(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
tokenizer_name=tokenizer,
|
tokenizer_name=tokenizer,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
return gemini_send_message_to_model(
|
||||||
|
@ -1004,6 +1007,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
model_name=chat_model,
|
model_name=chat_model,
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
|
model_type=conversation_config.model_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
return gemini_send_message_to_model(
|
||||||
|
@ -1029,7 +1033,7 @@ def generate_chat_response(
|
||||||
conversation_id: str = None,
|
conversation_id: str = None,
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: Optional[str] = None,
|
user_name: Optional[str] = None,
|
||||||
uploaded_image_url: Optional[str] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
chat_response = None
|
chat_response = None
|
||||||
|
@ -1048,12 +1052,12 @@ def generate_chat_response(
|
||||||
inferred_queries=inferred_queries,
|
inferred_queries=inferred_queries,
|
||||||
client_application=client_application,
|
client_application=client_application,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
uploaded_image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||||
vision_available = conversation_config.vision_enabled
|
vision_available = conversation_config.vision_enabled
|
||||||
if not vision_available and uploaded_image_url:
|
if not vision_available and query_images:
|
||||||
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
||||||
if vision_enabled_config:
|
if vision_enabled_config:
|
||||||
conversation_config = vision_enabled_config
|
conversation_config = vision_enabled_config
|
||||||
|
@ -1084,7 +1088,7 @@ def generate_chat_response(
|
||||||
chat_response = converse(
|
chat_response = converse(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
q,
|
q,
|
||||||
image_url=uploaded_image_url,
|
query_images=query_images,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
conversation_log=meta_log,
|
conversation_log=meta_log,
|
||||||
model=chat_model,
|
model=chat_model,
|
||||||
|
@ -1122,8 +1126,9 @@ def generate_chat_response(
|
||||||
chat_response = converse_gemini(
|
chat_response = converse_gemini(
|
||||||
compiled_references,
|
compiled_references,
|
||||||
q,
|
q,
|
||||||
online_results,
|
query_images=query_images,
|
||||||
meta_log,
|
online_results=online_results,
|
||||||
|
conversation_log=meta_log,
|
||||||
model=conversation_config.chat_model,
|
model=conversation_config.chat_model,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
completion_func=partial_completion,
|
completion_func=partial_completion,
|
||||||
|
@ -1133,6 +1138,7 @@ def generate_chat_response(
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
vision_available=vision_available,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata.update({"chat_model": conversation_config.chat_model})
|
metadata.update({"chat_model": conversation_config.chat_model})
|
||||||
|
@ -1144,6 +1150,22 @@ def generate_chat_response(
|
||||||
return chat_response, metadata
|
return chat_response, metadata
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequestBody(BaseModel):
|
||||||
|
q: str
|
||||||
|
n: Optional[int] = 7
|
||||||
|
d: Optional[float] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
title: Optional[str] = None
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
city: Optional[str] = None
|
||||||
|
region: Optional[str] = None
|
||||||
|
country: Optional[str] = None
|
||||||
|
country_code: Optional[str] = None
|
||||||
|
timezone: Optional[str] = None
|
||||||
|
images: Optional[list[str]] = None
|
||||||
|
create_new: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
|
@ -1189,13 +1211,58 @@ class ApiUserRateLimiter:
|
||||||
)
|
)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=429,
|
status_code=429,
|
||||||
detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
|
detail="I'm glad you're enjoying interacting with me! But you've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add the current request to the cache
|
# Add the current request to the cache
|
||||||
UserRequests.objects.create(user=user, slug=self.slug)
|
UserRequests.objects.create(user=user, slug=self.slug)
|
||||||
|
|
||||||
|
|
||||||
|
class ApiImageRateLimiter:
|
||||||
|
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
|
||||||
|
self.max_images = max_images
|
||||||
|
self.max_combined_size_mb = max_combined_size_mb
|
||||||
|
|
||||||
|
def __call__(self, request: Request, body: ChatRequestBody):
|
||||||
|
if state.billing_enabled is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Rate limiting is disabled if user unauthenticated.
|
||||||
|
# Other systems handle authentication
|
||||||
|
if not request.user.is_authenticated:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not body.images:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check number of images
|
||||||
|
if len(body.images) > self.max_images:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check total size of images
|
||||||
|
total_size_mb = 0.0
|
||||||
|
for image in body.images:
|
||||||
|
# Unquote the image in case it's URL encoded
|
||||||
|
image = unquote(image)
|
||||||
|
# Assuming the image is a base64 encoded string
|
||||||
|
# Remove the data:image/jpeg;base64, part if present
|
||||||
|
if "," in image:
|
||||||
|
image = image.split(",", 1)[1]
|
||||||
|
|
||||||
|
# Decode base64 to get the actual size
|
||||||
|
image_bytes = base64.b64decode(image)
|
||||||
|
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
|
||||||
|
|
||||||
|
if total_size_mb > self.max_combined_size_mb:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConversationCommandRateLimiter:
|
class ConversationCommandRateLimiter:
|
||||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
||||||
self.slug = slug
|
self.slug = slug
|
||||||
|
|
|
@ -352,9 +352,9 @@ tool_descriptions_for_llm = {
|
||||||
}
|
}
|
||||||
|
|
||||||
mode_descriptions_for_llm = {
|
mode_descriptions_for_llm = {
|
||||||
ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.",
|
ConversationCommand.Image: "Use this if the user is requesting you to create a new picture based on their description.",
|
||||||
ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency",
|
ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency",
|
||||||
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
|
ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query.",
|
||||||
ConversationCommand.Diagram: "Use this if the user is requesting a visual representation that requires primitives like lines, rectangles, and text.",
|
ConversationCommand.Diagram: "Use this if the user is requesting a visual representation that requires primitives like lines, rectangles, and text.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue