mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Merge branch 'master' into improve-agent-pane-on-home-screen
This commit is contained in:
commit
750fbce0c2
51 changed files with 1472 additions and 399 deletions
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "khoj",
|
||||
"name": "Khoj",
|
||||
"version": "1.26.0",
|
||||
"version": "1.26.4",
|
||||
"minAppVersion": "0.15.0",
|
||||
"description": "Your Second Brain",
|
||||
"author": "Khoj Inc.",
|
||||
|
|
|
@ -62,7 +62,7 @@ dependencies = [
|
|||
"requests >= 2.26.0",
|
||||
"tenacity == 8.3.0",
|
||||
"anyio == 3.7.1",
|
||||
"pymupdf >= 1.23.5",
|
||||
"pymupdf == 1.24.11",
|
||||
"django == 5.0.9",
|
||||
"authlib == 1.2.1",
|
||||
"llama-cpp-python == 0.2.88",
|
||||
|
|
|
@ -326,7 +326,7 @@
|
|||
entries.forEach(entry => {
|
||||
// If the element is in the viewport, fetch the remaining message and unobserve the element
|
||||
if (entry.isIntersecting) {
|
||||
fetchRemainingChatMessages(chatHistoryUrl, headers);
|
||||
fetchRemainingChatMessages(chatHistoryUrl, headers, chatBody.dataset.conversation_id, hostURL);
|
||||
observer.unobserve(entry.target);
|
||||
}
|
||||
});
|
||||
|
@ -342,7 +342,11 @@
|
|||
new Date(chat_log.created),
|
||||
chat_log.onlineContext,
|
||||
chat_log.intent?.type,
|
||||
chat_log.intent?.["inferred-queries"]);
|
||||
chat_log.intent?.["inferred-queries"],
|
||||
chatBody.dataset.conversationId ?? "",
|
||||
hostURL,
|
||||
);
|
||||
|
||||
chatBody.appendChild(messageElement);
|
||||
|
||||
// When the 4th oldest message is within viewing distance (~60% scrolled up)
|
||||
|
@ -421,7 +425,7 @@
|
|||
}
|
||||
}
|
||||
|
||||
function fetchRemainingChatMessages(chatHistoryUrl, headers) {
|
||||
function fetchRemainingChatMessages(chatHistoryUrl, headers, conversationId, hostURL) {
|
||||
// Create a new IntersectionObserver
|
||||
let observer = new IntersectionObserver((entries, observer) => {
|
||||
entries.forEach(entry => {
|
||||
|
@ -435,7 +439,9 @@
|
|||
new Date(chat_log.created),
|
||||
chat_log.onlineContext,
|
||||
chat_log.intent?.type,
|
||||
chat_log.intent?.["inferred-queries"]
|
||||
chat_log.intent?.["inferred-queries"],
|
||||
chatBody.dataset.conversationId ?? "",
|
||||
hostURL,
|
||||
);
|
||||
entry.target.replaceWith(messageElement);
|
||||
|
||||
|
|
|
@ -189,11 +189,19 @@ function processOnlineReferences(referenceSection, onlineContext) { //same
|
|||
return numOnlineReferences;
|
||||
}
|
||||
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null) { //same
|
||||
function renderMessageWithReference(message, by, context=null, dt=null, onlineContext=null, intentType=null, inferredQueries=null, conversationId=null, hostURL=null) {
|
||||
let chatEl;
|
||||
if (intentType?.includes("text-to-image")) {
|
||||
let imageMarkdown = generateImageMarkdown(message, intentType, inferredQueries);
|
||||
chatEl = renderMessage(imageMarkdown, by, dt, null, false, "return");
|
||||
} else if (intentType === "excalidraw") {
|
||||
let domain = hostURL ?? "https://app.khoj.dev/";
|
||||
|
||||
if (!domain.endsWith("/")) domain += "/";
|
||||
|
||||
let excalidrawMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in the web app at ${domain}chat?conversationId=${conversationId}`;
|
||||
|
||||
chatEl = renderMessage(excalidrawMessage, by, dt, null, false, "return");
|
||||
} else {
|
||||
chatEl = renderMessage(message, by, dt, null, false, "return");
|
||||
}
|
||||
|
@ -312,7 +320,6 @@ function formatHTMLMessage(message, raw=false, willReplace=true) { //same
|
|||
}
|
||||
|
||||
function createReferenceSection(references, createLinkerSection=false) {
|
||||
console.log("linker data: ", createLinkerSection);
|
||||
let referenceSection = document.createElement('div');
|
||||
referenceSection.classList.add("reference-section");
|
||||
referenceSection.classList.add("collapsed");
|
||||
|
@ -417,7 +424,11 @@ function handleImageResponse(imageJson, rawResponse) {
|
|||
rawResponse += `![generated_image](${imageJson.image})`;
|
||||
} else if (imageJson.intentType === "text-to-image-v3") {
|
||||
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
|
||||
} else if (imageJson.intentType === "excalidraw") {
|
||||
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in the web app`;
|
||||
rawResponse += redirectMessage;
|
||||
}
|
||||
|
||||
if (inferredQuery) {
|
||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "Khoj",
|
||||
"version": "1.26.0",
|
||||
"version": "1.26.4",
|
||||
"description": "Your Second Brain",
|
||||
"author": "Khoj Inc. <team@khoj.dev>",
|
||||
"license": "GPL-3.0-or-later",
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
;; Saba Imran <saba@khoj.dev>
|
||||
;; Description: Your Second Brain
|
||||
;; Keywords: search, chat, ai, org-mode, outlines, markdown, pdf, image
|
||||
;; Version: 1.26.0
|
||||
;; Version: 1.26.4
|
||||
;; Package-Requires: ((emacs "27.1") (transient "0.3.0") (dash "2.19.1"))
|
||||
;; URL: https://github.com/khoj-ai/khoj/tree/master/src/interface/emacs
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
{
|
||||
"id": "khoj",
|
||||
"name": "Khoj",
|
||||
"version": "1.26.0",
|
||||
"version": "1.26.4",
|
||||
"minAppVersion": "0.15.0",
|
||||
"description": "Your Second Brain",
|
||||
"author": "Khoj Inc.",
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "Khoj",
|
||||
"version": "1.26.0",
|
||||
"version": "1.26.4",
|
||||
"description": "Your Second Brain",
|
||||
"author": "Debanjum Singh Solanky, Saba Imran <team@khoj.dev>",
|
||||
"license": "GPL-3.0-or-later",
|
||||
|
|
|
@ -484,12 +484,13 @@ export class KhojChatView extends KhojPaneView {
|
|||
dt?: Date,
|
||||
intentType?: string,
|
||||
inferredQueries?: string[],
|
||||
conversationId?: string,
|
||||
) {
|
||||
if (!message) return;
|
||||
|
||||
let chatMessageEl;
|
||||
if (intentType?.includes("text-to-image")) {
|
||||
let imageMarkdown = this.generateImageMarkdown(message, intentType, inferredQueries);
|
||||
if (intentType?.includes("text-to-image") || intentType === "excalidraw") {
|
||||
let imageMarkdown = this.generateImageMarkdown(message, intentType, inferredQueries, conversationId);
|
||||
chatMessageEl = this.renderMessage(chatEl, imageMarkdown, sender, dt);
|
||||
} else {
|
||||
chatMessageEl = this.renderMessage(chatEl, message, sender, dt);
|
||||
|
@ -509,7 +510,7 @@ export class KhojChatView extends KhojPaneView {
|
|||
chatMessageBodyEl.appendChild(this.createReferenceSection(references));
|
||||
}
|
||||
|
||||
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[]) {
|
||||
generateImageMarkdown(message: string, intentType: string, inferredQueries?: string[], conversationId?: string): string {
|
||||
let imageMarkdown = "";
|
||||
if (intentType === "text-to-image") {
|
||||
imageMarkdown = `![](data:image/png;base64,${message})`;
|
||||
|
@ -517,6 +518,10 @@ export class KhojChatView extends KhojPaneView {
|
|||
imageMarkdown = `![](${message})`;
|
||||
} else if (intentType === "text-to-image-v3") {
|
||||
imageMarkdown = `![](data:image/webp;base64,${message})`;
|
||||
} else if (intentType === "excalidraw") {
|
||||
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
|
||||
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}chat?conversationId=${conversationId}`;
|
||||
imageMarkdown = redirectMessage;
|
||||
}
|
||||
if (inferredQueries) {
|
||||
imageMarkdown += "\n\n**Inferred Query**:";
|
||||
|
@ -884,6 +889,7 @@ export class KhojChatView extends KhojPaneView {
|
|||
new Date(chatLog.created),
|
||||
chatLog.intent?.type,
|
||||
chatLog.intent?.["inferred-queries"],
|
||||
chatBodyEl.dataset.conversationId ?? "",
|
||||
);
|
||||
// push the user messages to the chat history
|
||||
if(chatLog.by === "you"){
|
||||
|
@ -1354,6 +1360,10 @@ export class KhojChatView extends KhojPaneView {
|
|||
rawResponse += `![generated_image](${imageJson.image})`;
|
||||
} else if (imageJson.intentType === "text-to-image-v3") {
|
||||
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
|
||||
} else if (imageJson.intentType === "excalidraw") {
|
||||
const domain = this.setting.khojUrl.endsWith("/") ? this.setting.khojUrl : `${this.setting.khojUrl}/`;
|
||||
const redirectMessage = `Hey, I'm not ready to show you diagrams yet here. But you can view it in ${domain}`;
|
||||
rawResponse += redirectMessage;
|
||||
}
|
||||
if (inferredQuery) {
|
||||
rawResponse += `\n\n**Inferred Query**:\n\n${inferredQuery}`;
|
||||
|
|
|
@ -78,5 +78,9 @@
|
|||
"1.24.0": "0.15.0",
|
||||
"1.24.1": "0.15.0",
|
||||
"1.25.0": "0.15.0",
|
||||
"1.26.0": "0.15.0"
|
||||
"1.26.0": "0.15.0",
|
||||
"1.26.1": "0.15.0",
|
||||
"1.26.2": "0.15.0",
|
||||
"1.26.3": "0.15.0",
|
||||
"1.26.4": "0.15.0"
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ div.titleBar {
|
|||
div.chatBoxBody {
|
||||
display: grid;
|
||||
height: 100%;
|
||||
width: 70%;
|
||||
width: 95%;
|
||||
margin: auto;
|
||||
}
|
||||
|
||||
|
|
|
@ -47,7 +47,14 @@ export default function RootLayout({
|
|||
child-src 'none';
|
||||
object-src 'none';"
|
||||
></meta>
|
||||
<body className={inter.className}>{children}</body>
|
||||
<body className={inter.className}>
|
||||
{children}
|
||||
<script
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: `window.EXCALIDRAW_ASSET_PATH = 'https://assets.khoj.dev/@excalidraw/excalidraw/dist/';`,
|
||||
}}
|
||||
/>
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
|
|
|
@ -27,32 +27,37 @@ interface ChatBodyDataProps {
|
|||
setUploadedFiles: (files: string[]) => void;
|
||||
isMobileWidth?: boolean;
|
||||
isLoggedIn: boolean;
|
||||
setImage64: (image64: string) => void;
|
||||
setImages: (images: string[]) => void;
|
||||
}
|
||||
|
||||
function ChatBodyData(props: ChatBodyDataProps) {
|
||||
const searchParams = useSearchParams();
|
||||
const conversationId = searchParams.get("conversationId");
|
||||
const [message, setMessage] = useState("");
|
||||
const [image, setImage] = useState<string | null>(null);
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
const [processingMessage, setProcessingMessage] = useState(false);
|
||||
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
||||
|
||||
const setQueryToProcess = props.setQueryToProcess;
|
||||
const onConversationIdChange = props.onConversationIdChange;
|
||||
|
||||
useEffect(() => {
|
||||
if (image) {
|
||||
props.setImage64(encodeURIComponent(image));
|
||||
}
|
||||
}, [image, props.setImage64]);
|
||||
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
|
||||
|
||||
useEffect(() => {
|
||||
const storedImage = localStorage.getItem("image");
|
||||
if (storedImage) {
|
||||
setImage(storedImage);
|
||||
props.setImage64(encodeURIComponent(storedImage));
|
||||
localStorage.removeItem("image");
|
||||
if (images.length > 0) {
|
||||
const encodedImages = images.map((image) => encodeURIComponent(image));
|
||||
props.setImages(encodedImages);
|
||||
}
|
||||
}, [images, props.setImages]);
|
||||
|
||||
useEffect(() => {
|
||||
const storedImages = localStorage.getItem("images");
|
||||
if (storedImages) {
|
||||
const parsedImages: string[] = JSON.parse(storedImages);
|
||||
setImages(parsedImages);
|
||||
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
|
||||
props.setImages(encodedImages);
|
||||
localStorage.removeItem("images");
|
||||
}
|
||||
|
||||
const storedMessage = localStorage.getItem("message");
|
||||
|
@ -60,7 +65,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
setProcessingMessage(true);
|
||||
setQueryToProcess(storedMessage);
|
||||
}
|
||||
}, [setQueryToProcess]);
|
||||
}, [setQueryToProcess, props.setImages]);
|
||||
|
||||
useEffect(() => {
|
||||
if (message) {
|
||||
|
@ -82,6 +87,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
props.streamedMessages[props.streamedMessages.length - 1].completed
|
||||
) {
|
||||
setProcessingMessage(false);
|
||||
setImages([]); // Reset images after processing
|
||||
} else {
|
||||
setMessage("");
|
||||
}
|
||||
|
@ -101,16 +107,17 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
setAgent={setAgentMetadata}
|
||||
pendingMessage={processingMessage ? message : ""}
|
||||
incomingMessages={props.streamedMessages}
|
||||
customClassName={chatHistoryCustomClassName}
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl h-fit`}
|
||||
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl h-fit ${chatHistoryCustomClassName} mr-auto ml-auto`}
|
||||
>
|
||||
<ChatInputArea
|
||||
agentColor={agentMetadata?.color}
|
||||
isLoggedIn={props.isLoggedIn}
|
||||
sendMessage={(message) => setMessage(message)}
|
||||
sendImage={(image) => setImage(image)}
|
||||
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||
sendDisabled={processingMessage}
|
||||
chatOptionsData={props.chatOptionsData}
|
||||
conversationId={conversationId}
|
||||
|
@ -132,7 +139,7 @@ export default function Chat() {
|
|||
const [queryToProcess, setQueryToProcess] = useState<string>("");
|
||||
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||
const [image64, setImage64] = useState<string>("");
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
|
||||
const locationData = useIPLocationData() || {
|
||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||
|
@ -168,7 +175,7 @@ export default function Chat() {
|
|||
completed: false,
|
||||
timestamp: new Date().toISOString(),
|
||||
rawQuery: queryToProcess || "",
|
||||
uploadedImageData: decodeURIComponent(image64),
|
||||
images: images,
|
||||
};
|
||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
||||
setProcessQuerySignal(true);
|
||||
|
@ -199,7 +206,7 @@ export default function Chat() {
|
|||
if (done) {
|
||||
setQueryToProcess("");
|
||||
setProcessQuerySignal(false);
|
||||
setImage64("");
|
||||
setImages([]);
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -247,7 +254,7 @@ export default function Chat() {
|
|||
country_code: locationData.countryCode,
|
||||
timezone: locationData.timezone,
|
||||
}),
|
||||
...(image64 && { image: image64 }),
|
||||
...(images.length > 0 && { images: images }),
|
||||
};
|
||||
|
||||
const response = await fetch(chatAPI, {
|
||||
|
@ -261,7 +268,8 @@ export default function Chat() {
|
|||
try {
|
||||
await readChatStream(response);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
const apiError = await response.json();
|
||||
console.error(apiError);
|
||||
// Retrieve latest message being processed
|
||||
const currentMessage = messages.find((message) => !message.completed);
|
||||
if (!currentMessage) return;
|
||||
|
@ -270,7 +278,11 @@ export default function Chat() {
|
|||
const errorMessage = (err as Error).message;
|
||||
if (errorMessage.includes("Error in input stream"))
|
||||
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
|
||||
else
|
||||
else if (response.status === 429) {
|
||||
"detail" in apiError
|
||||
? (currentMessage.rawResponse = `${apiError.detail}`)
|
||||
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
|
||||
} else
|
||||
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
|
||||
|
||||
// Complete message streaming teardown properly
|
||||
|
@ -329,7 +341,7 @@ export default function Chat() {
|
|||
setUploadedFiles={setUploadedFiles}
|
||||
isMobileWidth={isMobileWidth}
|
||||
onConversationIdChange={handleConversationIdChange}
|
||||
setImage64={setImage64}
|
||||
setImages={setImages}
|
||||
/>
|
||||
</Suspense>
|
||||
</div>
|
||||
|
|
|
@ -5,10 +5,10 @@ export interface RawReferenceData {
|
|||
onlineContext?: OnlineContext;
|
||||
}
|
||||
|
||||
export interface ResponseWithReferences {
|
||||
context?: Context[];
|
||||
online?: OnlineContext;
|
||||
response?: string;
|
||||
export interface ResponseWithIntent {
|
||||
intentType: string;
|
||||
response: string;
|
||||
inferredQueries?: string[];
|
||||
}
|
||||
|
||||
interface MessageChunk {
|
||||
|
@ -49,10 +49,14 @@ export function convertMessageChunkToJson(chunk: string): MessageChunk {
|
|||
function handleJsonResponse(chunkData: any) {
|
||||
const jsonData = chunkData as any;
|
||||
if (jsonData.image || jsonData.detail) {
|
||||
let responseWithReference = handleImageResponse(chunkData, true);
|
||||
if (responseWithReference.response) return responseWithReference.response;
|
||||
let responseWithIntent = handleImageResponse(chunkData, true);
|
||||
return responseWithIntent;
|
||||
} else if (jsonData.response) {
|
||||
return jsonData.response;
|
||||
return {
|
||||
response: jsonData.response,
|
||||
intentType: "",
|
||||
inferredQueries: [],
|
||||
};
|
||||
} else {
|
||||
throw new Error("Invalid JSON response");
|
||||
}
|
||||
|
@ -80,8 +84,18 @@ export function processMessageChunk(
|
|||
return { context, onlineContext };
|
||||
} else if (chunk.type === "message") {
|
||||
const chunkData = chunk.data;
|
||||
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw
|
||||
if (chunkData !== null && typeof chunkData === "object") {
|
||||
currentMessage.rawResponse += handleJsonResponse(chunkData);
|
||||
let responseWithIntent = handleJsonResponse(chunkData);
|
||||
|
||||
if (responseWithIntent.intentType && responseWithIntent.intentType === "excalidraw") {
|
||||
currentMessage.rawResponse = responseWithIntent.response;
|
||||
} else {
|
||||
currentMessage.rawResponse += responseWithIntent.response;
|
||||
}
|
||||
|
||||
currentMessage.intentType = responseWithIntent.intentType;
|
||||
currentMessage.inferredQueries = responseWithIntent.inferredQueries;
|
||||
} else if (
|
||||
typeof chunkData === "string" &&
|
||||
chunkData.trim()?.startsWith("{") &&
|
||||
|
@ -89,7 +103,10 @@ export function processMessageChunk(
|
|||
) {
|
||||
try {
|
||||
const jsonData = JSON.parse(chunkData.trim());
|
||||
currentMessage.rawResponse += handleJsonResponse(jsonData);
|
||||
let responseWithIntent = handleJsonResponse(jsonData);
|
||||
currentMessage.rawResponse += responseWithIntent.response;
|
||||
currentMessage.intentType = responseWithIntent.intentType;
|
||||
currentMessage.inferredQueries = responseWithIntent.inferredQueries;
|
||||
} catch (e) {
|
||||
currentMessage.rawResponse += JSON.stringify(chunkData);
|
||||
}
|
||||
|
@ -111,42 +128,26 @@ export function processMessageChunk(
|
|||
return { context, onlineContext };
|
||||
}
|
||||
|
||||
export function handleImageResponse(imageJson: any, liveStream: boolean): ResponseWithReferences {
|
||||
export function handleImageResponse(imageJson: any, liveStream: boolean): ResponseWithIntent {
|
||||
let rawResponse = "";
|
||||
|
||||
if (imageJson.image) {
|
||||
const inferredQuery = imageJson.inferredQueries?.[0] ?? "generated image";
|
||||
|
||||
// If response has image field, response is a generated image.
|
||||
if (imageJson.intentType === "text-to-image") {
|
||||
rawResponse += `![generated_image](data:image/png;base64,${imageJson.image})`;
|
||||
} else if (imageJson.intentType === "text-to-image2") {
|
||||
rawResponse += `![generated_image](${imageJson.image})`;
|
||||
} else if (imageJson.intentType === "text-to-image-v3") {
|
||||
rawResponse = `![](data:image/webp;base64,${imageJson.image})`;
|
||||
}
|
||||
if (inferredQuery && !liveStream) {
|
||||
rawResponse += `\n\n${inferredQuery}`;
|
||||
}
|
||||
// If response has image field, response may be a generated image
|
||||
rawResponse = imageJson.image;
|
||||
}
|
||||
|
||||
let reference: ResponseWithReferences = {};
|
||||
let responseWithIntent: ResponseWithIntent = {
|
||||
intentType: imageJson.intentType,
|
||||
response: rawResponse,
|
||||
inferredQueries: imageJson.inferredQueries,
|
||||
};
|
||||
|
||||
if (imageJson.context && imageJson.context.length > 0) {
|
||||
const rawReferenceAsJson = imageJson.context;
|
||||
if (rawReferenceAsJson instanceof Array) {
|
||||
reference.context = rawReferenceAsJson;
|
||||
} else if (typeof rawReferenceAsJson === "object" && rawReferenceAsJson !== null) {
|
||||
reference.online = rawReferenceAsJson;
|
||||
}
|
||||
}
|
||||
if (imageJson.detail) {
|
||||
// The detail field contains the improved image prompt
|
||||
rawResponse += imageJson.detail;
|
||||
}
|
||||
|
||||
reference.response = rawResponse;
|
||||
return reference;
|
||||
return responseWithIntent;
|
||||
}
|
||||
|
||||
export function modifyFileFilterForConversation(
|
||||
|
|
|
@ -48,6 +48,7 @@ import {
|
|||
Oven,
|
||||
Gavel,
|
||||
Broadcast,
|
||||
KeyReturn,
|
||||
} from "@phosphor-icons/react";
|
||||
import { Markdown, OrgMode, Pdf, Word } from "@/app/components/logo/fileLogo";
|
||||
|
||||
|
@ -193,6 +194,10 @@ export function getIconForSlashCommand(command: string, customClassName: string
|
|||
}
|
||||
|
||||
if (command.includes("default")) {
|
||||
return <KeyReturn className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("diagram")) {
|
||||
return <Shapes className={className} />;
|
||||
}
|
||||
|
||||
|
|
|
@ -2,12 +2,7 @@ div.chatHistory {
|
|||
display: flex;
|
||||
flex-direction: column;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
div.chatLayout {
|
||||
height: 80vh;
|
||||
overflow-y: auto;
|
||||
margin: 0 auto;
|
||||
margin: auto;
|
||||
}
|
||||
|
||||
div.agentIndicator a {
|
||||
|
|
|
@ -37,6 +37,7 @@ interface ChatHistoryProps {
|
|||
pendingMessage?: string;
|
||||
publicConversationSlug?: string;
|
||||
setAgent: (agent: AgentData) => void;
|
||||
customClassName?: string;
|
||||
}
|
||||
|
||||
function constructTrainOfThought(
|
||||
|
@ -255,7 +256,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
return (
|
||||
<ScrollArea className={`h-[80vh] relative`} ref={scrollAreaRef}>
|
||||
<div>
|
||||
<div className={styles.chatHistory}>
|
||||
<div className={`${styles.chatHistory} ${props.customClassName}`}>
|
||||
<div ref={sentinelRef} style={{ height: "1px" }}>
|
||||
{fetchingData && (
|
||||
<InlineLoading message="Loading Conversation" className="opacity-50" />
|
||||
|
@ -298,7 +299,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
created: message.timestamp,
|
||||
by: "you",
|
||||
automationId: "",
|
||||
uploadedImageData: message.uploadedImageData,
|
||||
images: message.images,
|
||||
}}
|
||||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
|
@ -322,6 +323,12 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
by: "khoj",
|
||||
automationId: "",
|
||||
rawQuery: message.rawQuery,
|
||||
intent: {
|
||||
type: message.intentType || "",
|
||||
query: message.rawQuery,
|
||||
"memory-type": "",
|
||||
"inferred-queries": message.inferredQueries || [],
|
||||
},
|
||||
}}
|
||||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
|
@ -341,7 +348,6 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
created: new Date().getTime().toString(),
|
||||
by: "you",
|
||||
automationId: "",
|
||||
uploadedImageData: props.pendingMessage,
|
||||
}}
|
||||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
|
@ -366,18 +372,20 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
</div>
|
||||
)}
|
||||
</div>
|
||||
{!isNearBottom && (
|
||||
<button
|
||||
title="Scroll to bottom"
|
||||
className="absolute bottom-4 right-5 bg-white dark:bg-[hsl(var(--background))] text-neutral-500 dark:text-white p-2 rounded-full shadow-xl"
|
||||
onClick={() => {
|
||||
scrollToBottom();
|
||||
setIsNearBottom(true);
|
||||
}}
|
||||
>
|
||||
<ArrowDown size={24} />
|
||||
</button>
|
||||
)}
|
||||
<div className={`${props.customClassName} fixed bottom-[15%] z-10`}>
|
||||
{!isNearBottom && (
|
||||
<button
|
||||
title="Scroll to bottom"
|
||||
className="absolute bottom-0 right-0 bg-white dark:bg-[hsl(var(--background))] text-neutral-500 dark:text-white p-2 rounded-full shadow-xl"
|
||||
onClick={() => {
|
||||
scrollToBottom();
|
||||
setIsNearBottom(true);
|
||||
}}
|
||||
>
|
||||
<ArrowDown size={24} />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</ScrollArea>
|
||||
);
|
||||
|
|
|
@ -62,10 +62,11 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
const [loginRedirectMessage, setLoginRedirectMessage] = useState<string | null>(null);
|
||||
const [showLoginPrompt, setShowLoginPrompt] = useState(false);
|
||||
|
||||
const [recording, setRecording] = useState(false);
|
||||
const [imageUploaded, setImageUploaded] = useState(false);
|
||||
const [imagePath, setImagePath] = useState<string>("");
|
||||
const [imageData, setImageData] = useState<string | null>(null);
|
||||
const [imagePaths, setImagePaths] = useState<string[]>([]);
|
||||
const [imageData, setImageData] = useState<string[]>([]);
|
||||
|
||||
const [recording, setRecording] = useState(false);
|
||||
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
|
||||
|
||||
const [progressValue, setProgressValue] = useState(0);
|
||||
|
@ -90,27 +91,31 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
|
||||
useEffect(() => {
|
||||
async function fetchImageData() {
|
||||
if (imagePath) {
|
||||
const response = await fetch(imagePath);
|
||||
const blob = await response.blob();
|
||||
const reader = new FileReader();
|
||||
reader.onload = function () {
|
||||
const base64data = reader.result;
|
||||
setImageData(base64data as string);
|
||||
};
|
||||
reader.readAsDataURL(blob);
|
||||
if (imagePaths.length > 0) {
|
||||
const newImageData = await Promise.all(
|
||||
imagePaths.map(async (path) => {
|
||||
const response = await fetch(path);
|
||||
const blob = await response.blob();
|
||||
return new Promise<string>((resolve) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => resolve(reader.result as string);
|
||||
reader.readAsDataURL(blob);
|
||||
});
|
||||
}),
|
||||
);
|
||||
setImageData(newImageData);
|
||||
}
|
||||
setUploading(false);
|
||||
}
|
||||
setUploading(true);
|
||||
fetchImageData();
|
||||
}, [imagePath]);
|
||||
}, [imagePaths]);
|
||||
|
||||
function onSendMessage() {
|
||||
if (imageUploaded) {
|
||||
setImageUploaded(false);
|
||||
setImagePath("");
|
||||
props.sendImage(imageData || "");
|
||||
setImagePaths([]);
|
||||
imageData.forEach((data) => props.sendImage(data));
|
||||
}
|
||||
if (!message.trim()) return;
|
||||
|
||||
|
@ -156,18 +161,23 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
setShowLoginPrompt(true);
|
||||
return;
|
||||
}
|
||||
// check for image file
|
||||
// check for image files
|
||||
const image_endings = ["jpg", "jpeg", "png", "webp"];
|
||||
const newImagePaths: string[] = [];
|
||||
for (let i = 0; i < files.length; i++) {
|
||||
const file = files[i];
|
||||
const file_extension = file.name.split(".").pop();
|
||||
if (image_endings.includes(file_extension || "")) {
|
||||
setImageUploaded(true);
|
||||
setImagePath(DOMPurify.sanitize(URL.createObjectURL(file)));
|
||||
return;
|
||||
newImagePaths.push(DOMPurify.sanitize(URL.createObjectURL(file)));
|
||||
}
|
||||
}
|
||||
|
||||
if (newImagePaths.length > 0) {
|
||||
setImageUploaded(true);
|
||||
setImagePaths((prevPaths) => [...prevPaths, ...newImagePaths]);
|
||||
return;
|
||||
}
|
||||
|
||||
uploadDataForIndexing(
|
||||
files,
|
||||
setWarning,
|
||||
|
@ -272,9 +282,12 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
setIsDragAndDropping(false);
|
||||
}
|
||||
|
||||
function removeImageUpload() {
|
||||
setImageUploaded(false);
|
||||
setImagePath("");
|
||||
function removeImageUpload(index: number) {
|
||||
setImagePaths((prevPaths) => prevPaths.filter((_, i) => i !== index));
|
||||
setImageData((prevData) => prevData.filter((_, i) => i !== index));
|
||||
if (imagePaths.length === 1) {
|
||||
setImageUploaded(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
|
@ -391,24 +404,11 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
</div>
|
||||
)}
|
||||
<div
|
||||
className={`${styles.actualInputArea} items-center justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
||||
className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
||||
onDragOver={handleDragOver}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={handleDragAndDropFiles}
|
||||
>
|
||||
{imageUploaded && (
|
||||
<div className="absolute bottom-[80px] left-0 right-0 dark:bg-neutral-700 bg-white pt-5 pb-5 w-full rounded-lg border dark:border-none grid grid-cols-2">
|
||||
<div className="pl-4 pr-4">
|
||||
<img src={imagePath} alt="img" className="w-auto max-h-[100px]" />
|
||||
</div>
|
||||
<div className="pl-4 pr-4">
|
||||
<X
|
||||
className="w-6 h-6 float-right dark:hover:bg-[hsl(var(--background))] hover:bg-neutral-100 rounded-sm"
|
||||
onClick={removeImageUpload}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<input
|
||||
type="file"
|
||||
multiple={true}
|
||||
|
@ -416,15 +416,37 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
onChange={handleFileChange}
|
||||
style={{ display: "none" }}
|
||||
/>
|
||||
<Button
|
||||
variant={"ghost"}
|
||||
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
|
||||
disabled={props.sendDisabled}
|
||||
onClick={handleFileButtonClick}
|
||||
>
|
||||
<Paperclip className="w-8 h-8" />
|
||||
</Button>
|
||||
<div className="grid w-full gap-1.5 relative">
|
||||
<div className="flex items-end pb-4">
|
||||
<Button
|
||||
variant={"ghost"}
|
||||
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
|
||||
disabled={props.sendDisabled}
|
||||
onClick={handleFileButtonClick}
|
||||
>
|
||||
<Paperclip className="w-8 h-8" />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex-grow flex flex-col w-full gap-1.5 relative pb-2">
|
||||
<div className="flex items-center gap-2 overflow-x-auto">
|
||||
{imageUploaded &&
|
||||
imagePaths.map((path, index) => (
|
||||
<div key={index} className="relative flex-shrink-0 pb-3 pt-2 group">
|
||||
<img
|
||||
src={path}
|
||||
alt={`img-${index}`}
|
||||
className="w-auto h-16 object-cover rounded-xl"
|
||||
/>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
onClick={() => removeImageUpload(index)}
|
||||
>
|
||||
<X className="h-3 w-3" />
|
||||
</Button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<Textarea
|
||||
ref={chatInputRef}
|
||||
className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`}
|
||||
|
@ -435,7 +457,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
onKeyDown={(e) => {
|
||||
if (e.key === "Enter" && !e.shiftKey) {
|
||||
setImageUploaded(false);
|
||||
setImagePath("");
|
||||
setImagePaths([]);
|
||||
e.preventDefault();
|
||||
onSendMessage();
|
||||
}
|
||||
|
@ -444,57 +466,59 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
disabled={props.sendDisabled || recording}
|
||||
/>
|
||||
</div>
|
||||
{recording ? (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="default"
|
||||
className={`${!recording && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||
onClick={() => {
|
||||
setRecording(!recording);
|
||||
}}
|
||||
disabled={props.sendDisabled}
|
||||
>
|
||||
<Stop weight="fill" className="w-6 h-6" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
Click to stop recording and transcribe your voice.
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
) : mediaRecorder ? (
|
||||
<InlineLoading />
|
||||
) : (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="default"
|
||||
className={`${!message || recording || "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||
onClick={() => {
|
||||
setMessage("Listening...");
|
||||
setRecording(!recording);
|
||||
}}
|
||||
disabled={props.sendDisabled}
|
||||
>
|
||||
<Microphone weight="fill" className="w-6 h-6" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
Click to transcribe your message with voice.
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
<Button
|
||||
className={`${(!message || recording) && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||
onClick={onSendMessage}
|
||||
disabled={props.sendDisabled}
|
||||
>
|
||||
<ArrowUp className="w-6 h-6" weight="bold" />
|
||||
</Button>
|
||||
<div className="flex items-end pb-4">
|
||||
{recording ? (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="default"
|
||||
className={`${!recording && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||
onClick={() => {
|
||||
setRecording(!recording);
|
||||
}}
|
||||
disabled={props.sendDisabled}
|
||||
>
|
||||
<Stop weight="fill" className="w-6 h-6" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
Click to stop recording and transcribe your voice.
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
) : mediaRecorder ? (
|
||||
<InlineLoading />
|
||||
) : (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="default"
|
||||
className={`${!message || recording || "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||
onClick={() => {
|
||||
setMessage("Listening...");
|
||||
setRecording(!recording);
|
||||
}}
|
||||
disabled={props.sendDisabled}
|
||||
>
|
||||
<Microphone weight="fill" className="w-6 h-6" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
Click to transcribe your message with voice.
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)}
|
||||
<Button
|
||||
className={`${(!message || recording) && "hidden"} ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} rounded-full p-1 m-2 h-auto text-3xl transition transform md:hover:-translate-y-1`}
|
||||
onClick={onSendMessage}
|
||||
disabled={props.sendDisabled}
|
||||
>
|
||||
<ArrowUp className="w-6 h-6" weight="bold" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
|
|
|
@ -57,7 +57,26 @@ div.emptyChatMessage {
|
|||
display: none;
|
||||
}
|
||||
|
||||
div.chatMessageContainer img {
|
||||
div.imagesContainer {
|
||||
display: flex;
|
||||
overflow-x: auto;
|
||||
padding-bottom: 8px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
div.imageWrapper {
|
||||
flex: 0 0 auto;
|
||||
margin-right: 8px;
|
||||
}
|
||||
|
||||
div.imageWrapper img {
|
||||
width: auto;
|
||||
height: 128px;
|
||||
object-fit: cover;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
div.chatMessageContainer > img {
|
||||
width: auto;
|
||||
height: auto;
|
||||
max-width: 100%;
|
||||
|
|
|
@ -26,6 +26,7 @@ import {
|
|||
Palette,
|
||||
ClipboardText,
|
||||
Check,
|
||||
Shapes,
|
||||
} from "@phosphor-icons/react";
|
||||
|
||||
import DOMPurify from "dompurify";
|
||||
|
@ -35,6 +36,7 @@ import { AgentData } from "@/app/agents/page";
|
|||
|
||||
import renderMathInElement from "katex/contrib/auto-render";
|
||||
import "katex/dist/katex.min.css";
|
||||
import ExcalidrawComponent from "../excalidraw/excalidraw";
|
||||
|
||||
const md = new markdownIt({
|
||||
html: true,
|
||||
|
@ -114,7 +116,7 @@ export interface SingleChatMessage {
|
|||
rawQuery?: string;
|
||||
intent?: Intent;
|
||||
agent?: AgentData;
|
||||
uploadedImageData?: string;
|
||||
images?: string[];
|
||||
}
|
||||
|
||||
export interface StreamMessage {
|
||||
|
@ -126,7 +128,9 @@ export interface StreamMessage {
|
|||
rawQuery: string;
|
||||
timestamp: string;
|
||||
agent?: AgentData;
|
||||
uploadedImageData?: string;
|
||||
images?: string[];
|
||||
intentType?: string;
|
||||
inferredQueries?: string[];
|
||||
}
|
||||
|
||||
export interface ChatHistoryData {
|
||||
|
@ -208,7 +212,6 @@ interface ChatMessageProps {
|
|||
borderLeftColor?: string;
|
||||
isLastMessage?: boolean;
|
||||
agent?: AgentData;
|
||||
uploadedImageData?: string;
|
||||
}
|
||||
|
||||
interface TrainOfThoughtProps {
|
||||
|
@ -252,6 +255,10 @@ function chooseIconFromHeader(header: string, iconColor: string) {
|
|||
return <Aperture className={`${classNames}`} />;
|
||||
}
|
||||
|
||||
if (compareHeader.includes("diagram")) {
|
||||
return <Shapes className={`${classNames}`} />;
|
||||
}
|
||||
|
||||
if (compareHeader.includes("paint")) {
|
||||
return <Palette className={`${classNames}`} />;
|
||||
}
|
||||
|
@ -283,6 +290,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
|||
const [markdownRendered, setMarkdownRendered] = useState<string>("");
|
||||
const [isPlaying, setIsPlaying] = useState<boolean>(false);
|
||||
const [interrupted, setInterrupted] = useState<boolean>(false);
|
||||
const [excalidrawData, setExcalidrawData] = useState<string>("");
|
||||
|
||||
const interruptedRef = useRef<boolean>(false);
|
||||
const messageRef = useRef<HTMLDivElement>(null);
|
||||
|
@ -321,6 +329,11 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
|||
useEffect(() => {
|
||||
let message = props.chatMessage.message;
|
||||
|
||||
if (props.chatMessage.intent && props.chatMessage.intent.type == "excalidraw") {
|
||||
message = props.chatMessage.intent["inferred-queries"][0];
|
||||
setExcalidrawData(props.chatMessage.message);
|
||||
}
|
||||
|
||||
// Replace LaTeX delimiters with placeholders
|
||||
message = message
|
||||
.replace(/\\\(/g, "LEFTPAREN")
|
||||
|
@ -328,26 +341,40 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
|||
.replace(/\\\[/g, "LEFTBRACKET")
|
||||
.replace(/\\\]/g, "RIGHTBRACKET");
|
||||
|
||||
if (props.chatMessage.uploadedImageData) {
|
||||
message = `![uploaded image](${props.chatMessage.uploadedImageData})\n\n${message}`;
|
||||
if (props.chatMessage.images && props.chatMessage.images.length > 0) {
|
||||
const imagesInMd = props.chatMessage.images
|
||||
.map((image, index) => {
|
||||
const decodedImage = image.startsWith("data%3Aimage")
|
||||
? decodeURIComponent(image)
|
||||
: image;
|
||||
const sanitizedImage = DOMPurify.sanitize(decodedImage);
|
||||
return `<div class="${styles.imageWrapper}"><img src="${sanitizedImage}" alt="uploaded image ${index + 1}" /></div>`;
|
||||
})
|
||||
.join("");
|
||||
message = `<div class="${styles.imagesContainer}">${imagesInMd}</div>${message}`;
|
||||
}
|
||||
|
||||
if (props.chatMessage.intent && props.chatMessage.intent.type == "text-to-image") {
|
||||
message = `![generated image](data:image/png;base64,${message})`;
|
||||
} else if (props.chatMessage.intent && props.chatMessage.intent.type == "text-to-image2") {
|
||||
message = `![generated image](${message})`;
|
||||
} else if (
|
||||
props.chatMessage.intent &&
|
||||
props.chatMessage.intent.type == "text-to-image-v3"
|
||||
) {
|
||||
message = `![generated image](data:image/webp;base64,${message})`;
|
||||
}
|
||||
if (
|
||||
props.chatMessage.intent &&
|
||||
props.chatMessage.intent.type.includes("text-to-image") &&
|
||||
props.chatMessage.intent["inferred-queries"]?.length > 0
|
||||
) {
|
||||
message += `\n\n${props.chatMessage.intent["inferred-queries"][0]}`;
|
||||
const intentTypeHandlers = {
|
||||
"text-to-image": (msg: string) => `![generated image](data:image/png;base64,${msg})`,
|
||||
"text-to-image2": (msg: string) => `![generated image](${msg})`,
|
||||
"text-to-image-v3": (msg: string) =>
|
||||
`![generated image](data:image/webp;base64,${msg})`,
|
||||
excalidraw: (msg: string) => {
|
||||
return msg;
|
||||
},
|
||||
};
|
||||
|
||||
if (props.chatMessage.intent) {
|
||||
const { type, "inferred-queries": inferredQueries } = props.chatMessage.intent;
|
||||
|
||||
console.log("intent type", type);
|
||||
if (type in intentTypeHandlers) {
|
||||
message = intentTypeHandlers[type as keyof typeof intentTypeHandlers](message);
|
||||
}
|
||||
|
||||
if (type.includes("text-to-image") && inferredQueries?.length > 0) {
|
||||
message += `\n\n${inferredQueries[0]}`;
|
||||
}
|
||||
}
|
||||
|
||||
setTextRendered(message);
|
||||
|
@ -364,7 +391,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
|||
|
||||
// Sanitize and set the rendered markdown
|
||||
setMarkdownRendered(DOMPurify.sanitize(markdownRendered));
|
||||
}, [props.chatMessage.message, props.chatMessage.intent]);
|
||||
}, [props.chatMessage.message, props.chatMessage.images, props.chatMessage.intent]);
|
||||
|
||||
useEffect(() => {
|
||||
if (copySuccess) {
|
||||
|
@ -554,6 +581,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
|||
className={styles.chatMessage}
|
||||
dangerouslySetInnerHTML={{ __html: markdownRendered }}
|
||||
/>
|
||||
{excalidrawData && <ExcalidrawComponent data={excalidrawData} />}
|
||||
</div>
|
||||
<div className={styles.teaserReferencesContainer}>
|
||||
<TeaserReferencesSection
|
||||
|
|
24
src/interface/web/app/components/excalidraw/excalidraw.tsx
Normal file
24
src/interface/web/app/components/excalidraw/excalidraw.tsx
Normal file
|
@ -0,0 +1,24 @@
|
|||
"use client";
|
||||
|
||||
import dynamic from "next/dynamic";
|
||||
import { Suspense } from "react";
|
||||
import Loading from "../../components/loading/loading";
|
||||
|
||||
// Since client components get prerenderd on server as well hence importing
|
||||
// the excalidraw stuff dynamically with ssr false
|
||||
|
||||
const ExcalidrawWrapper = dynamic(() => import("./excalidrawWrapper").then((mod) => mod.default), {
|
||||
ssr: false,
|
||||
});
|
||||
|
||||
interface ExcalidrawComponentProps {
|
||||
data: any;
|
||||
}
|
||||
|
||||
export default function ExcalidrawComponent(props: ExcalidrawComponentProps) {
|
||||
return (
|
||||
<Suspense fallback={<Loading />}>
|
||||
<ExcalidrawWrapper data={props.data} />
|
||||
</Suspense>
|
||||
);
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
import dynamic from "next/dynamic";
|
||||
|
||||
import { ExcalidrawProps } from "@excalidraw/excalidraw/types/types";
|
||||
import { ExcalidrawElement } from "@excalidraw/excalidraw/types/element/types";
|
||||
import { ExcalidrawElementSkeleton } from "@excalidraw/excalidraw/types/data/transform";
|
||||
|
||||
const Excalidraw = dynamic<ExcalidrawProps>(
|
||||
async () => (await import("@excalidraw/excalidraw")).Excalidraw,
|
||||
{
|
||||
ssr: false,
|
||||
},
|
||||
);
|
||||
|
||||
import { convertToExcalidrawElements } from "@excalidraw/excalidraw";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
import { ArrowsInSimple, ArrowsOutSimple } from "@phosphor-icons/react";
|
||||
|
||||
interface ExcalidrawWrapperProps {
|
||||
data: ExcalidrawElementSkeleton[];
|
||||
}
|
||||
|
||||
export default function ExcalidrawWrapper(props: ExcalidrawWrapperProps) {
|
||||
const [excalidrawElements, setExcalidrawElements] = useState<ExcalidrawElement[]>([]);
|
||||
const [expanded, setExpanded] = useState<boolean>(false);
|
||||
|
||||
const isValidExcalidrawElement = (element: ExcalidrawElementSkeleton): boolean => {
|
||||
return (
|
||||
element.x !== undefined &&
|
||||
element.y !== undefined &&
|
||||
element.id !== undefined &&
|
||||
element.type !== undefined
|
||||
);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (expanded) {
|
||||
onkeydown = (e) => {
|
||||
if (e.key === "Escape") {
|
||||
setExpanded(false);
|
||||
// Trigger a resize event to make Excalidraw adjust its size
|
||||
window.dispatchEvent(new Event("resize"));
|
||||
}
|
||||
};
|
||||
} else {
|
||||
onkeydown = null;
|
||||
}
|
||||
}, [expanded]);
|
||||
|
||||
useEffect(() => {
|
||||
// Do some basic validation
|
||||
const basicValidSkeletons: ExcalidrawElementSkeleton[] = [];
|
||||
|
||||
for (const element of props.data) {
|
||||
if (isValidExcalidrawElement(element as ExcalidrawElementSkeleton)) {
|
||||
basicValidSkeletons.push(element as ExcalidrawElementSkeleton);
|
||||
}
|
||||
}
|
||||
|
||||
const validSkeletons: ExcalidrawElementSkeleton[] = [];
|
||||
for (const element of basicValidSkeletons) {
|
||||
if (element.type === "frame") {
|
||||
continue;
|
||||
}
|
||||
if (element.type === "arrow") {
|
||||
const start = basicValidSkeletons.find((child) => child.id === element.start?.id);
|
||||
const end = basicValidSkeletons.find((child) => child.id === element.end?.id);
|
||||
if (start && end) {
|
||||
validSkeletons.push(element);
|
||||
}
|
||||
} else {
|
||||
validSkeletons.push(element);
|
||||
}
|
||||
}
|
||||
|
||||
for (const element of basicValidSkeletons) {
|
||||
if (element.type === "frame") {
|
||||
const children = element.children?.map((childId) => {
|
||||
return validSkeletons.find((child) => child.id === childId);
|
||||
});
|
||||
// Get the valid children, filter out any undefined values
|
||||
const validChildrenIds: readonly string[] = children
|
||||
?.map((child) => child?.id)
|
||||
.filter((id) => id !== undefined) as string[];
|
||||
|
||||
if (validChildrenIds === undefined || validChildrenIds.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
validSkeletons.push({
|
||||
...element,
|
||||
children: validChildrenIds,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const elements = convertToExcalidrawElements(validSkeletons);
|
||||
setExcalidrawElements(elements);
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="relative">
|
||||
<div
|
||||
className={`${expanded ? "fixed inset-0 bg-black bg-opacity-50 backdrop-blur-sm z-50 flex items-center justify-center" : ""}`}
|
||||
>
|
||||
<Button
|
||||
onClick={() => {
|
||||
setExpanded(!expanded);
|
||||
// Trigger a resize event to make Excalidraw adjust its size
|
||||
window.dispatchEvent(new Event("resize"));
|
||||
}}
|
||||
variant={"outline"}
|
||||
className={`${expanded ? "absolute top-2 left-2 z-[60]" : ""}`}
|
||||
>
|
||||
{expanded ? (
|
||||
<ArrowsInSimple className="h-4 w-4" />
|
||||
) : (
|
||||
<ArrowsOutSimple className="h-4 w-4" />
|
||||
)}
|
||||
</Button>
|
||||
<div
|
||||
className={`
|
||||
${expanded ? "w-[80vw] h-[80vh]" : "w-full h-[500px]"}
|
||||
bg-white overflow-hidden rounded-lg relative
|
||||
`}
|
||||
>
|
||||
<Excalidraw
|
||||
initialData={{
|
||||
elements: excalidrawElements,
|
||||
appState: { zenModeEnabled: true },
|
||||
scrollToContent: true,
|
||||
}}
|
||||
// TODO - Create a common function to detect if the theme is dark?
|
||||
theme={localStorage.getItem("theme") === "dark" ? "dark" : "light"}
|
||||
validateEmbeddable={true}
|
||||
renderTopRightUI={(isMobile, appState) => {
|
||||
return <></>;
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
|
@ -51,7 +51,7 @@ function FisherYatesShuffle(array: any[]) {
|
|||
|
||||
function ChatBodyData(props: ChatBodyDataProps) {
|
||||
const [message, setMessage] = useState("");
|
||||
const [image, setImage] = useState<string | null>(null);
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
const [processingMessage, setProcessingMessage] = useState(false);
|
||||
const [greeting, setGreeting] = useState("");
|
||||
const [shuffledOptions, setShuffledOptions] = useState<Suggestion[]>([]);
|
||||
|
@ -151,20 +151,21 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
try {
|
||||
const newConversationId = await createNewConversation(selectedAgent || "khoj");
|
||||
onConversationIdChange?.(newConversationId);
|
||||
window.location.href = `/chat?conversationId=${newConversationId}`;
|
||||
localStorage.setItem("message", message);
|
||||
if (image) {
|
||||
localStorage.setItem("image", image);
|
||||
if (images.length > 0) {
|
||||
localStorage.setItem("images", JSON.stringify(images));
|
||||
}
|
||||
window.location.href = `/chat?conversationId=${newConversationId}`;
|
||||
} catch (error) {
|
||||
console.error("Error creating new conversation:", error);
|
||||
setProcessingMessage(false);
|
||||
}
|
||||
setMessage("");
|
||||
setImages([]);
|
||||
}
|
||||
};
|
||||
processMessage();
|
||||
if (message) {
|
||||
if (message || images.length > 0) {
|
||||
setProcessingMessage(true);
|
||||
}
|
||||
}, [selectedAgent, message, processingMessage, onConversationIdChange]);
|
||||
|
@ -290,7 +291,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
</ScrollArea>
|
||||
)}
|
||||
</div>
|
||||
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit"}`}>
|
||||
<div className={`mx-auto ${props.isMobileWidth ? "w-full" : "w-fit max-w-screen-md"}`}>
|
||||
{!props.isMobileWidth && (
|
||||
<div
|
||||
className={`w-full ${styles.inputBox} shadow-lg bg-background align-middle items-center justify-center px-3 py-1 dark:bg-neutral-700 border-stone-100 dark:border-none dark:shadow-none rounded-2xl`}
|
||||
|
@ -298,7 +299,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
<ChatInputArea
|
||||
isLoggedIn={props.isLoggedIn}
|
||||
sendMessage={(message) => setMessage(message)}
|
||||
sendImage={(image) => setImage(image)}
|
||||
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||
sendDisabled={processingMessage}
|
||||
chatOptionsData={props.chatOptionsData}
|
||||
conversationId={null}
|
||||
|
@ -379,7 +380,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
<ChatInputArea
|
||||
isLoggedIn={props.isLoggedIn}
|
||||
sendMessage={(message) => setMessage(message)}
|
||||
sendImage={(image) => setImage(image)}
|
||||
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||
sendDisabled={processingMessage}
|
||||
chatOptionsData={props.chatOptionsData}
|
||||
conversationId={null}
|
||||
|
|
|
@ -27,7 +27,14 @@ export default function RootLayout({
|
|||
child-src 'none';
|
||||
object-src 'none';"
|
||||
></meta>
|
||||
<body className={inter.className}>{children}</body>
|
||||
<body className={inter.className}>
|
||||
{children}
|
||||
<script
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: `window.EXCALIDRAW_ASSET_PATH = 'https://assets.khoj.dev/@excalidraw/excalidraw/dist/';`,
|
||||
}}
|
||||
/>
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
|
|
|
@ -28,22 +28,42 @@ interface ChatBodyDataProps {
|
|||
isLoggedIn: boolean;
|
||||
conversationId?: string;
|
||||
setQueryToProcess: (query: string) => void;
|
||||
setImage64: (image64: string) => void;
|
||||
setImages: (images: string[]) => void;
|
||||
}
|
||||
|
||||
function ChatBodyData(props: ChatBodyDataProps) {
|
||||
const [message, setMessage] = useState("");
|
||||
const [image, setImage] = useState<string | null>(null);
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
const [processingMessage, setProcessingMessage] = useState(false);
|
||||
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
||||
const setQueryToProcess = props.setQueryToProcess;
|
||||
const streamedMessages = props.streamedMessages;
|
||||
|
||||
const chatHistoryCustomClassName = props.isMobileWidth ? "w-full" : "w-4/6";
|
||||
|
||||
useEffect(() => {
|
||||
if (image) {
|
||||
props.setImage64(encodeURIComponent(image));
|
||||
if (images.length > 0) {
|
||||
const encodedImages = images.map((image) => encodeURIComponent(image));
|
||||
props.setImages(encodedImages);
|
||||
}
|
||||
}, [image, props.setImage64]);
|
||||
}, [images, props.setImages]);
|
||||
|
||||
useEffect(() => {
|
||||
const storedImages = localStorage.getItem("images");
|
||||
if (storedImages) {
|
||||
const parsedImages: string[] = JSON.parse(storedImages);
|
||||
setImages(parsedImages);
|
||||
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
|
||||
props.setImages(encodedImages);
|
||||
localStorage.removeItem("images");
|
||||
}
|
||||
|
||||
const storedMessage = localStorage.getItem("message");
|
||||
if (storedMessage) {
|
||||
setProcessingMessage(true);
|
||||
setQueryToProcess(storedMessage);
|
||||
}
|
||||
}, [setQueryToProcess, props.setImages]);
|
||||
|
||||
useEffect(() => {
|
||||
if (message) {
|
||||
|
@ -78,15 +98,16 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
setTitle={props.setTitle}
|
||||
pendingMessage={processingMessage ? message : ""}
|
||||
incomingMessages={props.streamedMessages}
|
||||
customClassName={chatHistoryCustomClassName}
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl`}
|
||||
className={`${styles.inputBox} p-1 md:px-2 shadow-md bg-background align-middle items-center justify-center dark:bg-neutral-700 dark:border-0 dark:shadow-sm rounded-t-2xl rounded-b-none md:rounded-xl h-fit ${chatHistoryCustomClassName} mr-auto ml-auto`}
|
||||
>
|
||||
<ChatInputArea
|
||||
isLoggedIn={props.isLoggedIn}
|
||||
sendMessage={(message) => setMessage(message)}
|
||||
sendImage={(image) => setImage(image)}
|
||||
sendImage={(image) => setImages((prevImages) => [...prevImages, image])}
|
||||
sendDisabled={processingMessage}
|
||||
chatOptionsData={props.chatOptionsData}
|
||||
conversationId={props.conversationId}
|
||||
|
@ -109,7 +130,7 @@ export default function SharedChat() {
|
|||
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
||||
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
|
||||
const [image64, setImage64] = useState<string>("");
|
||||
const [images, setImages] = useState<string[]>([]);
|
||||
|
||||
const locationData = useIPLocationData() || {
|
||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
||||
|
@ -167,7 +188,7 @@ export default function SharedChat() {
|
|||
completed: false,
|
||||
timestamp: new Date().toISOString(),
|
||||
rawQuery: queryToProcess || "",
|
||||
uploadedImageData: decodeURIComponent(image64),
|
||||
images: images,
|
||||
};
|
||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
||||
setProcessQuerySignal(true);
|
||||
|
@ -194,7 +215,7 @@ export default function SharedChat() {
|
|||
if (done) {
|
||||
setQueryToProcess("");
|
||||
setProcessQuerySignal(false);
|
||||
setImage64("");
|
||||
setImages([]);
|
||||
break;
|
||||
}
|
||||
|
||||
|
@ -236,7 +257,7 @@ export default function SharedChat() {
|
|||
country_code: locationData.countryCode,
|
||||
timezone: locationData.timezone,
|
||||
}),
|
||||
...(image64 && { image: image64 }),
|
||||
...(images.length > 0 && { image: images }),
|
||||
};
|
||||
|
||||
const response = await fetch(chatAPI, {
|
||||
|
@ -275,6 +296,19 @@ export default function SharedChat() {
|
|||
|
||||
<div className={styles.chatBox}>
|
||||
<div className={styles.chatBoxBody}>
|
||||
{!isMobileWidth && title && (
|
||||
<div
|
||||
className={`${styles.chatTitleWrapper} text-nowrap text-ellipsis overflow-hidden max-w-screen-md grid items-top font-bold mr-8 pt-6 col-auto h-fit`}
|
||||
>
|
||||
{title && (
|
||||
<h2
|
||||
className={`text-lg text-ellipsis whitespace-nowrap overflow-x-hidden`}
|
||||
>
|
||||
{title}
|
||||
</h2>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<Suspense fallback={<Loading />}>
|
||||
<ChatBodyData
|
||||
conversationId={conversationId}
|
||||
|
@ -286,7 +320,7 @@ export default function SharedChat() {
|
|||
setTitle={setTitle}
|
||||
setUploadedFiles={setUploadedFiles}
|
||||
isMobileWidth={isMobileWidth}
|
||||
setImage64={setImage64}
|
||||
setImages={setImages}
|
||||
/>
|
||||
</Suspense>
|
||||
</div>
|
||||
|
|
|
@ -75,7 +75,7 @@ div.titleBar {
|
|||
div.chatBoxBody {
|
||||
display: grid;
|
||||
height: 100%;
|
||||
width: 70%;
|
||||
width: 95%;
|
||||
margin: auto;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "khoj-ai",
|
||||
"version": "1.26.0",
|
||||
"version": "1.26.4",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev",
|
||||
|
@ -63,7 +63,8 @@
|
|||
"swr": "^2.2.5",
|
||||
"typescript": "^5",
|
||||
"vaul": "^0.9.1",
|
||||
"zod": "^3.23.8"
|
||||
"zod": "^3.23.8",
|
||||
"@excalidraw/excalidraw": "^0.17.6"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/dompurify": "^3.0.5",
|
||||
|
|
|
@ -286,6 +286,11 @@
|
|||
resolved "https://registry.yarnpkg.com/@eslint/js/-/js-8.57.1.tgz#de633db3ec2ef6a3c89e2f19038063e8a122e2c2"
|
||||
integrity sha512-d9zaMRSTIKDLhctzH12MtXvJKSSUhaHcjV+2Z+GK+EEY7XKpP5yR4x+N3TAcHTcu963nIr+TMcCb4DBCYX1z6Q==
|
||||
|
||||
"@excalidraw/excalidraw@^0.17.6":
|
||||
version "0.17.6"
|
||||
resolved "https://registry.yarnpkg.com/@excalidraw/excalidraw/-/excalidraw-0.17.6.tgz#5fd208ce69d33ca712d1804b50d7d06d5c46ac4d"
|
||||
integrity sha512-fyCl+zG/Z5yhHDh5Fq2ZGmphcrALmuOdtITm8gN4d8w4ntnaopTXcTfnAAaU3VleDC6LhTkoLOTG6P5kgREiIg==
|
||||
|
||||
"@floating-ui/core@^1.6.0":
|
||||
version "1.6.8"
|
||||
resolved "https://registry.yarnpkg.com/@floating-ui/core/-/core-1.6.8.tgz#aa43561be075815879305965020f492cdb43da12"
|
||||
|
|
|
@ -172,7 +172,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
request=request,
|
||||
telemetry_type="api",
|
||||
api="create_user",
|
||||
metadata={"user_id": str(user.uuid)},
|
||||
metadata={"server_id": str(user.uuid)},
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
||||
else:
|
||||
|
|
|
@ -622,6 +622,8 @@ class AgentAdapters:
|
|||
@staticmethod
|
||||
def get_all_accessible_agents(user: KhojUser = None):
|
||||
public_query = Q(privacy_level=Agent.PrivacyLevel.PUBLIC)
|
||||
# TODO Update this to allow any public agent that's officially approved once that experience is launched
|
||||
public_query &= Q(managed_by_admin=True)
|
||||
if user:
|
||||
return (
|
||||
Agent.objects.filter(public_query | Q(creator=user))
|
||||
|
@ -640,6 +642,16 @@ class AgentAdapters:
|
|||
agents = await sync_to_async(AgentAdapters.get_all_accessible_agents)(user)
|
||||
return await sync_to_async(list)(agents)
|
||||
|
||||
@staticmethod
|
||||
async def ais_agent_accessible(agent: Agent, user: KhojUser) -> bool:
|
||||
if agent.privacy_level == Agent.PrivacyLevel.PUBLIC:
|
||||
return True
|
||||
if agent.creator == user:
|
||||
return True
|
||||
if agent.privacy_level == Agent.PrivacyLevel.PROTECTED:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_conversation_agent_by_id(agent_id: int):
|
||||
agent = Agent.objects.filter(id=agent_id).first()
|
||||
|
@ -1463,12 +1475,15 @@ class EntryAdapters:
|
|||
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
|
||||
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
|
||||
|
||||
user_or_agent = Q(user=user)
|
||||
owner_filter = Q()
|
||||
|
||||
if user != None:
|
||||
owner_filter = Q(user=user)
|
||||
if agent != None:
|
||||
user_or_agent |= Q(agent=agent)
|
||||
owner_filter |= Q(agent=agent)
|
||||
|
||||
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
||||
return Entry.objects.filter(user_or_agent)
|
||||
return Entry.objects.filter(owner_filter)
|
||||
|
||||
for term in word_filters:
|
||||
if term.startswith("+"):
|
||||
|
@ -1504,7 +1519,7 @@ class EntryAdapters:
|
|||
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
|
||||
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
|
||||
|
||||
relevant_entries = Entry.objects.filter(user_or_agent).filter(q_filter_terms)
|
||||
relevant_entries = Entry.objects.filter(owner_filter).filter(q_filter_terms)
|
||||
if file_type_filter:
|
||||
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
||||
return relevant_entries
|
||||
|
@ -1519,13 +1534,18 @@ class EntryAdapters:
|
|||
max_distance: float = math.inf,
|
||||
agent: Agent = None,
|
||||
):
|
||||
user_or_agent = Q(user=user)
|
||||
owner_filter = Q()
|
||||
|
||||
if user != None:
|
||||
owner_filter = Q(user=user)
|
||||
if agent != None:
|
||||
user_or_agent |= Q(agent=agent)
|
||||
owner_filter |= Q(agent=agent)
|
||||
|
||||
if owner_filter == Q():
|
||||
return Entry.objects.none()
|
||||
|
||||
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter, agent)
|
||||
relevant_entries = relevant_entries.filter(user_or_agent).annotate(
|
||||
relevant_entries = relevant_entries.filter(owner_filter).annotate(
|
||||
distance=CosineDistance("embeddings", embeddings)
|
||||
)
|
||||
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# Generated by Django 5.0.8 on 2024-10-21 05:16
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0069_webscraper_serverchatsettings_web_scraper"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="agent",
|
||||
name="input_tools",
|
||||
field=django.contrib.postgres.fields.ArrayField(
|
||||
base_field=models.CharField(
|
||||
choices=[
|
||||
("general", "General"),
|
||||
("online", "Online"),
|
||||
("notes", "Notes"),
|
||||
("summarize", "Summarize"),
|
||||
("webpage", "Webpage"),
|
||||
],
|
||||
max_length=200,
|
||||
),
|
||||
blank=True,
|
||||
default=list,
|
||||
null=True,
|
||||
size=None,
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="agent",
|
||||
name="output_modes",
|
||||
field=django.contrib.postgres.fields.ArrayField(
|
||||
base_field=models.CharField(
|
||||
choices=[("text", "Text"), ("image", "Image"), ("automation", "Automation")], max_length=200
|
||||
),
|
||||
blank=True,
|
||||
default=list,
|
||||
null=True,
|
||||
size=None,
|
||||
),
|
||||
),
|
||||
]
|
|
@ -180,8 +180,12 @@ class Agent(BaseModel):
|
|||
) # Creator will only be null when the agents are managed by admin
|
||||
name = models.CharField(max_length=200)
|
||||
personality = models.TextField()
|
||||
input_tools = ArrayField(models.CharField(max_length=200, choices=InputToolOptions.choices), default=list)
|
||||
output_modes = ArrayField(models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list)
|
||||
input_tools = ArrayField(
|
||||
models.CharField(max_length=200, choices=InputToolOptions.choices), default=list, null=True, blank=True
|
||||
)
|
||||
output_modes = ArrayField(
|
||||
models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list, null=True, blank=True
|
||||
)
|
||||
managed_by_admin = models.BooleanField(default=False)
|
||||
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
|
||||
slug = models.CharField(max_length=200, unique=True)
|
||||
|
|
|
@ -67,7 +67,7 @@ class PdfToEntries(TextToEntries):
|
|||
bytes = pdf_files[pdf_file]
|
||||
f.write(bytes)
|
||||
try:
|
||||
loader = PyMuPDFLoader(f"{tmp_file}", extract_images=True)
|
||||
loader = PyMuPDFLoader(f"{tmp_file}", extract_images=False)
|
||||
pdf_entries_per_file = [page.page_content for page in loader.load()]
|
||||
except ImportError:
|
||||
loader = PyMuPDFLoader(f"{tmp_file}")
|
||||
|
|
|
@ -6,14 +6,17 @@ from typing import Dict, Optional
|
|||
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.database.models import Agent, ChatModelOptions, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.google.utils import (
|
||||
format_messages_for_gemini,
|
||||
gemini_chat_completion_with_backoff,
|
||||
gemini_completion_with_backoff,
|
||||
)
|
||||
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
|
||||
from khoj.processor.conversation.utils import (
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
|
@ -29,6 +32,8 @@ def extract_questions_gemini(
|
|||
max_tokens=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
|
@ -70,17 +75,17 @@ def extract_questions_gemini(
|
|||
text=text,
|
||||
)
|
||||
|
||||
messages = [ChatMessage(content=prompt, role="user")]
|
||||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
images=query_images,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
vision_enabled=vision_enabled,
|
||||
)
|
||||
|
||||
model_kwargs = {"response_mime_type": "application/json"}
|
||||
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
||||
|
||||
response = gemini_completion_with_backoff(
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
model_kwargs=model_kwargs,
|
||||
response = gemini_send_message_to_model(
|
||||
messages, api_key, model, response_type="json_object", temperature=temperature
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
|
@ -102,7 +107,7 @@ def extract_questions_gemini(
|
|||
return questions
|
||||
|
||||
|
||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
|
||||
def gemini_send_message_to_model(messages, api_key, model, response_type="text", temperature=0, model_kwargs=None):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -114,7 +119,12 @@ def gemini_send_message_to_model(messages, api_key, model, response_type="text")
|
|||
|
||||
# Get Response from Gemini
|
||||
return gemini_completion_with_backoff(
|
||||
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
|
||||
messages=messages,
|
||||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
temperature=temperature,
|
||||
model_kwargs=model_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
@ -133,6 +143,8 @@ def converse_gemini(
|
|||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
):
|
||||
"""
|
||||
Converse with user using Google's Gemini
|
||||
|
@ -187,6 +199,9 @@ def converse_gemini(
|
|||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||
)
|
||||
|
||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
import logging
|
||||
import random
|
||||
from io import BytesIO
|
||||
from threading import Thread
|
||||
|
||||
import google.generativeai as genai
|
||||
import PIL.Image
|
||||
import requests
|
||||
from google.generativeai.types.answer_types import FinishReason
|
||||
from google.generativeai.types.generation_types import StopCandidateException
|
||||
from google.generativeai.types.safety_types import (
|
||||
|
@ -53,14 +56,14 @@ def gemini_completion_with_backoff(
|
|||
},
|
||||
)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
|
||||
# Start chat session. All messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
|
||||
try:
|
||||
# Generate the response. The last message is considered to be the current prompt
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
|
||||
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"])
|
||||
return aggregated_response.text
|
||||
except StopCandidateException as e:
|
||||
response_message, _ = handle_gemini_response(e.args)
|
||||
|
@ -117,11 +120,11 @@ def gemini_llm_thread(g, messages, system_prompt, model_name, temperature, api_k
|
|||
},
|
||||
)
|
||||
|
||||
formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
|
||||
formatted_messages = [{"role": message.role, "parts": message.content} for message in messages]
|
||||
# all messages up to the last are considered to be part of the chat history
|
||||
chat_session = model.start_chat(history=formatted_messages[0:-1])
|
||||
# the last message is considered to be the current prompt
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
|
||||
for chunk in chat_session.send_message(formatted_messages[-1]["parts"], stream=True):
|
||||
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
|
||||
message = message or chunk.text
|
||||
g.send(message)
|
||||
|
@ -191,14 +194,6 @@ def generate_safety_response(safety_ratings):
|
|||
|
||||
|
||||
def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str = None) -> tuple[list[str], str]:
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
return messages, system_prompt
|
||||
|
||||
for message in messages:
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
# Extract system message
|
||||
system_prompt = system_prompt or ""
|
||||
for message in messages.copy():
|
||||
|
@ -207,4 +202,31 @@ def format_messages_for_gemini(messages: list[ChatMessage], system_prompt: str =
|
|||
messages.remove(message)
|
||||
system_prompt = None if is_none_or_empty(system_prompt) else system_prompt
|
||||
|
||||
for message in messages:
|
||||
# Convert message content to string list from chatml dictionary list
|
||||
if isinstance(message.content, list):
|
||||
# Convert image_urls to PIL.Image and place them at beginning of list (better for Gemini)
|
||||
message.content = [
|
||||
get_image_from_url(item["image_url"]["url"]) if item["type"] == "image_url" else item["text"]
|
||||
for item in sorted(message.content, key=lambda x: 0 if x["type"] == "image_url" else 1)
|
||||
]
|
||||
elif isinstance(message.content, str):
|
||||
message.content = [message.content]
|
||||
|
||||
if message.role == "assistant":
|
||||
message.role = "model"
|
||||
|
||||
if len(messages) == 1:
|
||||
messages[0].role = "user"
|
||||
|
||||
return messages, system_prompt
|
||||
|
||||
|
||||
def get_image_from_url(image_url: str) -> PIL.Image:
|
||||
try:
|
||||
response = requests.get(image_url)
|
||||
response.raise_for_status() # Check if the request was successful
|
||||
return PIL.Image.open(BytesIO(response.content))
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to get image from URL {image_url}: {e}")
|
||||
return None
|
||||
|
|
|
@ -30,7 +30,7 @@ def extract_questions(
|
|||
api_base_url=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
):
|
||||
|
@ -74,7 +74,7 @@ def extract_questions(
|
|||
|
||||
prompt = construct_structured_message(
|
||||
message=prompt,
|
||||
image_url=uploaded_image_url,
|
||||
images=query_images,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
vision_enabled=vision_enabled,
|
||||
)
|
||||
|
@ -135,7 +135,7 @@ def converse(
|
|||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
image_url: Optional[str] = None,
|
||||
query_images: Optional[list[str]] = None,
|
||||
vision_available: bool = False,
|
||||
):
|
||||
"""
|
||||
|
@ -191,7 +191,7 @@ def converse(
|
|||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
tokenizer_name=tokenizer_name,
|
||||
uploaded_image_url=image_url,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_available,
|
||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||
)
|
||||
|
|
|
@ -176,6 +176,150 @@ Improved Prompt:
|
|||
""".strip()
|
||||
)
|
||||
|
||||
## Diagram Generation
|
||||
## --
|
||||
|
||||
improve_diagram_description_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
you are an architect working with a novice artist using a diagramming tool.
|
||||
{personality_context}
|
||||
|
||||
you need to convert the user's query to a description format that the novice artist can use very well. you are allowed to use primitives like
|
||||
- text
|
||||
- rectangle
|
||||
- diamond
|
||||
- ellipse
|
||||
- line
|
||||
- arrow
|
||||
- frame
|
||||
|
||||
use these primitives to describe what sort of diagram the drawer should create. the artist must recreate the diagram every time, so include all relevant prior information in your description.
|
||||
|
||||
use simple, concise language.
|
||||
|
||||
Today's Date: {current_date}
|
||||
User's Location: {location}
|
||||
|
||||
User's Notes:
|
||||
{references}
|
||||
|
||||
Online References:
|
||||
{online_results}
|
||||
|
||||
Conversation Log:
|
||||
{chat_history}
|
||||
|
||||
Query: {query}
|
||||
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
excalidraw_diagram_generation_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are a program manager with the ability to describe diagrams to compose in professional, fine detail.
|
||||
{personality_context}
|
||||
|
||||
You need to create a declarative description of the diagram and relevant components, using this base schema. Use the `label` property to specify the text to be rendered in the respective elements. Always use light colors for the `backgroundColor` property, like white, or light blue, green, red. "type", "x", "y", "id", are required properties for all elements.
|
||||
|
||||
{{
|
||||
type: string,
|
||||
x: number,
|
||||
y: number,
|
||||
strokeColor: string,
|
||||
backgroundColor: string,
|
||||
width: number,
|
||||
height: number,
|
||||
id: string,
|
||||
label: {{
|
||||
text: string,
|
||||
}}
|
||||
}}
|
||||
|
||||
Valid types:
|
||||
- text
|
||||
- rectangle
|
||||
- diamond
|
||||
- ellipse
|
||||
- line
|
||||
- arrow
|
||||
|
||||
For arrows and lines, you can use the `points` property to specify the start and end points of the arrow. You may also use the `label` property to specify the text to be rendered. You may use the `start` and `end` properties to connect the linear elements to other elements. The start and end point can either be the ID to map to an existing object, or the `type` to create a new object. Mapping to an existing object is useful if you want to connect it to multiple objects. Lines and arrows can only start and end at rectangle, text, diamond, or ellipse elements.
|
||||
|
||||
{{
|
||||
type: "arrow",
|
||||
id: string,
|
||||
x: number,
|
||||
y: number,
|
||||
width: number,
|
||||
height: number,
|
||||
strokeColor: string,
|
||||
start: {{
|
||||
id: string,
|
||||
type: string,
|
||||
}},
|
||||
end: {{
|
||||
id: string,
|
||||
type: string,
|
||||
}},
|
||||
label: {{
|
||||
text: string,
|
||||
}}
|
||||
points: [
|
||||
[number, number],
|
||||
[number, number],
|
||||
]
|
||||
}}
|
||||
|
||||
For text, you must use the `text` property to specify the text to be rendered. You may also use `fontSize` property to specify the font size of the text. Only use the `text` element for titles, subtitles, and overviews. For labels, use the `label` property in the respective elements.
|
||||
|
||||
{{
|
||||
type: "text",
|
||||
id: string,
|
||||
x: number,
|
||||
y: number,
|
||||
fontSize: number,
|
||||
text: string,
|
||||
}}
|
||||
|
||||
For frames, use the `children` property to specify the elements that are inside the frame by their ids.
|
||||
|
||||
{{
|
||||
type: "frame",
|
||||
id: string,
|
||||
x: number,
|
||||
y: number,
|
||||
width: number,
|
||||
height: number,
|
||||
name: string,
|
||||
children: [
|
||||
string
|
||||
]
|
||||
}}
|
||||
|
||||
Here's an example of a valid diagram:
|
||||
|
||||
Design Description: Create a diagram describing a circular development process with 3 stages: design, implementation and feedback. The design stage is connected to the implementation stage and the implementation stage is connected to the feedback stage and the feedback stage is connected to the design stage. Each stage should be labeled with the stage name.
|
||||
|
||||
Response:
|
||||
|
||||
[
|
||||
{{"type":"text","x":-150,"y":50,"width":300,"height":40,"id":"title_text","text":"Circular Development Process","fontSize":24}},
|
||||
{{"type":"ellipse","x":-169,"y":113,"width":188,"height":202,"id":"design_ellipse", "label": {{"text": "Design"}}}},
|
||||
{{"type":"ellipse","x":62,"y":394,"width":186,"height":188,"id":"implement_ellipse", "label": {{"text": "Implement"}}}},
|
||||
{{"type":"ellipse","x":-348,"y":430,"width":184,"height":170,"id":"feedback_ellipse", "label": {{"text": "Feedback"}}}},
|
||||
{{"type":"arrow","x":21,"y":273,"id":"design_to_implement_arrow","points":[[0,0],[86,105]],"start":{{"id":"design_ellipse"}}, "end":{{"id":"implement_ellipse"}}}},
|
||||
{{"type":"arrow","x":50,"y":519,"id":"implement_to_feedback_arrow","points":[[0,0],[-198,-6]],"start":{{"id":"implement_ellipse"}}, "end":{{"id":"feedback_ellipse"}}}},
|
||||
{{"type":"arrow","x":-228,"y":417,"id":"feedback_to_design_arrow","points":[[0,0],[85,-123]],"start":{{"id":"feedback_ellipse"}}, "end":{{"id":"design_ellipse"}}}},
|
||||
]
|
||||
|
||||
Create a detailed diagram from the provided context and user prompt below. Return a valid JSON object:
|
||||
|
||||
Diagram Description: {query}
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## Online Search Conversation
|
||||
## --
|
||||
online_search_conversation = PromptTemplate.from_template(
|
||||
|
|
|
@ -109,7 +109,7 @@ def save_to_conversation_log(
|
|||
client_application: ClientApplication = None,
|
||||
conversation_id: str = None,
|
||||
automation_id: str = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
updated_conversation = message_to_log(
|
||||
|
@ -117,7 +117,7 @@ def save_to_conversation_log(
|
|||
chat_response=chat_response,
|
||||
user_message_metadata={
|
||||
"created": user_message_time,
|
||||
"uploadedImageData": uploaded_image_url,
|
||||
"images": query_images,
|
||||
},
|
||||
khoj_message_metadata={
|
||||
"context": compiled_references,
|
||||
|
@ -145,10 +145,18 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
|||
)
|
||||
|
||||
|
||||
# Format user and system messages to chatml format
|
||||
def construct_structured_message(message, image_url, model_type, vision_enabled):
|
||||
if image_url and vision_enabled and model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
return [{"type": "text", "text": message}, {"type": "image_url", "image_url": {"url": image_url}}]
|
||||
def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool):
|
||||
"""
|
||||
Format messages into appropriate multimedia format for supported chat model types
|
||||
"""
|
||||
if not images or not vision_enabled:
|
||||
return message
|
||||
|
||||
if model_type in [ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.GOOGLE]:
|
||||
return [
|
||||
{"type": "text", "text": message},
|
||||
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
|
||||
]
|
||||
return message
|
||||
|
||||
|
||||
|
@ -160,7 +168,7 @@ def generate_chatml_messages_with_context(
|
|||
loaded_model: Optional[Llama] = None,
|
||||
max_prompt_size=None,
|
||||
tokenizer_name=None,
|
||||
uploaded_image_url=None,
|
||||
query_images=None,
|
||||
vision_enabled=False,
|
||||
model_type="",
|
||||
):
|
||||
|
@ -181,11 +189,12 @@ def generate_chatml_messages_with_context(
|
|||
message_notes = f'\n\n Notes:\n{chat.get("context")}' if chat.get("context") else "\n"
|
||||
role = "user" if chat["by"] == "you" else "assistant"
|
||||
|
||||
message_content = chat["message"] + message_notes
|
||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type"):
|
||||
message_content = chat.get("intent").get("inferred-queries")[0] + message_notes
|
||||
else:
|
||||
message_content = chat["message"] + message_notes
|
||||
|
||||
message_content = construct_structured_message(
|
||||
message_content, chat.get("uploadedImageData"), model_type, vision_enabled
|
||||
)
|
||||
message_content = construct_structured_message(message_content, chat.get("images"), model_type, vision_enabled)
|
||||
|
||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||
|
||||
|
@ -198,7 +207,7 @@ def generate_chatml_messages_with_context(
|
|||
if not is_none_or_empty(user_message):
|
||||
messages.append(
|
||||
ChatMessage(
|
||||
content=construct_structured_message(user_message, uploaded_image_url, model_type, vision_enabled),
|
||||
content=construct_structured_message(user_message, query_images, model_type, vision_enabled),
|
||||
role="user",
|
||||
)
|
||||
)
|
||||
|
@ -222,7 +231,6 @@ def truncate_messages(
|
|||
tokenizer_name=None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Truncate messages to fit within max prompt size supported by model"""
|
||||
|
||||
default_tokenizer = "gpt-4o"
|
||||
|
||||
try:
|
||||
|
@ -252,6 +260,7 @@ def truncate_messages(
|
|||
system_message = messages.pop(idx)
|
||||
break
|
||||
|
||||
# TODO: Handle truncation of multi-part message.content, i.e when message.content is a list[dict] rather than a string
|
||||
system_message_tokens = (
|
||||
len(encoder.encode(system_message.content)) if system_message and type(system_message.content) == str else 0
|
||||
)
|
||||
|
|
|
@ -26,7 +26,7 @@ async def text_to_image(
|
|||
references: List[Dict[str, Any]],
|
||||
online_results: Dict[str, Any],
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
status_code = 200
|
||||
|
@ -65,7 +65,7 @@ async def text_to_image(
|
|||
note_references=references,
|
||||
online_results=online_results,
|
||||
model_type=text_to_image_config.model_type,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
)
|
||||
|
@ -87,18 +87,18 @@ async def text_to_image(
|
|||
if "content_policy_violation" in e.message:
|
||||
logger.error(f"Image Generation blocked by OpenAI: {e}")
|
||||
status_code = e.status_code # type: ignore
|
||||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
|
||||
message = f"Image generation blocked by OpenAI due to policy violation" # type: ignore
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
return
|
||||
else:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
|
||||
message = f"Image generation failed using OpenAI" # type: ignore
|
||||
status_code = e.status_code # type: ignore
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
return
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Image Generation failed with {e}", exc_info=True)
|
||||
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}"
|
||||
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed due to a network error."
|
||||
status_code = 502
|
||||
yield image_url or image, status_code, message, intent_type.value
|
||||
return
|
||||
|
|
|
@ -62,7 +62,7 @@ async def search_online(
|
|||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
custom_filters: List[str] = [],
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
query += " ".join(custom_filters)
|
||||
|
@ -73,7 +73,7 @@ async def search_online(
|
|||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(
|
||||
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
|
||||
query, conversation_history, location, user, query_images=query_images, agent=agent
|
||||
)
|
||||
response_dict = {}
|
||||
|
||||
|
@ -151,7 +151,7 @@ async def read_webpages(
|
|||
location: LocationData,
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
|
@ -159,7 +159,7 @@ async def read_webpages(
|
|||
if send_status_func:
|
||||
async for event in send_status_func(f"**Inferring web pages to read**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
urls = await infer_webpage_urls(query, conversation_history, location, user, uploaded_image_url)
|
||||
urls = await infer_webpage_urls(query, conversation_history, location, user, query_images)
|
||||
|
||||
logger.info(f"Reading web pages at: {urls}")
|
||||
if send_status_func:
|
||||
|
|
|
@ -21,6 +21,7 @@ from starlette.authentication import has_required_scope, requires
|
|||
from khoj.configure import initialize_content
|
||||
from khoj.database import adapters
|
||||
from khoj.database.adapters import (
|
||||
AgentAdapters,
|
||||
AutomationAdapters,
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
|
@ -114,10 +115,16 @@ async def execute_search(
|
|||
dedupe: Optional[bool] = True,
|
||||
agent: Optional[Agent] = None,
|
||||
):
|
||||
start_time = time.time()
|
||||
|
||||
# Run validation checks
|
||||
results: List[SearchResponse] = []
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Ensure the agent, if present, is accessible by the user
|
||||
if user and agent and not await AgentAdapters.ais_agent_accessible(agent, user):
|
||||
logger.error(f"Agent {agent.slug} is not accessible by user {user}")
|
||||
return results
|
||||
|
||||
if q is None or q == "":
|
||||
logger.warning(f"No query param (q) passed in API call to initiate search")
|
||||
return results
|
||||
|
@ -340,7 +347,7 @@ async def extract_references_and_questions(
|
|||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||
location_data: LocationData = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
@ -431,7 +438,7 @@ async def extract_references_and_questions(
|
|||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
@ -452,12 +459,14 @@ async def extract_references_and_questions(
|
|||
chat_model = conversation_config.chat_model
|
||||
inferred_queries = extract_questions_gemini(
|
||||
defiltered_query,
|
||||
query_images=query_images,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
max_tokens=conversation_config.max_prompt_size,
|
||||
user=user,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
|
|
|
@ -30,8 +30,10 @@ from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
|||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.routers.api import extract_references_and_questions
|
||||
from khoj.routers.helpers import (
|
||||
ApiImageRateLimiter,
|
||||
ApiUserRateLimiter,
|
||||
ChatEvent,
|
||||
ChatRequestBody,
|
||||
CommonQueryParams,
|
||||
ConversationCommandRateLimiter,
|
||||
agenerate_chat_response,
|
||||
|
@ -40,6 +42,7 @@ from khoj.routers.helpers import (
|
|||
construct_automation_created_message,
|
||||
create_automation,
|
||||
extract_relevant_summary,
|
||||
generate_excalidraw_diagram,
|
||||
get_conversation_command,
|
||||
is_query_empty,
|
||||
is_ready_to_chat,
|
||||
|
@ -523,22 +526,6 @@ async def set_conversation_title(
|
|||
)
|
||||
|
||||
|
||||
class ChatRequestBody(BaseModel):
|
||||
q: str
|
||||
n: Optional[int] = 7
|
||||
d: Optional[float] = None
|
||||
stream: Optional[bool] = False
|
||||
title: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
image: Optional[str] = None
|
||||
create_new: Optional[bool] = False
|
||||
|
||||
|
||||
@api_chat.post("")
|
||||
@requires(["authenticated"])
|
||||
async def chat(
|
||||
|
@ -551,6 +538,7 @@ async def chat(
|
|||
rate_limiter_per_day=Depends(
|
||||
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
|
||||
),
|
||||
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=20)),
|
||||
):
|
||||
# Access the parameters from the body
|
||||
q = body.q
|
||||
|
@ -564,9 +552,9 @@ async def chat(
|
|||
country = body.country or get_country_name_from_timezone(body.timezone)
|
||||
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||
timezone = body.timezone
|
||||
image = body.image
|
||||
raw_images = body.images
|
||||
|
||||
async def event_generator(q: str, image: str):
|
||||
async def event_generator(q: str, images: list[str]):
|
||||
start_time = time.perf_counter()
|
||||
ttft = None
|
||||
chat_metadata: dict = {}
|
||||
|
@ -576,16 +564,16 @@ async def chat(
|
|||
q = unquote(q)
|
||||
nonlocal conversation_id
|
||||
|
||||
uploaded_image_url = None
|
||||
if image:
|
||||
decoded_string = unquote(image)
|
||||
base64_data = decoded_string.split(",", 1)[1]
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||
try:
|
||||
uploaded_image_url = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||
except:
|
||||
uploaded_image_url = None
|
||||
uploaded_images: list[str] = []
|
||||
if images:
|
||||
for image in images:
|
||||
decoded_string = unquote(image)
|
||||
base64_data = decoded_string.split(",", 1)[1]
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
webp_image_bytes = convert_image_to_webp(image_bytes)
|
||||
uploaded_image = upload_image_to_bucket(webp_image_bytes, request.user.object.id)
|
||||
if uploaded_image:
|
||||
uploaded_images.append(uploaded_image)
|
||||
|
||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||
nonlocal connection_alive, ttft
|
||||
|
@ -692,7 +680,7 @@ async def chat(
|
|||
meta_log,
|
||||
is_automated_task,
|
||||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
)
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
|
@ -701,7 +689,7 @@ async def chat(
|
|||
):
|
||||
yield result
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_image_url, agent)
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, user, uploaded_images, agent)
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||
yield result
|
||||
if mode not in conversation_commands:
|
||||
|
@ -764,7 +752,7 @@ async def chat(
|
|||
q,
|
||||
contextual_data,
|
||||
conversation_history=meta_log,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
)
|
||||
|
@ -785,7 +773,7 @@ async def chat(
|
|||
intent_type="summarize",
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -828,7 +816,7 @@ async def chat(
|
|||
conversation_id=conversation_id,
|
||||
inferred_queries=[query_to_run],
|
||||
automation_id=automation.id,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
)
|
||||
async for result in send_llm_response(llm_response):
|
||||
yield result
|
||||
|
@ -848,7 +836,7 @@ async def chat(
|
|||
conversation_commands,
|
||||
location,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
@ -859,7 +847,7 @@ async def chat(
|
|||
defiltered_query = result[2]
|
||||
except Exception as e:
|
||||
error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
||||
logger.warning(error_message)
|
||||
logger.error(error_message, exc_info=True)
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
||||
):
|
||||
|
@ -892,7 +880,7 @@ async def chat(
|
|||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
custom_filters,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
@ -916,7 +904,7 @@ async def chat(
|
|||
location,
|
||||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
|
@ -966,20 +954,20 @@ async def chat(
|
|||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
image, status_code, improved_image_prompt, intent_type = result
|
||||
generated_image, status_code, improved_image_prompt, intent_type = result
|
||||
|
||||
if image is None or status_code != 200:
|
||||
if generated_image is None or status_code != 200:
|
||||
content_obj = {
|
||||
"content-type": "application/json",
|
||||
"intentType": intent_type,
|
||||
"detail": improved_image_prompt,
|
||||
"image": image,
|
||||
"image": None,
|
||||
}
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
yield result
|
||||
|
@ -987,7 +975,7 @@ async def chat(
|
|||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
image,
|
||||
generated_image,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
|
@ -997,17 +985,68 @@ async def chat(
|
|||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=uploaded_images,
|
||||
)
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
"inferredQueries": [improved_image_prompt],
|
||||
"image": image,
|
||||
"image": generated_image,
|
||||
}
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
yield result
|
||||
return
|
||||
|
||||
if ConversationCommand.Diagram in conversation_commands:
|
||||
async for result in send_event(ChatEvent.STATUS, f"Creating diagram"):
|
||||
yield result
|
||||
|
||||
intent_type = "excalidraw"
|
||||
inferred_queries = []
|
||||
diagram_description = ""
|
||||
|
||||
async for result in generate_excalidraw_diagram(
|
||||
q=defiltered_query,
|
||||
conversation_history=meta_log,
|
||||
location_data=location,
|
||||
note_references=compiled_references,
|
||||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
better_diagram_description_prompt, excalidraw_diagram_description = result
|
||||
inferred_queries.append(better_diagram_description_prompt)
|
||||
diagram_description = excalidraw_diagram_description
|
||||
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
"inferredQueries": inferred_queries,
|
||||
"image": diagram_description,
|
||||
}
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
excalidraw_diagram_description,
|
||||
user,
|
||||
meta_log,
|
||||
user_message_time,
|
||||
intent_type="excalidraw",
|
||||
inferred_queries=[better_diagram_description_prompt],
|
||||
client_application=request.user.client_app,
|
||||
conversation_id=conversation_id,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
)
|
||||
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
yield result
|
||||
return
|
||||
|
||||
## Generate Text Output
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Generating a well-informed response**"):
|
||||
yield result
|
||||
|
@ -1024,7 +1063,7 @@ async def chat(
|
|||
conversation_id,
|
||||
location,
|
||||
user_name,
|
||||
uploaded_image_url,
|
||||
uploaded_images,
|
||||
)
|
||||
|
||||
# Send Response
|
||||
|
@ -1050,9 +1089,9 @@ async def chat(
|
|||
|
||||
## Stream Text Response
|
||||
if stream:
|
||||
return StreamingResponse(event_generator(q, image=image), media_type="text/plain")
|
||||
return StreamingResponse(event_generator(q, images=raw_images), media_type="text/plain")
|
||||
## Non-Streaming Text Response
|
||||
else:
|
||||
response_iterator = event_generator(q, image=image)
|
||||
response_iterator = event_generator(q, images=raw_images)
|
||||
response_data = await read_chat_stream(response_iterator)
|
||||
return Response(content=json.dumps(response_data), media_type="application/json", status_code=200)
|
||||
|
|
|
@ -90,7 +90,7 @@ async def login_magic_link(request: Request, form: MagicLinkForm):
|
|||
request=request,
|
||||
telemetry_type="api",
|
||||
api="create_user",
|
||||
metadata={"user_id": str(user.uuid)},
|
||||
metadata={"server_id": str(user.uuid)},
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {user.uuid}")
|
||||
|
||||
|
@ -175,7 +175,7 @@ async def auth(request: Request):
|
|||
request=request,
|
||||
telemetry_type="api",
|
||||
api="create_user",
|
||||
metadata={"user_id": str(khoj_user.uuid)},
|
||||
metadata={"server_id": str(khoj_user.uuid)},
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {khoj_user.uuid}")
|
||||
return RedirectResponse(url=next_url, status_code=HTTP_302_FOUND)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
|
@ -14,6 +15,7 @@ from typing import (
|
|||
Annotated,
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
|
@ -21,7 +23,7 @@ from typing import (
|
|||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
||||
from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse
|
||||
|
||||
import cron_descriptor
|
||||
import pytz
|
||||
|
@ -30,6 +32,7 @@ from apscheduler.job import Job
|
|||
from apscheduler.triggers.cron import CronTrigger
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from starlette.authentication import has_required_scope
|
||||
from starlette.requests import URL
|
||||
|
||||
|
@ -215,6 +218,9 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
|
|||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
||||
return chat_history
|
||||
|
||||
|
||||
|
@ -235,6 +241,8 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
|||
return ConversationCommand.AutomatedTask
|
||||
elif query.startswith("/summarize"):
|
||||
return ConversationCommand.Summarize
|
||||
elif query.startswith("/diagram"):
|
||||
return ConversationCommand.Diagram
|
||||
# If no relevant notes found for the given query
|
||||
elif not any_references:
|
||||
return ConversationCommand.General
|
||||
|
@ -290,7 +298,7 @@ async def aget_relevant_information_sources(
|
|||
conversation_history: dict,
|
||||
is_task: bool,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"""
|
||||
|
@ -309,8 +317,8 @@ async def aget_relevant_information_sources(
|
|||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
if uploaded_image_url:
|
||||
query = f"[placeholder for user attached image]\n{query}"
|
||||
if query_images:
|
||||
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
|
@ -367,7 +375,7 @@ async def aget_relevant_output_modes(
|
|||
conversation_history: dict,
|
||||
is_task: bool = False,
|
||||
user: KhojUser = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"""
|
||||
|
@ -389,8 +397,8 @@ async def aget_relevant_output_modes(
|
|||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
if uploaded_image_url:
|
||||
query = f"[placeholder for user attached image]\n{query}"
|
||||
if query_images:
|
||||
query = f"[placeholder for {len(query_images)} user attached images]\n{query}"
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
|
@ -433,7 +441,7 @@ async def infer_webpage_urls(
|
|||
conversation_history: dict,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
|
@ -459,7 +467,7 @@ async def infer_webpage_urls(
|
|||
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||
|
@ -479,7 +487,7 @@ async def generate_online_subqueries(
|
|||
conversation_history: dict,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
|
@ -505,7 +513,7 @@ async def generate_online_subqueries(
|
|||
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
online_queries_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||
online_queries_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
@ -524,7 +532,7 @@ async def generate_online_subqueries(
|
|||
|
||||
|
||||
async def schedule_query(
|
||||
q: str, conversation_history: dict, user: KhojUser, uploaded_image_url: str = None
|
||||
q: str, conversation_history: dict, user: KhojUser, query_images: List[str] = None
|
||||
) -> Tuple[str, ...]:
|
||||
"""
|
||||
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||
|
@ -537,7 +545,7 @@ async def schedule_query(
|
|||
)
|
||||
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
crontime_prompt, uploaded_image_url=uploaded_image_url, response_type="json_object", user=user
|
||||
crontime_prompt, query_images=query_images, response_type="json_object", user=user
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
|
@ -583,7 +591,7 @@ async def extract_relevant_summary(
|
|||
q: str,
|
||||
corpus: str,
|
||||
conversation_history: dict,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
) -> Union[str, None]:
|
||||
|
@ -612,11 +620,134 @@ async def extract_relevant_summary(
|
|||
extract_relevant_information,
|
||||
prompts.system_prompt_extract_relevant_summary,
|
||||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
)
|
||||
return response.strip()
|
||||
|
||||
|
||||
async def generate_excalidraw_diagram(
|
||||
q: str,
|
||||
conversation_history: Dict[str, Any],
|
||||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
):
|
||||
if send_status_func:
|
||||
async for event in send_status_func("**Enhancing the Diagramming Prompt**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
better_diagram_description_prompt = await generate_better_diagram_description(
|
||||
q=q,
|
||||
conversation_history=conversation_history,
|
||||
location_data=location_data,
|
||||
note_references=note_references,
|
||||
online_results=online_results,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Diagram to Create:**:\n{better_diagram_description_prompt}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
excalidraw_diagram_description = await generate_excalidraw_diagram_from_description(
|
||||
q=better_diagram_description_prompt,
|
||||
user=user,
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
yield better_diagram_description_prompt, excalidraw_diagram_description
|
||||
|
||||
|
||||
async def generate_better_diagram_description(
|
||||
q: str,
|
||||
conversation_history: Dict[str, Any],
|
||||
location_data: LocationData,
|
||||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
query_images: List[str] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a diagram description from the given query and context
|
||||
"""
|
||||
|
||||
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
if location_data:
|
||||
location_prompt = prompts.user_location.format(location=f"{location_data}")
|
||||
else:
|
||||
location_prompt = "Unknown"
|
||||
|
||||
user_references = "\n\n".join([f"# {item['compiled']}" for item in note_references])
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
simplified_online_results = {}
|
||||
|
||||
if online_results:
|
||||
for result in online_results:
|
||||
if online_results[result].get("answerBox"):
|
||||
simplified_online_results[result] = online_results[result]["answerBox"]
|
||||
elif online_results[result].get("webpages"):
|
||||
simplified_online_results[result] = online_results[result]["webpages"]
|
||||
|
||||
improve_diagram_description_prompt = prompts.improve_diagram_description_prompt.format(
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
location=location_prompt,
|
||||
current_date=today_date,
|
||||
references=user_references,
|
||||
online_results=simplified_online_results,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Generate better diagram description", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
improve_diagram_description_prompt, query_images=query_images, user=user
|
||||
)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def generate_excalidraw_diagram_from_description(
|
||||
q: str,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
) -> str:
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
excalidraw_diagram_generation = prompts.excalidraw_diagram_generation_prompt.format(
|
||||
personality_context=personality_context,
|
||||
query=q,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(message=excalidraw_diagram_generation, user=user)
|
||||
raw_response = raw_response.strip()
|
||||
raw_response = remove_json_codeblock(raw_response)
|
||||
response: Dict[str, str] = json.loads(raw_response)
|
||||
if not response or not isinstance(response, List) or not isinstance(response[0], Dict):
|
||||
# TODO Some additional validation here that it's a valid Excalidraw diagram
|
||||
raise AssertionError(f"Invalid response for improving diagram description: {response}")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def generate_better_image_prompt(
|
||||
q: str,
|
||||
conversation_history: str,
|
||||
|
@ -624,7 +755,7 @@ async def generate_better_image_prompt(
|
|||
note_references: List[Dict[str, Any]],
|
||||
online_results: Optional[dict] = None,
|
||||
model_type: Optional[str] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
user: KhojUser = None,
|
||||
agent: Agent = None,
|
||||
) -> str:
|
||||
|
@ -676,7 +807,7 @@ async def generate_better_image_prompt(
|
|||
)
|
||||
|
||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||
response = await send_message_to_model_wrapper(image_prompt, uploaded_image_url=uploaded_image_url, user=user)
|
||||
response = await send_message_to_model_wrapper(image_prompt, query_images=query_images, user=user)
|
||||
response = response.strip()
|
||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||
response = response[1:-1]
|
||||
|
@ -689,11 +820,11 @@ async def send_message_to_model_wrapper(
|
|||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
user: KhojUser = None,
|
||||
uploaded_image_url: str = None,
|
||||
query_images: List[str] = None,
|
||||
):
|
||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
if not vision_available and uploaded_image_url:
|
||||
if not vision_available and query_images:
|
||||
vision_enabled_config = await ConversationAdapters.aget_vision_enabled_config()
|
||||
if vision_enabled_config:
|
||||
conversation_config = vision_enabled_config
|
||||
|
@ -746,7 +877,7 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
|
@ -766,7 +897,7 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
|
@ -784,7 +915,8 @@ async def send_message_to_model_wrapper(
|
|||
max_prompt_size=max_tokens,
|
||||
tokenizer_name=tokenizer,
|
||||
vision_enabled=vision_available,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
|
@ -875,6 +1007,7 @@ def send_message_to_model_wrapper_sync(
|
|||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
vision_enabled=vision_available,
|
||||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
return gemini_send_message_to_model(
|
||||
|
@ -900,7 +1033,7 @@ def generate_chat_response(
|
|||
conversation_id: str = None,
|
||||
location_data: LocationData = None,
|
||||
user_name: Optional[str] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
query_images: Optional[List[str]] = None,
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response = None
|
||||
|
@ -919,12 +1052,12 @@ def generate_chat_response(
|
|||
inferred_queries=inferred_queries,
|
||||
client_application=client_application,
|
||||
conversation_id=conversation_id,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
)
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
vision_available = conversation_config.vision_enabled
|
||||
if not vision_available and uploaded_image_url:
|
||||
if not vision_available and query_images:
|
||||
vision_enabled_config = ConversationAdapters.get_vision_enabled_config()
|
||||
if vision_enabled_config:
|
||||
conversation_config = vision_enabled_config
|
||||
|
@ -955,7 +1088,7 @@ def generate_chat_response(
|
|||
chat_response = converse(
|
||||
compiled_references,
|
||||
q,
|
||||
image_url=uploaded_image_url,
|
||||
query_images=query_images,
|
||||
online_results=online_results,
|
||||
conversation_log=meta_log,
|
||||
model=chat_model,
|
||||
|
@ -993,8 +1126,9 @@ def generate_chat_response(
|
|||
chat_response = converse_gemini(
|
||||
compiled_references,
|
||||
q,
|
||||
online_results,
|
||||
meta_log,
|
||||
query_images=query_images,
|
||||
online_results=online_results,
|
||||
conversation_log=meta_log,
|
||||
model=conversation_config.chat_model,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
|
@ -1004,6 +1138,7 @@ def generate_chat_response(
|
|||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
vision_available=vision_available,
|
||||
)
|
||||
|
||||
metadata.update({"chat_model": conversation_config.chat_model})
|
||||
|
@ -1015,6 +1150,22 @@ def generate_chat_response(
|
|||
return chat_response, metadata
|
||||
|
||||
|
||||
class ChatRequestBody(BaseModel):
|
||||
q: str
|
||||
n: Optional[int] = 7
|
||||
d: Optional[float] = None
|
||||
stream: Optional[bool] = False
|
||||
title: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
images: Optional[list[str]] = None
|
||||
create_new: Optional[bool] = False
|
||||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
||||
self.requests = requests
|
||||
|
@ -1060,13 +1211,58 @@ class ApiUserRateLimiter:
|
|||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
|
||||
detail="I'm glad you're enjoying interacting with me! But you've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
|
||||
)
|
||||
|
||||
# Add the current request to the cache
|
||||
UserRequests.objects.create(user=user, slug=self.slug)
|
||||
|
||||
|
||||
class ApiImageRateLimiter:
|
||||
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
|
||||
self.max_images = max_images
|
||||
self.max_combined_size_mb = max_combined_size_mb
|
||||
|
||||
def __call__(self, request: Request, body: ChatRequestBody):
|
||||
if state.billing_enabled is False:
|
||||
return
|
||||
|
||||
# Rate limiting is disabled if user unauthenticated.
|
||||
# Other systems handle authentication
|
||||
if not request.user.is_authenticated:
|
||||
return
|
||||
|
||||
if not body.images:
|
||||
return
|
||||
|
||||
# Check number of images
|
||||
if len(body.images) > self.max_images:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
|
||||
)
|
||||
|
||||
# Check total size of images
|
||||
total_size_mb = 0.0
|
||||
for image in body.images:
|
||||
# Unquote the image in case it's URL encoded
|
||||
image = unquote(image)
|
||||
# Assuming the image is a base64 encoded string
|
||||
# Remove the data:image/jpeg;base64, part if present
|
||||
if "," in image:
|
||||
image = image.split(",", 1)[1]
|
||||
|
||||
# Decode base64 to get the actual size
|
||||
image_bytes = base64.b64decode(image)
|
||||
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
|
||||
|
||||
if total_size_mb > self.max_combined_size_mb:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
||||
)
|
||||
|
||||
|
||||
class ConversationCommandRateLimiter:
|
||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
||||
self.slug = slug
|
||||
|
|
|
@ -82,7 +82,7 @@ async def subscribe(request: Request):
|
|||
request=request,
|
||||
telemetry_type="api",
|
||||
api="create_user",
|
||||
metadata={"user_id": str(user.user.uuid)},
|
||||
metadata={"server_id": str(user.user.uuid)},
|
||||
)
|
||||
logger.log(logging.INFO, f"🥳 New User Created: {user.user.uuid}")
|
||||
|
||||
|
|
|
@ -51,17 +51,6 @@ def chat_page(request: Request):
|
|||
)
|
||||
|
||||
|
||||
@web_client.get("/experimental", response_class=FileResponse)
|
||||
@requires(["authenticated"], redirect="login_page")
|
||||
def experimental_page(request: Request):
|
||||
return templates.TemplateResponse(
|
||||
"index.html",
|
||||
context={
|
||||
"request": request,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@web_client.get("/factchecker", response_class=FileResponse)
|
||||
def fact_checker_page(request: Request):
|
||||
return templates.TemplateResponse(
|
||||
|
|
|
@ -318,6 +318,7 @@ class ConversationCommand(str, Enum):
|
|||
Automation = "automation"
|
||||
AutomatedTask = "automated_task"
|
||||
Summarize = "summarize"
|
||||
Diagram = "diagram"
|
||||
|
||||
|
||||
command_descriptions = {
|
||||
|
@ -326,10 +327,11 @@ command_descriptions = {
|
|||
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
|
||||
ConversationCommand.Online: "Search for information on the internet.",
|
||||
ConversationCommand.Webpage: "Get information from webpage suggested by you.",
|
||||
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
||||
ConversationCommand.Image: "Generate illustrative, creative images by describing your imagination in words.",
|
||||
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
|
||||
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
|
||||
ConversationCommand.Summarize: "Get help with a question pertaining to an entire document.",
|
||||
ConversationCommand.Diagram: "Draw a flowchart, diagram, or any other visual representation best expressed with primitives like lines, rectangles, and text.",
|
||||
}
|
||||
|
||||
command_descriptions_for_agent = {
|
||||
|
@ -350,15 +352,17 @@ tool_descriptions_for_llm = {
|
|||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
ConversationCommand.Image: "Use this if the user is requesting you to generate a picture based on their description.",
|
||||
ConversationCommand.Image: "Use this if the user is requesting you to create a new picture based on their description.",
|
||||
ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency",
|
||||
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
|
||||
ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query.",
|
||||
ConversationCommand.Diagram: "Use this if the user is requesting a visual representation that requires primitives like lines, rectangles, and text.",
|
||||
}
|
||||
|
||||
mode_descriptions_for_agent = {
|
||||
ConversationCommand.Image: "Agent can generate image in response.",
|
||||
ConversationCommand.Automation: "Agent can schedule a task to run at a scheduled date, time and frequency in response.",
|
||||
ConversationCommand.Text: "Agent can generate text in response.",
|
||||
ConversationCommand.Diagram: "Agent can generate a visual representation that requires primitives like lines, rectangles, and text.",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -178,6 +178,13 @@ def api_user4(default_user4):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_openai_chat_model_option():
|
||||
chat_model = ChatModelOptionsFactory(chat_model="gpt-4o-mini", model_type="openai")
|
||||
return chat_model
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def offline_agent():
|
||||
|
|
211
tests/test_agents.py
Normal file
211
tests/test_agents.py
Normal file
|
@ -0,0 +1,211 @@
|
|||
# tests/test_agents.py
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from khoj.database.adapters import AgentAdapters
|
||||
from khoj.database.models import Agent, ChatModelOptions, Entry, KhojUser
|
||||
from khoj.routers.api import execute_search
|
||||
from khoj.utils.helpers import get_absolute_path
|
||||
from tests.helpers import ChatModelOptionsFactory
|
||||
|
||||
|
||||
def test_create_default_agent(default_user: KhojUser):
|
||||
ChatModelOptionsFactory()
|
||||
|
||||
agent = AgentAdapters.create_default_agent(default_user)
|
||||
assert agent is not None
|
||||
assert agent.input_tools == []
|
||||
assert agent.output_modes == []
|
||||
assert agent.privacy_level == Agent.PrivacyLevel.PUBLIC
|
||||
assert agent.managed_by_admin == True
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_create_or_update_agent(default_user: KhojUser, default_openai_chat_model_option: ChatModelOptions):
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
default_user,
|
||||
"Test Agent",
|
||||
"Test Personality",
|
||||
Agent.PrivacyLevel.PRIVATE,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
assert new_agent is not None
|
||||
assert new_agent.name == "Test Agent"
|
||||
assert new_agent.privacy_level == Agent.PrivacyLevel.PRIVATE
|
||||
assert new_agent.creator == default_user
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_create_or_update_agent_with_knowledge_base(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
default_user2,
|
||||
"Test Agent",
|
||||
"Test Personality",
|
||||
Agent.PrivacyLevel.PRIVATE,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
[full_filename],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
entries = await sync_to_async(list)(Entry.objects.filter(agent=new_agent))
|
||||
file_names = set()
|
||||
for entry in entries:
|
||||
file_names.add(entry.file_path)
|
||||
|
||||
assert new_agent is not None
|
||||
assert new_agent.name == "Test Agent"
|
||||
assert new_agent.privacy_level == Agent.PrivacyLevel.PRIVATE
|
||||
assert new_agent.creator == default_user2
|
||||
assert len(entries) > 0
|
||||
assert full_filename in file_names
|
||||
assert len(file_names) == 1
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_create_or_update_agent_with_knowledge_base_and_search(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
default_user2,
|
||||
"Test Agent",
|
||||
"Test Personality",
|
||||
Agent.PrivacyLevel.PRIVATE,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
[full_filename],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
search_result = await execute_search(user=default_user2, q="having kids", agent=new_agent)
|
||||
|
||||
assert len(search_result) == 5
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_agent_with_knowledge_base_and_search_not_creator(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
default_user2,
|
||||
"Test Agent",
|
||||
"Test Personality",
|
||||
Agent.PrivacyLevel.PUBLIC,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
[full_filename],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
search_result = await execute_search(user=default_user3, q="having kids", agent=new_agent)
|
||||
|
||||
assert len(search_result) == 5
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_agent_with_knowledge_base_and_search_not_creator_and_private(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
default_user2,
|
||||
"Test Agent",
|
||||
"Test Personality",
|
||||
Agent.PrivacyLevel.PRIVATE,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
[full_filename],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
search_result = await execute_search(user=default_user3, q="having kids", agent=new_agent)
|
||||
|
||||
assert len(search_result) == 0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_agent_with_knowledge_base_and_search_not_creator_and_private_accessible_to_none(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
default_user2,
|
||||
"Test Agent",
|
||||
"Test Personality",
|
||||
Agent.PrivacyLevel.PRIVATE,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
[full_filename],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
search_result = await execute_search(user=None, q="having kids", agent=new_agent)
|
||||
|
||||
assert len(search_result) == 5
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
async def test_multiple_agents_with_knowledge_base_and_users(
|
||||
default_user2: KhojUser, default_openai_chat_model_option: ChatModelOptions, chat_client, default_user3: KhojUser
|
||||
):
|
||||
full_filename = get_absolute_path("tests/data/markdown/having_kids.markdown")
|
||||
new_agent = await AgentAdapters.aupdate_agent(
|
||||
default_user2,
|
||||
"Test Agent",
|
||||
"Test Personality",
|
||||
Agent.PrivacyLevel.PUBLIC,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
[full_filename],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
full_filename2 = get_absolute_path("tests/data/markdown/Namita.markdown")
|
||||
new_agent2 = await AgentAdapters.aupdate_agent(
|
||||
default_user2,
|
||||
"Test Agent 2",
|
||||
"Test Personality",
|
||||
Agent.PrivacyLevel.PUBLIC,
|
||||
"icon",
|
||||
"color",
|
||||
default_openai_chat_model_option.chat_model,
|
||||
[full_filename2],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
|
||||
search_result = await execute_search(user=default_user3, q="having kids", agent=new_agent2)
|
||||
search_result2 = await execute_search(user=default_user3, q="Namita", agent=new_agent2)
|
||||
|
||||
assert len(search_result) == 0
|
||||
assert len(search_result2) == 1
|
|
@ -1,6 +1,8 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||
from khoj.utils.fs_syncer import get_pdf_files
|
||||
from khoj.utils.rawconfig import TextContentConfig
|
||||
|
@ -37,6 +39,7 @@ def test_multi_page_pdf_to_jsonl():
|
|||
assert len(entries[1]) == 6
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Temporarily disabled OCR due to performance issues")
|
||||
def test_ocr_page_pdf_to_jsonl():
|
||||
"Convert multiple pages from single PDF file to jsonl."
|
||||
# Arrange
|
||||
|
|
|
@ -78,5 +78,9 @@
|
|||
"1.24.0": "0.15.0",
|
||||
"1.24.1": "0.15.0",
|
||||
"1.25.0": "0.15.0",
|
||||
"1.26.0": "0.15.0"
|
||||
"1.26.0": "0.15.0",
|
||||
"1.26.1": "0.15.0",
|
||||
"1.26.2": "0.15.0",
|
||||
"1.26.3": "0.15.0",
|
||||
"1.26.4": "0.15.0"
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue