Chat with Multiple Images. Support Vision with Gemini (#942)

## Overview
- Add vision support for Gemini models in Khoj
- Allow sharing multiple images as part of user query from the web app
- Handle multiple images shared in query to chat API
This commit is contained in:
Debanjum 2024-10-22 19:59:18 -07:00 committed by GitHub
commit c6f3253ebd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 454 additions and 294 deletions

View file

@ -27,14 +27,14 @@ interface ChatBodyDataProps {
setUploadedFiles: (files: string[]) => void;
isMobileWidth?: boolean;
isLoggedIn: boolean;
setImage64: (image64: string) => void;
setImages: (images: string[]) => void;
}
function ChatBodyData(props: ChatBodyDataProps) {
const searchParams = useSearchParams();
const conversationId = searchParams.get("conversationId");
const [message, setMessage] = useState("");
const [image, setImage] = useState<string | null>(null);
const [images, setImages] = useState<string[]>([]);
const [processingMessage, setProcessingMessage] = useState(false);
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
@ -44,17 +44,20 @@ function ChatBodyData(props: ChatBodyDataProps) {
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
useEffect(() => {
if (image) {
props.setImage64(encodeURIComponent(image));
if (images.length > 0) {
const encodedImages = images.map((image) => encodeURIComponent(image));
props.setImages(encodedImages);
}
}, [image, props.setImage64]);
}, [images, props.setImages]);
useEffect(() => {
const storedImage = localStorage.getItem("image");
if (storedImage) {
setImage(storedImage);
props.setImage64(encodeURIComponent(storedImage));
localStorage.removeItem("image");
const storedImages = localStorage.getItem("images");
if (storedImages) {
const parsedImages: string[] = JSON.parse(storedImages);
setImages(parsedImages);
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
props.setImages(encodedImages);
localStorage.removeItem("images");
}
const storedMessage = localStorage.getItem("message");
@ -62,7 +65,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
setProcessingMessage(true);
setQueryToProcess(storedMessage);
}
}, [setQueryToProcess]);
}, [setQueryToProcess, props.setImages]);
useEffect(() => {
if (message) {
@ -84,6 +87,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
props.streamedMessages[props.streamedMessages.length - 1].completed
) {
setProcessingMessage(false);
setImages([]); // Reset images after processing
} else {
setMessage("");
}
@ -113,7 +117,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
agentColor={agentMetadata?.color}
isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImage(image)}
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage}
chatOptionsData={props.chatOptionsData}
conversationId={conversationId}
@ -135,7 +139,7 @@ export default function Chat() {
const [queryToProcess, setQueryToProcess] = useState<string>("");
const [processQuerySignal, setProcessQuerySignal] = useState(false);
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
const [image64, setImage64] = useState<string>("");
const [images, setImages] = useState<string[]>([]);
const locationData = useIPLocationData() || {
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
@ -171,7 +175,7 @@ export default function Chat() {
completed: false,
timestamp: new Date().toISOString(),
rawQuery: queryToProcess || "",
uploadedImageData: decodeURIComponent(image64),
images: images,
};
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
setProcessQuerySignal(true);
@ -202,7 +206,7 @@ export default function Chat() {
if (done) {
setQueryToProcess("");
setProcessQuerySignal(false);
setImage64("");
setImages([]);
break;
}
@ -250,7 +254,7 @@ export default function Chat() {
country_code: locationData.countryCode,
timezone: locationData.timezone,
}),
...(image64 && { image: image64 }),
...(images.length > 0 && { images: images }),
};
const response = await fetch(chatAPI, {
@ -264,7 +268,8 @@ export default function Chat() {
try {
await readChatStream(response);
} catch (err) {
console.error(err);
const apiError = await response.json();
console.error(apiError);
// Retrieve latest message being processed
const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) return;
@ -273,7 +278,11 @@ export default function Chat() {
const errorMessage = (err as Error).message;
if (errorMessage.includes("Error in input stream"))
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
else
else if (response.status === 429) {
"detail" in apiError
? (currentMessage.rawResponse = `${apiError.detail}`)
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
} else
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
// Complete message streaming teardown properly
@ -332,7 +341,7 @@ export default function Chat() {
setUploadedFiles={setUploadedFiles}
isMobileWidth={isMobileWidth}
onConversationIdChange={handleConversationIdChange}
setImage64={setImage64}
setImages={setImages}
/>
</Suspense>
</div>

View file

@ -299,7 +299,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
created: message.timestamp,
by: "you",
automationId: "",
uploadedImageData: message.uploadedImageData,
images: message.images,
}}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
@ -348,7 +348,6 @@ export default function ChatHistory(props: ChatHistoryProps) {
created: new Date().getTime().toString(),
by: "you",
automationId: "",
uploadedImageData: props.pendingMessage,
}}
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}

View file

@ -3,23 +3,7 @@ import React, { useEffect, useRef, useState } from "react";
import DOMPurify from "dompurify";
import "katex/dist/katex.min.css";
import {
ArrowRight,
ArrowUp,
Browser,
ChatsTeardrop,
GlobeSimple,
Gps,
Image,
Microphone,
Notebook,
Paperclip,
X,
Question,
Robot,
Shapes,
Stop,
} from "@phosphor-icons/react";
import { ArrowUp, Microphone, Paperclip, X, Stop } from "@phosphor-icons/react";
import {
Command,
@ -78,10 +62,11 @@ export default function ChatInputArea(props: ChatInputProps) {
const [loginRedirectMessage, setLoginRedirectMessage] = useState<string | null>(null);
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
const [recording, setRecording] = useState(false);
const [imageUploaded, setImageUploaded] = useState(false);
const [imagePath, setImagePath] = useState<string>("");
const [imageData, setImageData] = useState<string | null>(null);
const [imagePaths, setImagePaths] = useState<string[]>([]);
const [imageData, setImageData] = useState<string[]>([]);
const [recording, setRecording] = useState(false);
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
const [progressValue, setProgressValue] = useState(0);
@ -106,27 +91,31 @@ export default function ChatInputArea(props: ChatInputProps) {
useEffect(() => {
async function fetchImageData() {
if (imagePath) {
const response = await fetch(imagePath);
if (imagePaths.length > 0) {
const newImageData = await Promise.all(
imagePaths.map(async (path) => {
const response = await fetch(path);
const blob = await response.blob();
return new Promise<string>((resolve) => {
const reader = new FileReader();
reader.onload = function () {
const base64data = reader.result;
setImageData(base64data as string);
};
reader.onload = () => resolve(reader.result as string);
reader.readAsDataURL(blob);
});
}),
);
setImageData(newImageData);
}
setUploading(false);
}
setUploading(true);
fetchImageData();
}, [imagePath]);
}, [imagePaths]);
function onSendMessage() {
if (imageUploaded) {
setImageUploaded(false);
setImagePath("");
props.sendImage(imageData || "");
setImagePaths([]);
imageData.forEach((data) => props.sendImage(data));
}
if (!message.trim()) return;
@ -172,18 +161,23 @@ export default function ChatInputArea(props: ChatInputProps) {
setShowLoginPrompt(true);
return;
}
// check for image file
// check for image files
const image_endings = ["jpg", "jpeg", "png", "webp"];
const newImagePaths: string[] = [];
for (let i = 0; i < files.length; i++) {
const file = files[i];
const file_extension = file.name.split(".").pop();
if (image_endings.includes(file_extension || "")) {
setImageUploaded(true);
setImagePath(DOMPurify.sanitize(URL.createObjectURL(file)));
return;
newImagePaths.push(DOMPurify.sanitize(URL.createObjectURL(file)));
}
}
if (newImagePaths.length > 0) {
setImageUploaded(true);
setImagePaths((prevPaths) => [...prevPaths, ...newImagePaths]);
return;
}
uploadDataForIndexing(
files,
setWarning,
@ -288,9 +282,12 @@ export default function ChatInputArea(props: ChatInputProps) {
setIsDragAndDropping(false);
}
function removeImageUpload() {
function removeImageUpload(index: number) {
setImagePaths((prevPaths) => prevPaths.filter((_, i) => i !== index));
setImageData((prevData) => prevData.filter((_, i) => i !== index));
if (imagePaths.length === 1) {
setImageUploaded(false);
setImagePath("");
}
}
return (
@ -407,24 +404,11 @@ export default function ChatInputArea(props: ChatInputProps) {
</div>
)}
<div
className={`${styles.actualInputArea} items-center justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
onDragOver={handleDragOver}
onDragLeave={handleDragLeave}
onDrop={handleDragAndDropFiles}
>
{imageUploaded && (
<div className="absolute bottom-[80px] left-0 right-0 dark:bg-neutral-700 bg-white pt-5 pb-5 w-full rounded-lg border dark:border-none grid grid-cols-2">
<div className="pl-4 pr-4">
<img src={imagePath} alt="img" className="w-auto max-h-[100px]" />
</div>
<div className="pl-4 pr-4">
<X
className="w-6 h-6 float-right dark:hover:bg-[hsl(var(--background))] hover:bg-neutral-100 rounded-sm"
onClick={removeImageUpload}
/>
</div>
</div>
)}
<input
type="file"
multiple={true}
@ -432,6 +416,7 @@ export default function ChatInputArea(props: ChatInputProps) {
onChange={handleFileChange}
style={{ display: "none" }}
/>
<div className="flex items-end pb-4">
<Button
variant={"ghost"}
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
@ -440,7 +425,28 @@ export default function ChatInputArea(props: ChatInputProps) {
>
<Paperclip className="w-8 h-8" />
</Button>
<div className="grid w-full gap-1.5 relative">
</div>
<div className="flex-grow flex flex-col w-full gap-1.5 relative pb-2">
<div className="flex items-center gap-2 overflow-x-auto">
{imageUploaded &&
imagePaths.map((path, index) => (
<div key={index} className="relative flex-shrink-0 pb-3 pt-2 group">
<img
src={path}
alt={`img-${index}`}
className="w-auto h-16 object-cover rounded-xl"
/>
<Button
variant="ghost"
size="icon"
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
onClick={() => removeImageUpload(index)}
>
<X className="h-3 w-3" />
</Button>
</div>
))}
</div>
<Textarea
ref={chatInputRef}
className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`}
@ -451,7 +457,7 @@ export default function ChatInputArea(props: ChatInputProps) {
onKeyDown={(e) => {
if (e.key === "Enter" && !e.shiftKey) {
setImageUploaded(false);
setImagePath("");
setImagePaths([]);
e.preventDefault();
onSendMessage();
}
@ -460,6 +466,7 @@ export default function ChatInputArea(props: ChatInputProps) {
disabled={props.sendDisabled || recording}
/>
</div>
<div className="flex items-end pb-4">
{recording ? (
<TooltipProvider>
<Tooltip>
@ -512,6 +519,7 @@ export default function ChatInputArea(props: ChatInputProps) {
<ArrowUp className="w-6 h-6" weight="bold" />
</Button>
</div>
</div>
</>
);
}

View file

@ -57,7 +57,26 @@ div.emptyChatMessage {
display: none;
}
div.chatMessageContainer img {
div.imagesContainer {
display: flex;
overflow-x: auto;
padding-bottom: 8px;
margin-bottom: 8px;
}
div.imageWrapper {
flex: 0 0 auto;
margin-right: 8px;
}
div.imageWrapper img {
width: auto;
height: 128px;
object-fit: cover;
border-radius: 8px;
}
div.chatMessageContainer > img {
width: auto;
height: auto;
max-width: 100%;

View file

@ -116,7 +116,7 @@ export interface SingleChatMessage {
rawQuery?: string;
intent?: Intent;
agent?: AgentData;
uploadedImageData?: string;
images?: string[];
}
export interface StreamMessage {
@ -128,10 +128,9 @@ export interface StreamMessage {
rawQuery: string;
timestamp: string;
agent?: AgentData;
uploadedImageData?: string;
images?: string[];
intentType?: string;
inferredQueries?: string[];
image?: string;
}
export interface ChatHistoryData {
@ -213,7 +212,6 @@ interface ChatMessageProps {
borderLeftColor?: string;
isLastMessage?: boolean;
agent?: AgentData;
uploadedImageData?: string;
}
interface TrainOfThoughtProps {
@ -343,8 +341,17 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
.replace(/\\\[/g, "LEFTBRACKET")
.replace(/\\\]/g, "RIGHTBRACKET");
if (props.chatMessage.uploadedImageData) {
message = `![uploaded image](${props.chatMessage.uploadedImageData})\n\n${message}`;
if (props.chatMessage.images && props.chatMessage.images.length > 0) {
const imagesInMd = props.chatMessage.images
.map((image, index) => {
const decodedImage = image.startsWith("data%3Aimage")
? decodeURIComponent(image)
: image;
const sanitizedImage = DOMPurify.sanitize(decodedImage);
return `<div class="${styles.imageWrapper}"><img src="${sanitizedImage}" alt="uploaded image ${index + 1}" /></div>`;
})
.join("");
message = `<div class="${styles.imagesContainer}">${imagesInMd}</div>${message}`;
}
const intentTypeHandlers = {
@ -384,7 +391,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
// Sanitize and set the rendered markdown
setMarkdownRendered(DOMPurify.sanitize(markdownRendered));
}, [props.chatMessage.message, props.chatMessage.intent]);
}, [props.chatMessage.message, props.chatMessage.images, props.chatMessage.intent]);
useEffect(() => {
if (copySuccess) {

View file

@ -44,7 +44,7 @@ function FisherYatesShuffle(array: any[]) {
function ChatBodyData(props: ChatBodyDataProps) {
const [message, setMessage] = useState("");
const [image, setImage] = useState<string | null>(null);
const [images, setImages] = useState<string[]>([]);
const [processingMessage, setProcessingMessage] = useState(false);
const [greeting, setGreeting] = useState("");
const [shuffledOptions, setShuffledOptions] = useState<Suggestion[]>([]);
@ -138,20 +138,21 @@ function ChatBodyData(props: ChatBodyDataProps) {
try {
const newConversationId = await createNewConversation(selectedAgent || "khoj");
onConversationIdChange?.(newConversationId);
window.location.href = `/chat?conversationId=${newConversationId}`;
localStorage.setItem("message", message);
if (image) {
localStorage.setItem("image", image);
if (images.length > 0) {
localStorage.setItem("images", JSON.stringify(images));
}
window.location.href = `/chat?conversationId=${newConversationId}`;
} catch (error) {
console.error("Error creating new conversation:", error);
setProcessingMessage(false);
}
setMessage("");
setImages([]);
}
};
processMessage();
if (message) {
if (message || images.length > 0) {
setProcessingMessage(true);
}
}, [selectedAgent, message, processingMessage, onConversationIdChange]);
@ -224,7 +225,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
</div>
)}
</div>
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit"}`}>
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit max-w-screen-md"}`}>
{!props.isMobileWidth && (
<div
className={`w-full ${styles.inputBox} shadow-lg bg-background align-middle items-center justify-center px-3 py-1 dark:bg-neutral-700 border-stone-100 dark:border-none dark:shadow-none rounded-2xl`}
@ -232,7 +233,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
<ChatInputArea
isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImage(image)}
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage}
chatOptionsData={props.chatOptionsData}
conversationId={null}
@ -313,7 +314,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
<ChatInputArea
isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImage(image)}
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage}
chatOptionsData={props.chatOptionsData}
conversationId={null}

View file

@ -28,12 +28,12 @@ interface ChatBodyDataProps {
isLoggedIn: boolean;
conversationId?: string;
setQueryToProcess: (query: string) => void;
setImage64: (image64: string) => void;
setImages: (images: string[]) => void;
}
function ChatBodyData(props: ChatBodyDataProps) {
const [message, setMessage] = useState("");
const [image, setImage] = useState<string | null>(null);
const [images, setImages] = useState<string[]>([]);
const [processingMessage, setProcessingMessage] = useState(false);
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
const setQueryToProcess = props.setQueryToProcess;
@ -42,10 +42,28 @@ function ChatBodyData(props: ChatBodyDataProps) {
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
useEffect(() => {
if (image) {
props.setImage64(encodeURIComponent(image));
if (images.length > 0) {
const encodedImages = images.map((image) => encodeURIComponent(image));
props.setImages(encodedImages);
}
}, [image, props.setImage64]);
}, [images, props.setImages]);
useEffect(() => {
const storedImages = localStorage.getItem("images");
if (storedImages) {
const parsedImages: string[] = JSON.parse(storedImages);
setImages(parsedImages);
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
props.setImages(encodedImages);
localStorage.removeItem("images");
}
const storedMessage = localStorage.getItem("message");
if (storedMessage) {
setProcessingMessage(true);
setQueryToProcess(storedMessage);
}
}, [setQueryToProcess, props.setImages]);
useEffect(() => {
if (message) {
@ -89,7 +107,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
<ChatInputArea
isLoggedIn={props.isLoggedIn}
sendMessage={(message) => setMessage(message)}
sendImage={(image) => setImage(image)}
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
sendDisabled={processingMessage}
chatOptionsData={props.chatOptionsData}
conversationId={props.conversationId}
@ -112,7 +130,7 @@ export default function SharedChat() {
const [processQuerySignal, setProcessQuerySignal] = useState(false);
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
const [image64, setImage64] = useState<string>("");
const [images, setImages] = useState<string[]>([]);
const locationData = useIPLocationData() || {
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
@ -170,7 +188,7 @@ export default function SharedChat() {
completed: false,
timestamp: new Date().toISOString(),
rawQuery: queryToProcess || "",
uploadedImageData: decodeURIComponent(image64),
images: images,
};
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
setProcessQuerySignal(true);
@ -197,7 +215,7 @@ export default function SharedChat() {
if (done) {
setQueryToProcess("");
setProcessQuerySignal(false);
setImage64("");
setImages([]);
break;
}
@ -239,7 +257,7 @@ export default function SharedChat() {
country_code: locationData.countryCode,
timezone: locationData.timezone,
}),
...(image64 && { image: image64 }),
...(images.length > 0 && { image: images }),
};
const response = await fetch(chatAPI, {
@ -302,7 +320,7 @@ export default function SharedChat() {
setTitle={setTitle}
setUploadedFiles={setUploadedFiles}
isMobileWidth={isMobileWidth}
setImage64={setImage64}
setImages={setImages}
/>
</Suspense>
</div>

View file

@ -6,14 +6,17 @@ from typing import Dict, Optional
from langchain.schema import ChatMessage
from khoj.database.models import Agent, KhojUser
from khoj.database.models import Agent, ChatModelOptions, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
format_messages_for_gemini,
gemini_chat_completion_with_backoff,
gemini_completion_with_backoff,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.processor.conversation.utils import (
construct_structured_message,
generate_chatml_messages_with_context,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
@ -29,6 +32,8 @@ def extract_questions_gemini(
max_tokens=None,
location_data: LocationData = None,
user: KhojUser = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
):
"""
@ -70,17 +75,17 @@ def extract_questions_gemini(
text=text,
)
messages = [ChatMessage(content=prompt, role="user")]
prompt = construct_structured_message(
message=prompt,
images=query_images,
model_type=ChatModelOptions.ModelType.GOOGLE,
vision_enabled=vision_enabled,
)
model_kwargs = {"response_mime_type": "application/json"}
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
response = gemini_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
temperature=temperature,
api_key=api_key,
model_kwargs=model_kwargs,
response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature
)
# Extract, Clean Message from Gemini's Response
@ -102,7 +107,7 @@ def extract_questions_gemini(
return questions
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
"""
Send message to model
"""
@ -114,7 +119,12 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
# Get Response from Gemini
return gemini_completion_with_backoff(
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
messages=messages,
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
temperature=temperature,
model_kwargs=model_kwargs,
)
@ -133,6 +143,8 @@ def converse_gemini(
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
):
"""
Converse with user using Google's Gemini
@ -187,6 +199,9 @@ def converse_gemini(
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE,
)
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)

View file

@ -1,8 +1,11 @@
import logging
import random
from io import BytesIO
from threading import Thread
import google.generativeai as genai
import PIL.Image
import requests
from google.generativeai.types.answer_types import FinishReason
from google.generativeai.types.generation_types import StopCandidateException
from google.generativeai.types.safety_types import (
@ -53,14 +56,14 @@ def gemini_completion_with_backoff(
},
)
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# Start chat session. All messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
try:
# Generate the response. The last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
return aggregated_response.text
except StopCandidateException as e:
response_message, _ = handle_gemini_response(e.args)
@ -117,11 +120,11 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
},
)
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
# all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text
g.send(message)
@ -191,14 +194,6 @@ def generate_safety_response(safety_ratings):
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
if len(messages) == 1:
messages[0].role = "user"
return messages, system_prompt
for message in messages:
if message.role == "assistant":
message.role = "model"
# Extract system message
system_prompt = system_prompt or ""
for message in messages.copy():
@ -207,4 +202,31 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
messages.remove(message)
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
for message in messages:
# Convert message content to string list from chatml dictionary list
if isinstance(message.content, list):
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
message.content = [
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"]
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
]
elif isinstance(message.content, str):
message.content = [message.content]
if message.role == "assistant":
message.role = "model"
if len(messages) == 1:
messages[0].role = "user"
return messages, system_prompt
def get_image_from_url(image_url: str) -> PIL.Image:
try:
response = requests.get(image_url)
response.raise_for_status() # Check if the request was successful
return PIL.Image.open(BytesIO(response.content))
except requests.exceptions.RequestException as e:
logger.error(f"Failed to get image from URL {image_url}: {e}")
return None

View file

@ -30,7 +30,7 @@ def extract_questions(
api_base_url=None,
location_data: LocationData = None,
user: KhojUser = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[list[str]] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
):
@ -74,7 +74,7 @@ def extract_questions(
prompt = construct_structured_message(
message=prompt,
image_url=uploaded_image_url,
images=query_images,
model_type=ChatModelOptions.ModelType.OPENAI,
vision_enabled=vision_enabled,
)
@ -135,7 +135,7 @@ def converse(
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
image_url: Optional[str] = None,
query_images: Optional[list[str]] = None,
vision_available: bool = False,
):
"""
@ -191,7 +191,7 @@ def converse(
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
uploaded_image_url=image_url,
query_images=query_images,
vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI,
)

View file

@ -109,7 +109,7 @@ def save_to_conversation_log(
client_application: ClientApplication = None,
conversation_id: str = None,
automation_id: str = None,
uploaded_image_url: str = None,
query_images: List[str] = None,
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
@ -117,7 +117,7 @@ def save_to_conversation_log(
chat_response=chat_response,
user_message_metadata={
"created": user_message_time,
"uploadedImageData": uploaded_image_url,
"images": query_images,
},
khoj_message_metadata={
"context": compiled_references,
@ -145,10 +145,18 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
)
# Format user and system messages to chatml format
def construct_structured_message(message, image_url, model_type, vision_enabled):
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool):
"""
Format messages into appropriate multimedia format for supported chat model types
"""
if not images or not vision_enabled:
return message
if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]:
return [
{"type": "text", "text": message},
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
]
return message
@ -160,7 +168,7 @@ def generate_chatml_messages_with_context(
loaded_model: Optional[Llama] = None,
max_prompt_size=None,
tokenizer_name=None,
uploaded_image_url=None,
query_images=None,
vision_enabled=False,
model_type="",
):
@ -186,9 +194,7 @@ def generate_chatml_messages_with_context(
else:
message_content = chat["message"] + message_notes
message_content = construct_structured_message(
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
)
message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled)
reconstructed_message = ChatMessage(content=message_content, role=role)
@ -201,7 +207,7 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(user_message):
messages.append(
ChatMessage(
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
content=construct_structured_message(user_message, query_images, model_type, vision_enabled),
role="user",
)
)
@ -225,7 +231,6 @@ def truncate_messages(
tokenizer_name=None,
) -> list[ChatMessage]:
"""Truncate messages to fit within max prompt size supported by model"""
default_tokenizer = "gpt-4o"
try:
@ -255,6 +260,7 @@ def truncate_messages(
system_message = messages.pop(idx)
break
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
system_message_tokens = (
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
)

View file

@ -26,7 +26,7 @@ async def text_to_image(
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
):
status_code = 200
@ -65,7 +65,7 @@ async def text_to_image(
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
user=user,
agent=agent,
)

View file

@ -62,7 +62,7 @@ async def search_online(
user: KhojUser,
send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [],
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
):
query += " ".join(custom_filters)
@ -73,7 +73,7 @@ async def search_online(
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
query, conversation_history, location, user, query_images=query_images, agent=agent
)
response_dict = {}
@ -151,7 +151,7 @@ async def read_webpages(
location: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
):
"Infer web pages to read from the query and extract relevant information from them"
@ -159,7 +159,7 @@ async def read_webpages(
if send_status_func:
async for event in send_status_func(f"**Inferring web pages to read**"):
yield {ChatEvent.STATUS: event}
urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url)
urls = await infer_webpage_urls(query, conversation_history, location, user, query_images)
logger.info(f"Reading web pages at: {urls}")
if send_status_func:

View file

@ -347,7 +347,7 @@ async def extract_references_and_questions(
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[List[str]] = None,
agent: Agent = None,
):
user = request.user.object if request.user.is_authenticated else None
@ -438,7 +438,7 @@ async def extract_references_and_questions(
conversation_log=meta_log,
location_data=location_data,
user=user,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
vision_enabled=vision_enabled,
personality_context=personality_context,
)
@ -459,12 +459,14 @@ async def extract_references_and_questions(
chat_model = conversation_config.chat_model
inferred_queries = extract_questions_gemini(
defiltered_query,
query_images=query_images,
model=chat_model,
api_key=api_key,
conversation_log=meta_log,
location_data=location_data,
max_tokens=conversation_config.max_prompt_size,
user=user,
vision_enabled=vision_enabled,
personality_context=personality_context,
)

View file

@ -30,8 +30,10 @@ from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ApiImageRateLimiter,
ApiUserRateLimiter,
ChatEvent,
ChatRequestBody,
CommonQueryParams,
ConversationCommandRateLimiter,
agenerate_chat_response,
@ -524,22 +526,6 @@ async def set_conversation_title(
)
class ChatRequestBody(BaseModel):
q: str
n: Optional[int] = 7
d: Optional[float] = None
stream: Optional[bool] = False
title: Optional[str] = None
conversation_id: Optional[str] = None
city: Optional[str] = None
region: Optional[str] = None
country: Optional[str] = None
country_code: Optional[str] = None
timezone: Optional[str] = None
image: Optional[str] = None
create_new: Optional[bool] = False
@api_chat.post("")
@requires(["authenticated"])
async def chat(
@ -552,6 +538,7 @@ async def chat(
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
),
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
):
# Access the parameters from the body
q = body.q
@ -565,9 +552,9 @@ async def chat(
country = body.country or get_country_name_from_timezone(body.timezone)
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone
image = body.image
raw_images = body.images
async def event_generator(q: str, image: str):
async def event_generator(q: str, images: list[str]):
start_time = time.perf_counter()
ttft = None
chat_metadata: dict = {}
@ -577,16 +564,16 @@ async def chat(
q = unquote(q)
nonlocal conversation_id
uploaded_image_url = None
if image:
uploaded_images: list[str] = []
if images:
for image in images:
decoded_string = unquote(image)
base64_data = decoded_string.split(",", 1)[1]
image_bytes = base64.b64decode(base64_data)
webp_image_bytes = convert_image_to_webp(image_bytes)
try:
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
except:
uploaded_image_url = None
uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
if uploaded_image:
uploaded_images.append(uploaded_image)
async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft
@ -693,7 +680,7 @@ async def chat(
meta_log,
is_automated_task,
user=user,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
@ -702,7 +689,7 @@ async def chat(
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
@ -765,7 +752,7 @@ async def chat(
q,
contextual_data,
conversation_history=meta_log,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
user=user,
agent=agent,
)
@ -786,7 +773,7 @@ async def chat(
intent_type="summarize",
client_application=request.user.client_app,
conversation_id=conversation_id,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
)
return
@ -829,7 +816,7 @@ async def chat(
conversation_id=conversation_id,
inferred_queries=[query_to_run],
automation_id=automation.id,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
)
async for result in send_llm_response(llm_response):
yield result
@ -849,7 +836,7 @@ async def chat(
conversation_commands,
location,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -893,7 +880,7 @@ async def chat(
user,
partial(send_event, ChatEvent.STATUS),
custom_filters,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -917,7 +904,7 @@ async def chat(
location,
user,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -967,20 +954,20 @@ async def chat(
references=compiled_references,
online_results=online_results,
send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
image, status_code, improved_image_prompt, intent_type = result
generated_image, status_code, improved_image_prompt, intent_type = result
if image is None or status_code != 200:
if generated_image is None or status_code != 200:
content_obj = {
"content-type": "application/json",
"intentType": intent_type,
"detail": improved_image_prompt,
"image": image,
"image": None,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
@ -988,7 +975,7 @@ async def chat(
await sync_to_async(save_to_conversation_log)(
q,
image,
generated_image,
user,
meta_log,
user_message_time,
@ -998,12 +985,12 @@ async def chat(
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
)
content_obj = {
"intentType": intent_type,
"inferredQueries": [improved_image_prompt],
"image": image,
"image": generated_image,
}
async for result in send_llm_response(json.dumps(content_obj)):
yield result
@ -1023,7 +1010,7 @@ async def chat(
location_data=location,
note_references=compiled_references,
online_results=online_results,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
user=user,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
@ -1053,7 +1040,7 @@ async def chat(
conversation_id=conversation_id,
compiled_references=compiled_references,
online_results=online_results,
uploaded_image_url=uploaded_image_url,
query_images=uploaded_images,
)
async for result in send_llm_response(json.dumps(content_obj)):
@ -1076,7 +1063,7 @@ async def chat(
conversation_id,
location,
user_name,
uploaded_image_url,
uploaded_images,
)
# Send Response
@ -1102,9 +1089,9 @@ async def chat(
## Stream Text Response
if stream:
return StreamingResponse(event_generator(q, image=image), media_type="text/plain")
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
## Non-Streaming Text Response
else:
response_iterator = event_generator(q, image=image)
response_iterator = event_generator(q, images=raw_images)
response_data = await read_chat_stream(response_iterator)
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)

View file

@ -1,4 +1,5 @@
import asyncio
import base64
import hashlib
import json
import logging
@ -22,7 +23,7 @@ from typing import (
Tuple,
Union,
)
from urllib.parse import parse_qs, quote, urljoin, urlparse
from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse
import cron_descriptor
import pytz
@ -31,6 +32,7 @@ from apscheduler.job import Job
from apscheduler.triggers.cron import CronTrigger
from asgiref.sync import sync_to_async
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from pydantic import BaseModel
from starlette.authentication import has_required_scope
from starlette.requests import URL
@ -296,7 +298,7 @@ async def aget_relevant_information_sources(
conversation_history: dict,
is_task: bool,
user: KhojUser,
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
):
"""
@ -315,8 +317,8 @@ async def aget_relevant_information_sources(
chat_history = construct_chat_history(conversation_history)
if uploaded_image_url:
query = f"[placeholder for user attached image]\n{query}"
if query_images:
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@ -373,7 +375,7 @@ async def aget_relevant_output_modes(
conversation_history: dict,
is_task: bool = False,
user: KhojUser = None,
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
):
"""
@ -395,8 +397,8 @@ async def aget_relevant_output_modes(
chat_history = construct_chat_history(conversation_history)
if uploaded_image_url:
query = f"[placeholder for user attached image]\n{query}"
if query_images:
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
@ -439,7 +441,7 @@ async def infer_webpage_urls(
conversation_history: dict,
location_data: LocationData,
user: KhojUser,
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
) -> List[str]:
"""
@ -465,7 +467,7 @@ async def infer_webpage_urls(
with timer("Chat actor: Infer webpage urls to read", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list of URLs
@ -485,7 +487,7 @@ async def generate_online_subqueries(
conversation_history: dict,
location_data: LocationData,
user: KhojUser,
uploaded_image_url: str = None,
query_images: List[str] = None,
agent: Agent = None,
) -> List[str]:
"""
@ -511,7 +513,7 @@ async def generate_online_subqueries(
with timer("Chat actor: Generate online search subqueries", logger):
response = await send_message_to_model_wrapper(
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list
@ -530,7 +532,7 @@ async def generate_online_subqueries(
async def schedule_query(
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
) -> Tuple[str, ...]:
"""
Schedule the date, time to run the query. Assume the server timezone is UTC.
@ -543,7 +545,7 @@ async def schedule_query(
)
raw_response = await send_message_to_model_wrapper(
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
crontime_prompt, query_images=query_images, response_type="json_object", user=user
)
# Validate that the response is a non-empty, JSON-serializable list
@ -589,7 +591,7 @@ async def extract_relevant_summary(
q: str,
corpus: str,
conversation_history: dict,
uploaded_image_url: str = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
) -> Union[str, None]:
@ -618,7 +620,7 @@ async def extract_relevant_summary(
extract_relevant_information,
prompts.system_prompt_extract_relevant_summary,
user=user,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
)
return response.strip()
@ -629,7 +631,7 @@ async def generate_excalidraw_diagram(
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
uploaded_image_url: Optional[str] = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
@ -644,7 +646,7 @@ async def generate_excalidraw_diagram(
location_data=location_data,
note_references=note_references,
online_results=online_results,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
user=user,
agent=agent,
)
@ -668,7 +670,7 @@ async def generate_better_diagram_description(
location_data: LocationData,
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
uploaded_image_url: Optional[str] = None,
query_images: List[str] = None,
user: KhojUser = None,
agent: Agent = None,
) -> str:
@ -711,7 +713,7 @@ async def generate_better_diagram_description(
with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper(
improve_diagram_description_prompt, uploaded_image_url=uploaded_image_url, user=user
improve_diagram_description_prompt, query_images=query_images, user=user
)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@ -753,7 +755,7 @@ async def generate_better_image_prompt(
note_references: List[Dict[str, Any]],
online_results: Optional[dict] = None,
model_type: Optional[str] = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[List[str]] = None,
user: KhojUser = None,
agent: Agent = None,
) -> str:
@ -805,7 +807,7 @@ async def generate_better_image_prompt(
)
with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user)
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
response = response[1:-1]
@ -818,11 +820,11 @@ async def send_message_to_model_wrapper(
system_message: str = "",
response_type: str = "text",
user: KhojUser = None,
uploaded_image_url: str = None,
query_images: List[str] = None,
):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
vision_available = conversation_config.vision_enabled
if not vision_available and uploaded_image_url:
if not vision_available and query_images:
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
@ -875,7 +877,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
model_type=conversation_config.model_type,
)
@ -895,7 +897,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
model_type=conversation_config.model_type,
)
@ -913,7 +915,8 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens,
tokenizer_name=tokenizer,
vision_enabled=vision_available,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
model_type=conversation_config.model_type,
)
return gemini_send_message_to_model(
@ -1004,6 +1007,7 @@ def send_message_to_model_wrapper_sync(
model_name=chat_model,
max_prompt_size=max_tokens,
vision_enabled=vision_available,
model_type=conversation_config.model_type,
)
return gemini_send_message_to_model(
@ -1029,7 +1033,7 @@ def generate_chat_response(
conversation_id: str = None,
location_data: LocationData = None,
user_name: Optional[str] = None,
uploaded_image_url: Optional[str] = None,
query_images: Optional[List[str]] = None,
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
chat_response = None
@ -1048,12 +1052,12 @@ def generate_chat_response(
inferred_queries=inferred_queries,
client_application=client_application,
conversation_id=conversation_id,
uploaded_image_url=uploaded_image_url,
query_images=query_images,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
vision_available = conversation_config.vision_enabled
if not vision_available and uploaded_image_url:
if not vision_available and query_images:
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
if vision_enabled_config:
conversation_config = vision_enabled_config
@ -1084,7 +1088,7 @@ def generate_chat_response(
chat_response = converse(
compiled_references,
q,
image_url=uploaded_image_url,
query_images=query_images,
online_results=online_results,
conversation_log=meta_log,
model=chat_model,
@ -1122,8 +1126,9 @@ def generate_chat_response(
chat_response = converse_gemini(
compiled_references,
q,
online_results,
meta_log,
query_images=query_images,
online_results=online_results,
conversation_log=meta_log,
model=conversation_config.chat_model,
api_key=api_key,
completion_func=partial_completion,
@ -1133,6 +1138,7 @@ def generate_chat_response(
location_data=location_data,
user_name=user_name,
agent=agent,
vision_available=vision_available,
)
metadata.update({"chat_model": conversation_config.chat_model})
@ -1144,6 +1150,22 @@ def generate_chat_response(
return chat_response, metadata
class ChatRequestBody(BaseModel):
q: str
n: Optional[int] = 7
d: Optional[float] = None
stream: Optional[bool] = False
title: Optional[str] = None
conversation_id: Optional[str] = None
city: Optional[str] = None
region: Optional[str] = None
country: Optional[str] = None
country_code: Optional[str] = None
timezone: Optional[str] = None
images: Optional[list[str]] = None
create_new: Optional[bool] = False
class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests
@ -1189,13 +1211,58 @@ class ApiUserRateLimiter:
)
raise HTTPException(
status_code=429,
detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
detail="I'm glad you're enjoying interacting with me! But you've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
)
# Add the current request to the cache
UserRequests.objects.create(user=user, slug=self.slug)
class ApiImageRateLimiter:
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
self.max_images = max_images
self.max_combined_size_mb = max_combined_size_mb
def __call__(self, request: Request, body: ChatRequestBody):
if state.billing_enabled is False:
return
# Rate limiting is disabled if user unauthenticated.
# Other systems handle authentication
if not request.user.is_authenticated:
return
if not body.images:
return
# Check number of images
if len(body.images) > self.max_images:
raise HTTPException(
status_code=429,
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
)
# Check total size of images
total_size_mb = 0.0
for image in body.images:
# Unquote the image in case it's URL encoded
image = unquote(image)
# Assuming the image is a base64 encoded string
# Remove the data:image/jpeg;base64, part if present
if "," in image:
image = image.split(",", 1)[1]
# Decode base64 to get the actual size
image_bytes = base64.b64decode(image)
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
if total_size_mb > self.max_combined_size_mb:
raise HTTPException(
status_code=429,
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
)
class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
self.slug = slug

View file

@ -352,9 +352,9 @@ tool_descriptions_for_llm = {
}
mode_descriptions_for_llm = {
ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.",
ConversationCommand.Image: "Use this if the user is requesting you to create a new picture based on their description.",
ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency",
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query.",
ConversationCommand.Diagram: "Use this if the user is requesting a visual representation that requires primitives like lines, rectangles, and text.",
}