Merge pull request #957 from khoj-ai/features/include-full-file-in-convo-with-filter

Support including file attachments in the chat message

Now that models have much larger context windows, we can reasonably include full texts of certain files in the messages. Do this when an explicit file filter is set in a conversation. Do so in a separate user message in order to mitigate any confusion in the operation.

Pipe the relevant attached_files context through all methods calling into models.

This breaks certain prior behaviors. We will no longer automatically be processing/generating embeddings on the backend and adding documents to the "brain". You'll have to go to settings and go through the upload documents flow there in order to add docs to the brain (i.e., have search include them during question / response).
This commit is contained in:
sabaimran 2024-11-11 11:34:42 -08:00 committed by GitHub
commit b563f46a2e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
33 changed files with 880 additions and 418 deletions

View file

@ -8,7 +8,7 @@ import ChatHistory from "../components/chatHistory/chatHistory";
import { useSearchParams } from "next/navigation"; import { useSearchParams } from "next/navigation";
import Loading from "../components/loading/loading"; import Loading from "../components/loading/loading";
import { processMessageChunk } from "../common/chatFunctions"; import { generateNewTitle, processMessageChunk } from "../common/chatFunctions";
import "katex/dist/katex.min.css"; import "katex/dist/katex.min.css";
@ -19,7 +19,11 @@ import {
StreamMessage, StreamMessage,
} from "../components/chatMessage/chatMessage"; } from "../components/chatMessage/chatMessage";
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../common/utils"; import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../common/utils";
import { ChatInputArea, ChatOptions } from "../components/chatInputArea/chatInputArea"; import {
AttachedFileText,
ChatInputArea,
ChatOptions,
} from "../components/chatInputArea/chatInputArea";
import { useAuthenticatedData } from "../common/auth"; import { useAuthenticatedData } from "../common/auth";
import { AgentData } from "../agents/page"; import { AgentData } from "../agents/page";
@ -30,7 +34,7 @@ interface ChatBodyDataProps {
setQueryToProcess: (query: string) => void; setQueryToProcess: (query: string) => void;
streamedMessages: StreamMessage[]; streamedMessages: StreamMessage[];
setStreamedMessages: (messages: StreamMessage[]) => void; setStreamedMessages: (messages: StreamMessage[]) => void;
setUploadedFiles: (files: string[]) => void; setUploadedFiles: (files: AttachedFileText[] | undefined) => void;
isMobileWidth?: boolean; isMobileWidth?: boolean;
isLoggedIn: boolean; isLoggedIn: boolean;
setImages: (images: string[]) => void; setImages: (images: string[]) => void;
@ -77,7 +81,24 @@ function ChatBodyData(props: ChatBodyDataProps) {
setIsInResearchMode(true); setIsInResearchMode(true);
} }
} }
}, [setQueryToProcess, props.setImages]);
const storedUploadedFiles = localStorage.getItem("uploadedFiles");
if (storedUploadedFiles) {
const parsedFiles = storedUploadedFiles ? JSON.parse(storedUploadedFiles) : [];
const uploadedFiles: AttachedFileText[] = [];
for (const file of parsedFiles) {
uploadedFiles.push({
name: file.name,
file_type: file.file_type,
content: file.content,
size: file.size,
});
}
localStorage.removeItem("uploadedFiles");
props.setUploadedFiles(uploadedFiles);
}
}, [setQueryToProcess, props.setImages, conversationId]);
useEffect(() => { useEffect(() => {
if (message) { if (message) {
@ -100,6 +121,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
) { ) {
setProcessingMessage(false); setProcessingMessage(false);
setImages([]); // Reset images after processing setImages([]); // Reset images after processing
props.setUploadedFiles(undefined); // Reset uploaded files after processing
} else { } else {
setMessage(""); setMessage("");
} }
@ -153,7 +175,7 @@ export default function Chat() {
const [messages, setMessages] = useState<StreamMessage[]>([]); const [messages, setMessages] = useState<StreamMessage[]>([]);
const [queryToProcess, setQueryToProcess] = useState<string>(""); const [queryToProcess, setQueryToProcess] = useState<string>("");
const [processQuerySignal, setProcessQuerySignal] = useState(false); const [processQuerySignal, setProcessQuerySignal] = useState(false);
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]); const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | undefined>(undefined);
const [images, setImages] = useState<string[]>([]); const [images, setImages] = useState<string[]>([]);
const locationData = useIPLocationData() || { const locationData = useIPLocationData() || {
@ -192,6 +214,7 @@ export default function Chat() {
timestamp: new Date().toISOString(), timestamp: new Date().toISOString(),
rawQuery: queryToProcess || "", rawQuery: queryToProcess || "",
images: images, images: images,
queryFiles: uploadedFiles,
}; };
setMessages((prevMessages) => [...prevMessages, newStreamMessage]); setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
setProcessQuerySignal(true); setProcessQuerySignal(true);
@ -224,6 +247,9 @@ export default function Chat() {
setQueryToProcess(""); setQueryToProcess("");
setProcessQuerySignal(false); setProcessQuerySignal(false);
setImages([]); setImages([]);
if (conversationId) generateNewTitle(conversationId, setTitle);
break; break;
} }
@ -273,6 +299,7 @@ export default function Chat() {
timezone: locationData.timezone, timezone: locationData.timezone,
}), }),
...(images.length > 0 && { images: images }), ...(images.length > 0 && { images: images }),
...(uploadedFiles && { files: uploadedFiles }),
}; };
const response = await fetch(chatAPI, { const response = await fetch(chatAPI, {
@ -325,7 +352,7 @@ export default function Chat() {
<div> <div>
<SidePanel <SidePanel
conversationId={conversationId} conversationId={conversationId}
uploadedFiles={uploadedFiles} uploadedFiles={[]}
isMobileWidth={isMobileWidth} isMobileWidth={isMobileWidth}
/> />
</div> </div>

View file

@ -267,6 +267,78 @@ export async function createNewConversation(slug: string) {
} }
} }
export async function packageFilesForUpload(files: FileList): Promise<FormData> {
const formData = new FormData();
const fileReadPromises = Array.from(files).map((file) => {
return new Promise<void>((resolve, reject) => {
let reader = new FileReader();
reader.onload = function (event) {
if (event.target === null) {
reject();
return;
}
let fileContents = event.target.result;
let fileType = file.type;
let fileName = file.name;
if (fileType === "") {
let fileExtension = fileName.split(".").pop();
if (fileExtension === "org") {
fileType = "text/org";
} else if (fileExtension === "md") {
fileType = "text/markdown";
} else if (fileExtension === "txt") {
fileType = "text/plain";
} else if (fileExtension === "html") {
fileType = "text/html";
} else if (fileExtension === "pdf") {
fileType = "application/pdf";
} else if (fileExtension === "docx") {
fileType =
"application/vnd.openxmlformats-officedocument.wordprocessingml.document";
} else {
// Skip this file if its type is not supported
resolve();
return;
}
}
if (fileContents === null) {
reject();
return;
}
let fileObj = new Blob([fileContents], { type: fileType });
formData.append("files", fileObj, file.name);
resolve();
};
reader.onerror = reject;
reader.readAsArrayBuffer(file);
});
});
await Promise.all(fileReadPromises);
return formData;
}
export function generateNewTitle(conversationId: string, setTitle: (title: string) => void) {
fetch(`/api/chat/title?conversation_id=${conversationId}`, {
method: "POST",
})
.then((res) => {
if (!res.ok) throw new Error(`Failed to call API with error ${res.statusText}`);
return res.json();
})
.then((data) => {
setTitle(data.title);
})
.catch((err) => {
console.error(err);
return;
});
}
export function uploadDataForIndexing( export function uploadDataForIndexing(
files: FileList, files: FileList,
setWarning: (warning: string) => void, setWarning: (warning: string) => void,

View file

@ -49,8 +49,11 @@ import {
Gavel, Gavel,
Broadcast, Broadcast,
KeyReturn, KeyReturn,
FilePdf,
FileMd,
MicrosoftWordLogo,
} from "@phosphor-icons/react"; } from "@phosphor-icons/react";
import { Markdown, OrgMode, Pdf, Word } from "@/app/components/logo/fileLogo"; import { OrgMode } from "@/app/components/logo/fileLogo";
interface IconMap { interface IconMap {
[key: string]: (color: string, width: string, height: string) => JSX.Element | null; [key: string]: (color: string, width: string, height: string) => JSX.Element | null;
@ -238,11 +241,12 @@ function getIconFromFilename(
return <OrgMode className={className} />; return <OrgMode className={className} />;
case "markdown": case "markdown":
case "md": case "md":
return <Markdown className={className} />; return <FileMd className={className} />;
case "pdf": case "pdf":
return <Pdf className={className} />; return <FilePdf className={className} />;
case "doc": case "doc":
return <Word className={className} />; case "docx":
return <MicrosoftWordLogo className={className} />;
case "jpg": case "jpg":
case "jpeg": case "jpeg":
case "png": case "png":

View file

@ -71,6 +71,16 @@ export function useIsMobileWidth() {
return isMobileWidth; return isMobileWidth;
} }
export const convertBytesToText = (fileSize: number) => {
if (fileSize < 1024) {
return `${fileSize} B`;
} else if (fileSize < 1024 * 1024) {
return `${(fileSize / 1024).toFixed(2)} KB`;
} else {
return `${(fileSize / (1024 * 1024)).toFixed(2)} MB`;
}
};
export function useDebounce<T>(value: T, delay: number): T { export function useDebounce<T>(value: T, delay: number): T {
const [debouncedValue, setDebouncedValue] = useState<T>(value); const [debouncedValue, setDebouncedValue] = useState<T>(value);

View file

@ -373,6 +373,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
images: message.images, images: message.images,
conversationId: props.conversationId, conversationId: props.conversationId,
turnId: messageTurnId, turnId: messageTurnId,
queryFiles: message.queryFiles,
}} }}
customClassName="fullHistory" customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`} borderLeftColor={`${data?.agent?.color}-500`}

View file

@ -40,19 +40,36 @@ import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/comp
import { convertColorToTextClass, convertToBGClass } from "@/app/common/colorUtils"; import { convertColorToTextClass, convertToBGClass } from "@/app/common/colorUtils";
import LoginPrompt from "../loginPrompt/loginPrompt"; import LoginPrompt from "../loginPrompt/loginPrompt";
import { uploadDataForIndexing } from "../../common/chatFunctions";
import { InlineLoading } from "../loading/loading"; import { InlineLoading } from "../loading/loading";
import { getIconForSlashCommand } from "@/app/common/iconUtils"; import { getIconForSlashCommand, getIconFromFilename } from "@/app/common/iconUtils";
import { packageFilesForUpload } from "@/app/common/chatFunctions";
import { convertBytesToText } from "@/app/common/utils";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
DialogTrigger,
} from "@/components/ui/dialog";
import { ScrollArea } from "@/components/ui/scroll-area";
export interface ChatOptions { export interface ChatOptions {
[key: string]: string; [key: string]: string;
} }
export interface AttachedFileText {
name: string;
content: string;
file_type: string;
size: number;
}
interface ChatInputProps { interface ChatInputProps {
sendMessage: (message: string) => void; sendMessage: (message: string) => void;
sendImage: (image: string) => void; sendImage: (image: string) => void;
sendDisabled: boolean; sendDisabled: boolean;
setUploadedFiles?: (files: string[]) => void; setUploadedFiles: (files: AttachedFileText[]) => void;
conversationId?: string | null; conversationId?: string | null;
chatOptionsData?: ChatOptions | null; chatOptionsData?: ChatOptions | null;
isMobileWidth?: boolean; isMobileWidth?: boolean;
@ -75,6 +92,9 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
const [imagePaths, setImagePaths] = useState<string[]>([]); const [imagePaths, setImagePaths] = useState<string[]>([]);
const [imageData, setImageData] = useState<string[]>([]); const [imageData, setImageData] = useState<string[]>([]);
const [attachedFiles, setAttachedFiles] = useState<FileList | null>(null);
const [convertedAttachedFiles, setConvertedAttachedFiles] = useState<AttachedFileText[]>([]);
const [recording, setRecording] = useState(false); const [recording, setRecording] = useState(false);
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null); const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
@ -154,6 +174,8 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
} }
props.sendMessage(messageToSend); props.sendMessage(messageToSend);
setAttachedFiles(null);
setConvertedAttachedFiles([]);
setMessage(""); setMessage("");
} }
@ -203,22 +225,69 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
setImagePaths((prevPaths) => [...prevPaths, ...newImagePaths]); setImagePaths((prevPaths) => [...prevPaths, ...newImagePaths]);
// Set focus to the input for user message after uploading files // Set focus to the input for user message after uploading files
chatInputRef?.current?.focus(); chatInputRef?.current?.focus();
return;
} }
uploadDataForIndexing( // Process all non-image files
files, const nonImageFiles = Array.from(files).filter(
setWarning, (file) => !image_endings.includes(file.name.split(".").pop() || ""),
setUploading,
setError,
props.setUploadedFiles,
props.conversationId,
); );
// Concatenate attachedFiles and files
const newFiles = nonImageFiles
? Array.from(nonImageFiles).concat(Array.from(attachedFiles || []))
: Array.from(attachedFiles || []);
// Ensure files are below size limit (10 MB)
for (let i = 0; i < newFiles.length; i++) {
if (newFiles[i].size > 10 * 1024 * 1024) {
setWarning(
`File ${newFiles[i].name} is too large. Please upload files smaller than 10 MB.`,
);
return;
}
}
const dataTransfer = new DataTransfer();
newFiles.forEach((file) => dataTransfer.items.add(file));
setAttachedFiles(dataTransfer.files);
// Extract text from files
extractTextFromFiles(dataTransfer.files).then((data) => {
props.setUploadedFiles(data);
setConvertedAttachedFiles(data);
});
// Set focus to the input for user message after uploading files // Set focus to the input for user message after uploading files
chatInputRef?.current?.focus(); chatInputRef?.current?.focus();
} }
async function extractTextFromFiles(files: FileList): Promise<AttachedFileText[]> {
const formData = await packageFilesForUpload(files);
setUploading(true);
try {
const response = await fetch("/api/content/convert", {
method: "POST",
body: formData,
});
setUploading(false);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
return await response.json();
} catch (error) {
setError(
"Error converting files. " +
error +
". Please try again, or contact team@khoj.dev if the issue persists.",
);
console.error("Error converting files:", error);
return [];
}
}
// Assuming this function is added within the same context as the provided excerpt // Assuming this function is added within the same context as the provided excerpt
async function startRecordingAndTranscribe() { async function startRecordingAndTranscribe() {
try { try {
@ -445,6 +514,93 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
</div> </div>
)} )}
<div> <div>
<div className="flex items-center gap-2 overflow-x-auto">
{imageUploaded &&
imagePaths.map((path, index) => (
<div key={index} className="relative flex-shrink-0 pb-3 pt-2 group">
<img
src={path}
alt={`img-${index}`}
className="w-auto h-16 object-cover rounded-xl"
/>
<Button
variant="ghost"
size="icon"
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
onClick={() => removeImageUpload(index)}
>
<X className="h-3 w-3" />
</Button>
</div>
))}
{convertedAttachedFiles &&
Array.from(convertedAttachedFiles).map((file, index) => (
<Dialog key={index}>
<DialogTrigger asChild>
<div key={index} className="relative flex-shrink-0 p-2 group">
<div
className={`w-auto h-16 object-cover rounded-xl ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} bg-opacity-15`}
>
<div className="flex p-2 flex-col justify-start items-start h-full">
<span className="text-sm font-bold text-neutral-500 dark:text-neutral-400 text-ellipsis truncate max-w-[200px] break-words">
{file.name}
</span>
<span className="flex items-center gap-1">
{getIconFromFilename(file.file_type)}
<span className="text-xs text-neutral-500 dark:text-neutral-400">
{convertBytesToText(file.size)}
</span>
</span>
</div>
</div>
<Button
variant="ghost"
size="icon"
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
onClick={() => {
setAttachedFiles((prevFiles) => {
const removeFile = file.name;
if (!prevFiles) return null;
const updatedFiles = Array.from(
prevFiles,
).filter((file) => file.name !== removeFile);
const dataTransfer = new DataTransfer();
updatedFiles.forEach((file) =>
dataTransfer.items.add(file),
);
const filteredConvertedAttachedFiles =
convertedAttachedFiles.filter(
(file) => file.name !== removeFile,
);
props.setUploadedFiles(
filteredConvertedAttachedFiles,
);
setConvertedAttachedFiles(
filteredConvertedAttachedFiles,
);
return dataTransfer.files;
});
}}
>
<X className="h-3 w-3" />
</Button>
</div>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle>{file.name}</DialogTitle>
</DialogHeader>
<DialogDescription>
<ScrollArea className="h-72 w-full rounded-md">
{file.content}
</ScrollArea>
</DialogDescription>
</DialogContent>
</Dialog>
))}
</div>
<div <div
className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`} className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
onDragOver={handleDragOver} onDragOver={handleDragOver}
@ -453,12 +609,14 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
> >
<input <input
type="file" type="file"
accept=".pdf,.doc,.docx,.txt,.md,.org,.jpg,.jpeg,.png,.webp"
multiple={true} multiple={true}
ref={fileInputRef} ref={fileInputRef}
onChange={handleFileChange} onChange={handleFileChange}
style={{ display: "none" }} style={{ display: "none" }}
/> />
<div className="flex items-end pb-2">
<div className="flex items-center">
<Button <Button
variant={"ghost"} variant={"ghost"}
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500" className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
@ -469,29 +627,6 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
</Button> </Button>
</div> </div>
<div className="flex-grow flex flex-col w-full gap-1.5 relative"> <div className="flex-grow flex flex-col w-full gap-1.5 relative">
<div className="flex items-center gap-2 overflow-x-auto">
{imageUploaded &&
imagePaths.map((path, index) => (
<div
key={index}
className="relative flex-shrink-0 pb-3 pt-2 group"
>
<img
src={path}
alt={`img-${index}`}
className="w-auto h-16 object-cover rounded-xl"
/>
<Button
variant="ghost"
size="icon"
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
onClick={() => removeImageUpload(index)}
>
<X className="h-3 w-3" />
</Button>
</div>
))}
</div>
<Textarea <Textarea
ref={chatInputRef} ref={chatInputRef}
className={`border-none focus:border-none className={`border-none focus:border-none
@ -582,10 +717,14 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
<span className="text-muted-foreground text-sm">Research Mode</span> <span className="text-muted-foreground text-sm">Research Mode</span>
{useResearchMode ? ( {useResearchMode ? (
<ToggleRight <ToggleRight
weight="fill"
className={`w-6 h-6 inline-block ${props.agentColor ? convertColorToTextClass(props.agentColor) : convertColorToTextClass("orange")} rounded-full`} className={`w-6 h-6 inline-block ${props.agentColor ? convertColorToTextClass(props.agentColor) : convertColorToTextClass("orange")} rounded-full`}
/> />
) : ( ) : (
<ToggleLeft className={`w-6 h-6 inline-block rounded-full`} /> <ToggleLeft
weight="fill"
className={`w-6 h-6 inline-block ${props.agentColor ? convertColorToTextClass(props.agentColor) : convertColorToTextClass("orange")} rounded-full`}
/>
)} )}
</Button> </Button>
</TooltipTrigger> </TooltipTrigger>

View file

@ -40,6 +40,18 @@ import { AgentData } from "@/app/agents/page";
import renderMathInElement from "katex/contrib/auto-render"; import renderMathInElement from "katex/contrib/auto-render";
import "katex/dist/katex.min.css"; import "katex/dist/katex.min.css";
import ExcalidrawComponent from "../excalidraw/excalidraw"; import ExcalidrawComponent from "../excalidraw/excalidraw";
import { AttachedFileText } from "../chatInputArea/chatInputArea";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTrigger,
} from "@/components/ui/dialog";
import { DialogTitle } from "@radix-ui/react-dialog";
import { convertBytesToText } from "@/app/common/utils";
import { ScrollArea } from "@/components/ui/scroll-area";
import { getIconFromFilename } from "@/app/common/iconUtils";
const md = new markdownIt({ const md = new markdownIt({
html: true, html: true,
@ -149,6 +161,7 @@ export interface SingleChatMessage {
images?: string[]; images?: string[];
conversationId: string; conversationId: string;
turnId?: string; turnId?: string;
queryFiles?: AttachedFileText[];
} }
export interface StreamMessage { export interface StreamMessage {
@ -165,6 +178,7 @@ export interface StreamMessage {
intentType?: string; intentType?: string;
inferredQueries?: string[]; inferredQueries?: string[];
turnId?: string; turnId?: string;
queryFiles?: AttachedFileText[];
} }
export interface ChatHistoryData { export interface ChatHistoryData {
@ -398,7 +412,6 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
if (props.chatMessage.intent) { if (props.chatMessage.intent) {
const { type, "inferred-queries": inferredQueries } = props.chatMessage.intent; const { type, "inferred-queries": inferredQueries } = props.chatMessage.intent;
console.log("intent type", type);
if (type in intentTypeHandlers) { if (type in intentTypeHandlers) {
message = intentTypeHandlers[type as keyof typeof intentTypeHandlers](message); message = intentTypeHandlers[type as keyof typeof intentTypeHandlers](message);
} }
@ -695,6 +708,40 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
onMouseLeave={(event) => setIsHovering(false)} onMouseLeave={(event) => setIsHovering(false)}
onMouseEnter={(event) => setIsHovering(true)} onMouseEnter={(event) => setIsHovering(true)}
> >
{props.chatMessage.queryFiles && props.chatMessage.queryFiles.length > 0 && (
<div className="flex flex-wrap flex-col m-2 max-w-full">
{props.chatMessage.queryFiles.map((file, index) => (
<Dialog key={index}>
<DialogTrigger asChild>
<div
className="flex items-center space-x-2 cursor-pointer bg-gray-500 bg-opacity-25 rounded-lg m-1 p-2 w-full
"
>
<div className="flex-shrink-0">
{getIconFromFilename(file.file_type)}
</div>
<span className="truncate flex-1 min-w-0">{file.name}</span>
{file.size && (
<span className="text-gray-400 flex-shrink-0">
({convertBytesToText(file.size)})
</span>
)}
</div>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle>{file.name}</DialogTitle>
</DialogHeader>
<DialogDescription>
<ScrollArea className="h-72 w-full rounded-md">
{file.content}
</ScrollArea>
</DialogDescription>
</DialogContent>
</Dialog>
))}
</div>
)}
<div className={chatMessageWrapperClasses(props.chatMessage)}> <div className={chatMessageWrapperClasses(props.chatMessage)}>
<div <div
ref={messageRef} ref={messageRef}

View file

@ -81,111 +81,3 @@ export function OrgMode({ className }: { className?: string }) {
</svg> </svg>
); );
} }
export function Markdown({ className }: { className?: string }) {
const classes = className ?? "w-6 h-6 text-muted-foreground inline-flex mr-1";
return (
<svg
className={`${classes}`}
xmlns="http://www.w3.org/2000/svg"
width="208"
height="128"
viewBox="0 0 208 128"
>
<rect
width="198"
height="118"
x="5"
y="5"
ry="10"
stroke="#000"
strokeWidth="10"
fill="none"
/>
<path d="M30 98V30h20l20 25 20-25h20v68H90V59L70 84 50 59v39zm125 0l-30-33h20V30h20v35h20z" />
</svg>
);
}
export function Pdf({ className }: { className?: string }) {
const classes = className ?? "w-6 h-6 text-muted-foreground inline-flex mr-1";
return (
<svg
className={`${classes}`}
xmlns="http://www.w3.org/2000/svg"
enableBackground="new 0 0 334.371 380.563"
version="1.1"
viewBox="0 0 14 16"
>
<g transform="matrix(.04589 0 0 .04589 -.66877 -.73379)">
<polygon
points="51.791 356.65 51.791 23.99 204.5 23.99 282.65 102.07 282.65 356.65"
fill="#fff"
strokeWidth="212.65"
/>
<path
d="m201.19 31.99 73.46 73.393v243.26h-214.86v-316.66h141.4m6.623-16h-164.02v348.66h246.85v-265.9z"
strokeWidth="21.791"
/>
</g>
<g transform="matrix(.04589 0 0 .04589 -.66877 -.73379)">
<polygon
points="282.65 356.65 51.791 356.65 51.791 23.99 204.5 23.99 206.31 25.8 206.31 100.33 280.9 100.33 282.65 102.07"
fill="#fff"
strokeWidth="212.65"
/>
<path
d="m198.31 31.99v76.337h76.337v240.32h-214.86v-316.66h138.52m9.5-16h-164.02v348.66h246.85v-265.9l-6.43-6.424h-69.907v-69.842z"
strokeWidth="21.791"
/>
</g>
<g transform="matrix(.04589 0 0 .04589 -.66877 -.73379)" strokeWidth="21.791">
<polygon points="258.31 87.75 219.64 87.75 219.64 48.667 258.31 86.38" />
<path d="m227.64 67.646 12.41 12.104h-12.41v-12.104m-5.002-27.229h-10.998v55.333h54.666v-12.742z" />
</g>
<g
transform="matrix(.04589 0 0 .04589 -.66877 -.73379)"
fill="#ed1c24"
strokeWidth="212.65"
>
<polygon points="311.89 284.49 22.544 284.49 22.544 167.68 37.291 152.94 37.291 171.49 297.15 171.49 297.15 152.94 311.89 167.68" />
<path d="m303.65 168.63 1.747 1.747v107.62h-276.35v-107.62l1.747-1.747v9.362h272.85v-9.362m-12.999-31.385v27.747h-246.86v-27.747l-27.747 27.747v126h302.35v-126z" />
</g>
<rect x="1.7219" y="7.9544" width="10.684" height="4.0307" fill="none" />
<g transform="matrix(.04589 0 0 .04589 1.7219 11.733)" fill="#fff" strokeWidth="21.791">
<path d="m9.216 0v-83.2h30.464q6.784 0 12.928 1.408 6.144 1.28 10.752 4.608 4.608 3.2 7.296 8.576 2.816 5.248 2.816 13.056 0 7.68-2.816 13.184-2.688 5.504-7.296 9.088-4.608 3.456-10.624 5.248-6.016 1.664-12.544 1.664h-8.96v26.368zm22.016-43.776h7.936q6.528 0 9.6-3.072 3.2-3.072 3.2-8.704t-3.456-7.936-9.856-2.304h-7.424z" />
<path d="m87.04 0v-83.2h24.576q9.472 0 17.28 2.304 7.936 2.304 13.568 7.296t8.704 12.8q3.2 7.808 3.2 18.816t-3.072 18.944-8.704 13.056q-5.504 5.12-13.184 7.552-7.552 2.432-16.512 2.432zm22.016-17.664h1.28q4.48 0 8.448-1.024 3.968-1.152 6.784-3.84 2.944-2.688 4.608-7.424t1.664-12.032-1.664-11.904-4.608-7.168q-2.816-2.56-6.784-3.456-3.968-1.024-8.448-1.024h-1.28z" />
<path d="m169.22 0v-83.2h54.272v18.432h-32.256v15.872h27.648v18.432h-27.648v30.464z" />
</g>
</svg>
);
}
export function Word({ className }: { className?: string }) {
const classes = className ?? "w-6 h-6 text-muted-foreground inline-flex mr-1";
return (
<svg
className={`${classes}`}
xmlns="http://www.w3.org/2000/svg"
fill="#FFF"
stroke-miterlimit="10"
strokeWidth="2"
viewBox="0 0 96 96"
>
<path
stroke="#979593"
d="M67.1716 7H27c-1.1046 0-2 .8954-2 2v78c0 1.1046.8954 2 2 2h58c1.1046 0 2-.8954 2-2V26.8284c0-.5304-.2107-1.0391-.5858-1.4142L68.5858 7.5858C68.2107 7.2107 67.702 7 67.1716 7z"
/>
<path fill="none" stroke="#979593" d="M67 7v18c0 1.1046.8954 2 2 2h18" />
<path
fill="#C8C6C4"
d="M79 61H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0-6H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0-6H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0-6H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0 24H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1z"
/>
<path
fill="#185ABD"
d="M12 74h32c2.2091 0 4-1.7909 4-4V38c0-2.2091-1.7909-4-4-4H12c-2.2091 0-4 1.7909-4 4v32c0 2.2091 1.7909 4 4 4z"
/>
<path d="M21.6245 60.6455c.0661.522.109.9769.1296 1.3657h.0762c.0306-.3685.0889-.8129.1751-1.3349.0862-.5211.1703-.961.2517-1.319L25.7911 44h4.5702l3.6562 15.1272c.183.7468.3353 1.6973.457 2.8532h.0608c.0508-.7979.1777-1.7184.3809-2.7615L37.8413 44H42l-5.1183 22h-4.86l-3.4885-14.5744c-.1016-.4197-.2158-.9663-.3428-1.6417-.127-.6745-.2057-1.1656-.236-1.4724h-.0608c-.0407.358-.1195.8896-.2364 1.595-.1169.7062-.211 1.2273-.2819 1.565L24.1 66h-4.9357L14 44h4.2349l3.1843 15.3882c.0709.3165.1392.7362.2053 1.2573z" />
</svg>
);
}

View file

@ -11,7 +11,11 @@ import { Card, CardTitle } from "@/components/ui/card";
import SuggestionCard from "@/app/components/suggestions/suggestionCard"; import SuggestionCard from "@/app/components/suggestions/suggestionCard";
import SidePanel from "@/app/components/sidePanel/chatHistorySidePanel"; import SidePanel from "@/app/components/sidePanel/chatHistorySidePanel";
import Loading from "@/app/components/loading/loading"; import Loading from "@/app/components/loading/loading";
import { ChatInputArea, ChatOptions } from "@/app/components/chatInputArea/chatInputArea"; import {
AttachedFileText,
ChatInputArea,
ChatOptions,
} from "@/app/components/chatInputArea/chatInputArea";
import { Suggestion, suggestionsData } from "@/app/components/suggestions/suggestionsData"; import { Suggestion, suggestionsData } from "@/app/components/suggestions/suggestionsData";
import LoginPrompt from "@/app/components/loginPrompt/loginPrompt"; import LoginPrompt from "@/app/components/loginPrompt/loginPrompt";
@ -34,7 +38,7 @@ import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover
interface ChatBodyDataProps { interface ChatBodyDataProps {
chatOptionsData: ChatOptions | null; chatOptionsData: ChatOptions | null;
onConversationIdChange?: (conversationId: string) => void; onConversationIdChange?: (conversationId: string) => void;
setUploadedFiles: (files: string[]) => void; setUploadedFiles: (files: AttachedFileText[]) => void;
isMobileWidth?: boolean; isMobileWidth?: boolean;
isLoggedIn: boolean; isLoggedIn: boolean;
userConfig: UserConfig | null; userConfig: UserConfig | null;
@ -155,6 +159,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
if (images.length > 0) { if (images.length > 0) {
localStorage.setItem("images", JSON.stringify(images)); localStorage.setItem("images", JSON.stringify(images));
} }
window.location.href = `/chat?conversationId=${newConversationId}`; window.location.href = `/chat?conversationId=${newConversationId}`;
} catch (error) { } catch (error) {
console.error("Error creating new conversation:", error); console.error("Error creating new conversation:", error);
@ -401,7 +406,7 @@ export default function Home() {
const [chatOptionsData, setChatOptionsData] = useState<ChatOptions | null>(null); const [chatOptionsData, setChatOptionsData] = useState<ChatOptions | null>(null);
const [isLoading, setLoading] = useState(true); const [isLoading, setLoading] = useState(true);
const [conversationId, setConversationID] = useState<string | null>(null); const [conversationId, setConversationID] = useState<string | null>(null);
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]); const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | null>(null);
const isMobileWidth = useIsMobileWidth(); const isMobileWidth = useIsMobileWidth();
const { userConfig: initialUserConfig, isLoadingUserConfig } = useUserConfig(true); const { userConfig: initialUserConfig, isLoadingUserConfig } = useUserConfig(true);
@ -417,6 +422,12 @@ export default function Home() {
setUserConfig(initialUserConfig); setUserConfig(initialUserConfig);
}, [initialUserConfig]); }, [initialUserConfig]);
useEffect(() => {
if (uploadedFiles) {
localStorage.setItem("uploadedFiles", JSON.stringify(uploadedFiles));
}
}, [uploadedFiles]);
useEffect(() => { useEffect(() => {
fetch("/api/chat/options") fetch("/api/chat/options")
.then((response) => response.json()) .then((response) => response.json())
@ -442,7 +453,7 @@ export default function Home() {
<div className={`${styles.sidePanel}`}> <div className={`${styles.sidePanel}`}>
<SidePanel <SidePanel
conversationId={conversationId} conversationId={conversationId}
uploadedFiles={uploadedFiles} uploadedFiles={[]}
isMobileWidth={isMobileWidth} isMobileWidth={isMobileWidth}
/> />
</div> </div>

View file

@ -137,10 +137,8 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
const deleteSelected = async () => { const deleteSelected = async () => {
let filesToDelete = selectedFiles.length > 0 ? selectedFiles : filteredFiles; let filesToDelete = selectedFiles.length > 0 ? selectedFiles : filteredFiles;
console.log("Delete selected files", filesToDelete);
if (filesToDelete.length === 0) { if (filesToDelete.length === 0) {
console.log("No files to delete");
return; return;
} }
@ -162,15 +160,12 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
// Reset selectedFiles // Reset selectedFiles
setSelectedFiles([]); setSelectedFiles([]);
console.log("Deleted files:", filesToDelete);
} catch (error) { } catch (error) {
console.error("Error deleting files:", error); console.error("Error deleting files:", error);
} }
}; };
const deleteFile = async (filename: string) => { const deleteFile = async (filename: string) => {
console.log("Delete selected file", filename);
try { try {
const response = await fetch( const response = await fetch(
`/api/content/file?filename=${encodeURIComponent(filename)}`, `/api/content/file?filename=${encodeURIComponent(filename)}`,
@ -189,8 +184,6 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
// Remove the file from selectedFiles if it's there // Remove the file from selectedFiles if it's there
setSelectedFiles((prevSelected) => prevSelected.filter((file) => file !== filename)); setSelectedFiles((prevSelected) => prevSelected.filter((file) => file !== filename));
console.log("Deleted file:", filename);
} catch (error) { } catch (error) {
console.error("Error deleting file:", error); console.error("Error deleting file:", error);
} }

View file

@ -5,23 +5,25 @@ import React, { Suspense, useEffect, useRef, useState } from "react";
import SidePanel from "../../components/sidePanel/chatHistorySidePanel"; import SidePanel from "../../components/sidePanel/chatHistorySidePanel";
import ChatHistory from "../../components/chatHistory/chatHistory"; import ChatHistory from "../../components/chatHistory/chatHistory";
import NavMenu from "../../components/navMenu/navMenu";
import Loading from "../../components/loading/loading"; import Loading from "../../components/loading/loading";
import "katex/dist/katex.min.css"; import "katex/dist/katex.min.css";
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../../common/utils"; import { useIsMobileWidth, welcomeConsole } from "../../common/utils";
import { useAuthenticatedData } from "@/app/common/auth"; import { useAuthenticatedData } from "@/app/common/auth";
import { ChatInputArea, ChatOptions } from "@/app/components/chatInputArea/chatInputArea"; import {
AttachedFileText,
ChatInputArea,
ChatOptions,
} from "@/app/components/chatInputArea/chatInputArea";
import { StreamMessage } from "@/app/components/chatMessage/chatMessage"; import { StreamMessage } from "@/app/components/chatMessage/chatMessage";
import { processMessageChunk } from "@/app/common/chatFunctions";
import { AgentData } from "@/app/agents/page"; import { AgentData } from "@/app/agents/page";
interface ChatBodyDataProps { interface ChatBodyDataProps {
chatOptionsData: ChatOptions | null; chatOptionsData: ChatOptions | null;
setTitle: (title: string) => void; setTitle: (title: string) => void;
setUploadedFiles: (files: string[]) => void; setUploadedFiles: (files: AttachedFileText[]) => void;
isMobileWidth?: boolean; isMobileWidth?: boolean;
publicConversationSlug: string; publicConversationSlug: string;
streamedMessages: StreamMessage[]; streamedMessages: StreamMessage[];
@ -50,23 +52,6 @@ function ChatBodyData(props: ChatBodyDataProps) {
} }
}, [images, props.setImages]); }, [images, props.setImages]);
useEffect(() => {
const storedImages = localStorage.getItem("images");
if (storedImages) {
const parsedImages: string[] = JSON.parse(storedImages);
setImages(parsedImages);
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
props.setImages(encodedImages);
localStorage.removeItem("images");
}
const storedMessage = localStorage.getItem("message");
if (storedMessage) {
setProcessingMessage(true);
setQueryToProcess(storedMessage);
}
}, [setQueryToProcess, props.setImages]);
useEffect(() => { useEffect(() => {
if (message) { if (message) {
setProcessingMessage(true); setProcessingMessage(true);
@ -130,14 +115,10 @@ export default function SharedChat() {
const [conversationId, setConversationID] = useState<string | undefined>(undefined); const [conversationId, setConversationID] = useState<string | undefined>(undefined);
const [messages, setMessages] = useState<StreamMessage[]>([]); const [messages, setMessages] = useState<StreamMessage[]>([]);
const [queryToProcess, setQueryToProcess] = useState<string>(""); const [queryToProcess, setQueryToProcess] = useState<string>("");
const [processQuerySignal, setProcessQuerySignal] = useState(false); const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | null>(null);
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined); const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
const [images, setImages] = useState<string[]>([]); const [images, setImages] = useState<string[]>([]);
const locationData = useIPLocationData() || {
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
};
const authenticatedData = useAuthenticatedData(); const authenticatedData = useAuthenticatedData();
const isMobileWidth = useIsMobileWidth(); const isMobileWidth = useIsMobileWidth();
@ -161,6 +142,12 @@ export default function SharedChat() {
setParamSlug(window.location.pathname.split("/").pop() || ""); setParamSlug(window.location.pathname.split("/").pop() || "");
}, []); }, []);
useEffect(() => {
if (uploadedFiles) {
localStorage.setItem("uploadedFiles", JSON.stringify(uploadedFiles));
}
}, [uploadedFiles]);
useEffect(() => { useEffect(() => {
if (queryToProcess && !conversationId) { if (queryToProcess && !conversationId) {
// If the user has not yet started conversing in the chat, create a new conversation // If the user has not yet started conversing in the chat, create a new conversation
@ -173,6 +160,11 @@ export default function SharedChat() {
.then((response) => response.json()) .then((response) => response.json())
.then((data) => { .then((data) => {
setConversationID(data.conversation_id); setConversationID(data.conversation_id);
localStorage.setItem("message", queryToProcess);
if (images.length > 0) {
localStorage.setItem("images", JSON.stringify(images));
}
window.location.href = `/chat?conversationId=${data.conversation_id}`;
}) })
.catch((err) => { .catch((err) => {
console.error(err); console.error(err);
@ -180,105 +172,8 @@ export default function SharedChat() {
}); });
return; return;
} }
if (queryToProcess) {
// Add a new object to the state
const newStreamMessage: StreamMessage = {
rawResponse: "",
trainOfThought: [],
context: [],
onlineContext: {},
codeContext: {},
completed: false,
timestamp: new Date().toISOString(),
rawQuery: queryToProcess || "",
images: images,
};
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
setProcessQuerySignal(true);
}
}, [queryToProcess, conversationId, paramSlug]); }, [queryToProcess, conversationId, paramSlug]);
useEffect(() => {
if (processQuerySignal) {
chat();
}
}, [processQuerySignal]);
async function readChatStream(response: Response) {
if (!response.ok) throw new Error(response.statusText);
if (!response.body) throw new Error("Response body is null");
const reader = response.body.getReader();
const decoder = new TextDecoder();
const eventDelimiter = "␃🔚␗";
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) {
setQueryToProcess("");
setProcessQuerySignal(false);
setImages([]);
break;
}
const chunk = decoder.decode(value, { stream: true });
buffer += chunk;
let newEventIndex;
while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
const event = buffer.slice(0, newEventIndex);
buffer = buffer.slice(newEventIndex + eventDelimiter.length);
if (event) {
const currentMessage = messages.find((message) => !message.completed);
if (!currentMessage) {
console.error("No current message found");
return;
}
processMessageChunk(event, currentMessage);
setMessages([...messages]);
}
}
}
}
async function chat() {
if (!queryToProcess || !conversationId) return;
const chatAPI = "/api/chat?client=web";
const chatAPIBody = {
q: queryToProcess,
conversation_id: conversationId,
stream: true,
...(locationData && {
region: locationData.region,
country: locationData.country,
city: locationData.city,
country_code: locationData.countryCode,
timezone: locationData.timezone,
}),
...(images.length > 0 && { image: images }),
};
const response = await fetch(chatAPI, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(chatAPIBody),
});
try {
await readChatStream(response);
} catch (error) {
console.error(error);
}
}
if (isLoading) { if (isLoading) {
return <Loading />; return <Loading />;
} }
@ -293,7 +188,7 @@ export default function SharedChat() {
<div className={styles.sidePanel}> <div className={styles.sidePanel}>
<SidePanel <SidePanel
conversationId={conversationId ?? null} conversationId={conversationId ?? null}
uploadedFiles={uploadedFiles} uploadedFiles={[]}
isMobileWidth={isMobileWidth} isMobileWidth={isMobileWidth}
/> />
</div> </div>

View file

@ -9,7 +9,7 @@ const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
return ( return (
<textarea <textarea
className={cn( className={cn(
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50", "flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground disabled:cursor-not-allowed disabled:opacity-50",
className, className,
)} )}
ref={ref} ref={ref}

View file

@ -1387,6 +1387,10 @@ class FileObjectAdapters:
async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None): async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent)) return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
@staticmethod
async def async_get_file_objects_by_names(user: KhojUser, file_names: List[str]):
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
@staticmethod @staticmethod
async def async_get_all_file_objects(user: KhojUser): async def async_get_all_file_objects(user: KhojUser):
return await sync_to_async(list)(FileObject.objects.filter(user=user)) return await sync_to_async(list)(FileObject.objects.filter(user=user))

View file

@ -458,7 +458,11 @@ class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict) conversation_log = models.JSONField(default=dict)
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True) client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
# Slug is an app-generated conversation identifier. Need not be unique. Used as display title essentially.
slug = models.CharField(max_length=200, default=None, null=True, blank=True) slug = models.CharField(max_length=200, default=None, null=True, blank=True)
# The title field is explicitly set by the user.
title = models.CharField(max_length=200, default=None, null=True, blank=True) title = models.CharField(max_length=200, default=None, null=True, blank=True)
agent = models.ForeignKey(Agent, on_delete=models.SET_NULL, default=None, null=True, blank=True) agent = models.ForeignKey(Agent, on_delete=models.SET_NULL, default=None, null=True, blank=True)
file_filters = models.JSONField(default=list) file_filters = models.JSONField(default=list)

View file

@ -1,6 +1,5 @@
import logging import logging
import os import tempfile
from datetime import datetime
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from langchain_community.document_loaders import Docx2txtLoader from langchain_community.document_loaders import Docx2txtLoader
@ -58,28 +57,13 @@ class DocxToEntries(TextToEntries):
file_to_text_map = dict() file_to_text_map = dict()
for docx_file in docx_files: for docx_file in docx_files:
try: try:
timestamp_now = datetime.utcnow().timestamp() docx_texts = DocxToEntries.extract_text(docx_files[docx_file])
tmp_file = f"tmp_docx_file_{timestamp_now}.docx"
with open(tmp_file, "wb") as f:
bytes_content = docx_files[docx_file]
f.write(bytes_content)
# Load the content using Docx2txtLoader
loader = Docx2txtLoader(tmp_file)
docx_entries_per_file = loader.load()
# Convert the loaded entries into the desired format
docx_texts = [page.page_content for page in docx_entries_per_file]
entry_to_location_map += zip(docx_texts, [docx_file] * len(docx_texts)) entry_to_location_map += zip(docx_texts, [docx_file] * len(docx_texts))
entries.extend(docx_texts) entries.extend(docx_texts)
file_to_text_map[docx_file] = docx_texts file_to_text_map[docx_file] = docx_texts
except Exception as e: except Exception as e:
logger.warning(f"Unable to process file: {docx_file}. This file will not be indexed.") logger.warning(f"Unable to extract entries from file: {docx_file}")
logger.warning(e, exc_info=True) logger.warning(e, exc_info=True)
finally:
if os.path.exists(f"{tmp_file}"):
os.remove(f"{tmp_file}")
return file_to_text_map, DocxToEntries.convert_docx_entries_to_maps(entries, dict(entry_to_location_map)) return file_to_text_map, DocxToEntries.convert_docx_entries_to_maps(entries, dict(entry_to_location_map))
@staticmethod @staticmethod
@ -103,3 +87,25 @@ class DocxToEntries(TextToEntries):
logger.debug(f"Converted {len(parsed_entries)} DOCX entries to dictionaries") logger.debug(f"Converted {len(parsed_entries)} DOCX entries to dictionaries")
return entries return entries
@staticmethod
def extract_text(docx_file):
"""Extract text from specified DOCX file"""
try:
docx_entry_by_pages = []
# Create temp file with .docx extension that gets auto-deleted
with tempfile.NamedTemporaryFile(suffix=".docx", delete=True) as tmp:
tmp.write(docx_file)
tmp.flush() # Ensure all data is written
# Load the content using Docx2txtLoader
loader = Docx2txtLoader(tmp.name)
docx_entries_per_file = loader.load()
# Convert the loaded entries into the desired format
docx_entry_by_pages = [page.page_content for page in docx_entries_per_file]
except Exception as e:
logger.warning(f"Unable to extract text from file: {docx_file}")
logger.warning(e, exc_info=True)
return docx_entry_by_pages

View file

@ -1,13 +1,10 @@
import base64
import logging import logging
import os import tempfile
from datetime import datetime from io import BytesIO
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
from langchain_community.document_loaders import PyMuPDFLoader from langchain_community.document_loaders import PyMuPDFLoader
# importing FileObjectAdapter so that we can add new files and debug file object db.
# from khoj.database.adapters import FileObjectAdapters
from khoj.database.models import Entry as DbEntry from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries from khoj.processor.content.text_to_entries import TextToEntries
@ -60,31 +57,13 @@ class PdfToEntries(TextToEntries):
entry_to_location_map: List[Tuple[str, str]] = [] entry_to_location_map: List[Tuple[str, str]] = []
for pdf_file in pdf_files: for pdf_file in pdf_files:
try: try:
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path pdf_entries_per_file = PdfToEntries.extract_text(pdf_files[pdf_file])
timestamp_now = datetime.utcnow().timestamp() entry_to_location_map += zip(pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file))
tmp_file = f"tmp_pdf_file_{timestamp_now}.pdf"
with open(f"{tmp_file}", "wb") as f:
bytes = pdf_files[pdf_file]
f.write(bytes)
try:
loader = PyMuPDFLoader(f"{tmp_file}", extract_images=False)
pdf_entries_per_file = [page.page_content for page in loader.load()]
except ImportError:
loader = PyMuPDFLoader(f"{tmp_file}")
pdf_entries_per_file = [
page.page_content for page in loader.load()
] # page_content items list for a given pdf.
entry_to_location_map += zip(
pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file)
) # this is an indexed map of pdf_entries for the pdf.
entries.extend(pdf_entries_per_file) entries.extend(pdf_entries_per_file)
file_to_text_map[pdf_file] = pdf_entries_per_file file_to_text_map[pdf_file] = pdf_entries_per_file
except Exception as e: except Exception as e:
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.") logger.warning(f"Unable to extract entries from file: {pdf_file}")
logger.warning(e, exc_info=True) logger.warning(e, exc_info=True)
finally:
if os.path.exists(f"{tmp_file}"):
os.remove(f"{tmp_file}")
return file_to_text_map, PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map)) return file_to_text_map, PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
@ -109,3 +88,32 @@ class PdfToEntries(TextToEntries):
logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries") logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries")
return entries return entries
@staticmethod
def extract_text(pdf_file):
"""Extract text from specified PDF files"""
try:
# Create temp file with .pdf extension that gets auto-deleted
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=True) as tmpf:
tmpf.write(pdf_file)
tmpf.flush() # Ensure all data is written
# Load the content using PyMuPDFLoader
loader = PyMuPDFLoader(tmpf.name, extract_images=True)
pdf_entries_per_file = loader.load()
# Convert the loaded entries into the desired format
pdf_entry_by_pages = [PdfToEntries.clean_text(page.page_content) for page in pdf_entries_per_file]
except Exception as e:
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
logger.warning(e, exc_info=True)
return pdf_entry_by_pages
@staticmethod
def clean_text(text: str) -> str:
# Remove null bytes
text = text.replace("\x00", "")
# Replace invalid Unicode
text = text.encode("utf-8", errors="ignore").decode("utf-8")
return text

View file

@ -36,6 +36,7 @@ def extract_questions_anthropic(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """
@ -82,9 +83,12 @@ def extract_questions_anthropic(
images=query_images, images=query_images,
model_type=ChatModelOptions.ModelType.ANTHROPIC, model_type=ChatModelOptions.ModelType.ANTHROPIC,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
attached_file_context=query_files,
) )
messages = [ChatMessage(content=prompt, role="user")] messages = []
messages.append(ChatMessage(content=prompt, role="user"))
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
@ -148,6 +152,7 @@ def converse_anthropic(
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """
@ -205,6 +210,7 @@ def converse_anthropic(
query_images=query_images, query_images=query_images,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.ANTHROPIC, model_type=ChatModelOptions.ModelType.ANTHROPIC,
query_files=query_files,
) )
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt) messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)

View file

@ -37,6 +37,7 @@ def extract_questions_gemini(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """
@ -83,9 +84,13 @@ def extract_questions_gemini(
images=query_images, images=query_images,
model_type=ChatModelOptions.ModelType.GOOGLE, model_type=ChatModelOptions.ModelType.GOOGLE,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
attached_file_context=query_files,
) )
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")] messages = []
messages.append(ChatMessage(content=prompt, role="user"))
messages.append(ChatMessage(content=system_prompt, role="system"))
response = gemini_send_message_to_model( response = gemini_send_message_to_model(
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
@ -108,7 +113,13 @@ def extract_questions_gemini(
def gemini_send_message_to_model( def gemini_send_message_to_model(
messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={} messages,
api_key,
model,
response_type="text",
temperature=0,
model_kwargs=None,
tracer={},
): ):
""" """
Send message to model Send message to model
@ -151,6 +162,7 @@ def converse_gemini(
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
query_files: str = None,
tracer={}, tracer={},
): ):
""" """
@ -209,6 +221,7 @@ def converse_gemini(
query_images=query_images, query_images=query_images,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.GOOGLE, model_type=ChatModelOptions.ModelType.GOOGLE,
query_files=query_files,
) )
messages, system_prompt = format_messages_for_gemini(messages, system_prompt) messages, system_prompt = format_messages_for_gemini(messages, system_prompt)

View file

@ -37,6 +37,7 @@ def extract_questions_offline(
max_prompt_size: int = None, max_prompt_size: int = None,
temperature: float = 0.7, temperature: float = 0.7,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
) -> List[str]: ) -> List[str]:
""" """
@ -87,6 +88,7 @@ def extract_questions_offline(
loaded_model=offline_chat_model, loaded_model=offline_chat_model,
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
model_type=ChatModelOptions.ModelType.OFFLINE, model_type=ChatModelOptions.ModelType.OFFLINE,
query_files=query_files,
) )
state.chat_lock.acquire() state.chat_lock.acquire()
@ -152,6 +154,7 @@ def converse_offline(
location_data: LocationData = None, location_data: LocationData = None,
user_name: str = None, user_name: str = None,
agent: Agent = None, agent: Agent = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
) -> Union[ThreadedGenerator, Iterator[str]]: ) -> Union[ThreadedGenerator, Iterator[str]]:
""" """
@ -216,6 +219,7 @@ def converse_offline(
max_prompt_size=max_prompt_size, max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name, tokenizer_name=tokenizer_name,
model_type=ChatModelOptions.ModelType.OFFLINE, model_type=ChatModelOptions.ModelType.OFFLINE,
query_files=query_files,
) )
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})

View file

@ -34,6 +34,7 @@ def extract_questions(
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_enabled: bool = False, vision_enabled: bool = False,
personality_context: Optional[str] = None, personality_context: Optional[str] = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """
@ -79,9 +80,11 @@ def extract_questions(
images=query_images, images=query_images,
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModelOptions.ModelType.OPENAI,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
attached_file_context=query_files,
) )
messages = [ChatMessage(content=prompt, role="user")] messages = []
messages.append(ChatMessage(content=prompt, role="user"))
response = send_message_to_model( response = send_message_to_model(
messages, messages,
@ -148,6 +151,7 @@ def converse(
agent: Agent = None, agent: Agent = None,
query_images: Optional[list[str]] = None, query_images: Optional[list[str]] = None,
vision_available: bool = False, vision_available: bool = False,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """
@ -206,6 +210,7 @@ def converse(
query_images=query_images, query_images=query_images,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=ChatModelOptions.ModelType.OPENAI, model_type=ChatModelOptions.ModelType.OPENAI,
query_files=query_files,
) )
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages}) truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
logger.debug(f"Conversation Context for GPT: {truncated_messages}") logger.debug(f"Conversation Context for GPT: {truncated_messages}")

View file

@ -988,16 +988,27 @@ You are an extremely smart and helpful title generator assistant. Given a user q
# Examples: # Examples:
User: Show a new Calvin and Hobbes quote every morning at 9am. My Current Location: Shanghai, China User: Show a new Calvin and Hobbes quote every morning at 9am. My Current Location: Shanghai, China
Khoj: Your daily Calvin and Hobbes Quote Assistant: Your daily Calvin and Hobbes Quote
User: Notify me when version 2.0.0 of the sentence transformers python package is released. My Current Location: Mexico City, Mexico User: Notify me when version 2.0.0 of the sentence transformers python package is released. My Current Location: Mexico City, Mexico
Khoj: Sentence Transformers Python Package Version 2.0.0 Release Assistant: Sentence Transformers Python Package Version 2.0.0 Release
User: Gather the latest tech news on the first sunday of every month. User: Gather the latest tech news on the first sunday of every month.
Khoj: Your Monthly Dose of Tech News Assistant: Your Monthly Dose of Tech News
User Query: {query} User Query: {query}
Khoj: Assistant:
""".strip()
)
conversation_title_generation = PromptTemplate.from_template(
"""
You are an extremely smart and helpful title generator assistant. Given a conversation, extract the subject of the conversation. Crisp, informative, ten words or less.
Conversation History:
{chat_history}
Assistant:
""".strip() """.strip()
) )

View file

@ -36,6 +36,7 @@ from khoj.utils.helpers import (
is_none_or_empty, is_none_or_empty,
merge_dicts, merge_dicts,
) )
from khoj.utils.rawconfig import FileAttachment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -146,7 +147,7 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
chat_history += f"User: {chat['intent']['query']}\n" chat_history += f"User: {chat['intent']['query']}\n"
if chat["intent"].get("inferred-queries"): if chat["intent"].get("inferred-queries"):
chat_history += f'Khoj: {{"queries": {chat["intent"].get("inferred-queries")}}}\n' chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
chat_history += f"{agent_name}: {chat['message']}\n\n" chat_history += f"{agent_name}: {chat['message']}\n\n"
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")): elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
@ -155,6 +156,16 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")): elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n" chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n" chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
elif chat["by"] == "you":
raw_query_files = chat.get("queryFiles")
if raw_query_files:
query_files: Dict[str, str] = {}
for file in raw_query_files:
query_files[file["name"]] = file["content"]
query_file_context = gather_raw_query_files(query_files)
chat_history += f"User: {query_file_context}\n"
return chat_history return chat_history
@ -243,8 +254,9 @@ def save_to_conversation_log(
conversation_id: str = None, conversation_id: str = None,
automation_id: str = None, automation_id: str = None,
query_images: List[str] = None, query_images: List[str] = None,
tracer: Dict[str, Any] = {}, raw_query_files: List[FileAttachment] = [],
train_of_thought: List[Any] = [], train_of_thought: List[Any] = [],
tracer: Dict[str, Any] = {},
): ):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S") user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
turn_id = tracer.get("mid") or str(uuid.uuid4()) turn_id = tracer.get("mid") or str(uuid.uuid4())
@ -255,6 +267,7 @@ def save_to_conversation_log(
"created": user_message_time, "created": user_message_time,
"images": query_images, "images": query_images,
"turnId": turn_id, "turnId": turn_id,
"queryFiles": [file.model_dump(mode="json") for file in raw_query_files],
}, },
khoj_message_metadata={ khoj_message_metadata={
"context": compiled_references, "context": compiled_references,
@ -289,25 +302,50 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
) )
def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool): def construct_structured_message(
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str
):
""" """
Format messages into appropriate multimedia format for supported chat model types Format messages into appropriate multimedia format for supported chat model types
""" """
if not images or not vision_enabled:
return message
if model_type in [ if model_type in [
ChatModelOptions.ModelType.OPENAI, ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE, ChatModelOptions.ModelType.GOOGLE,
ChatModelOptions.ModelType.ANTHROPIC, ChatModelOptions.ModelType.ANTHROPIC,
]: ]:
return [ constructed_messages: List[Any] = [
{"type": "text", "text": message}, {"type": "text", "text": message},
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
] ]
if not is_none_or_empty(attached_file_context):
constructed_messages.append({"type": "text", "text": attached_file_context})
if vision_enabled and images:
for image in images:
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
return constructed_messages
if not is_none_or_empty(attached_file_context):
return f"{attached_file_context}\n\n{message}"
return message return message
def gather_raw_query_files(
query_files: Dict[str, str],
):
"""
Gather contextual data from the given (raw) files
"""
if len(query_files) == 0:
return ""
contextual_data = " ".join(
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in query_files.items()]
)
return f"I have attached the following files:\n\n{contextual_data}"
def generate_chatml_messages_with_context( def generate_chatml_messages_with_context(
user_message, user_message,
system_message=None, system_message=None,
@ -320,6 +358,7 @@ def generate_chatml_messages_with_context(
vision_enabled=False, vision_enabled=False,
model_type="", model_type="",
context_message="", context_message="",
query_files: str = None,
): ):
"""Generate chat messages with appropriate context from previous conversation to send to the chat model""" """Generate chat messages with appropriate context from previous conversation to send to the chat model"""
# Set max prompt size from user config or based on pre-configured for model and machine specs # Set max prompt size from user config or based on pre-configured for model and machine specs
@ -336,6 +375,8 @@ def generate_chatml_messages_with_context(
chatml_messages: List[ChatMessage] = [] chatml_messages: List[ChatMessage] = []
for chat in conversation_log.get("chat", []): for chat in conversation_log.get("chat", []):
message_context = "" message_context = ""
message_attached_files = ""
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""): if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
message_context += chat.get("intent").get("inferred-queries")[0] message_context += chat.get("intent").get("inferred-queries")[0]
if not is_none_or_empty(chat.get("context")): if not is_none_or_empty(chat.get("context")):
@ -347,14 +388,27 @@ def generate_chatml_messages_with_context(
} }
) )
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n" message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
if chat.get("queryFiles"):
raw_query_files = chat.get("queryFiles")
query_files_dict = dict()
for file in raw_query_files:
query_files_dict[file["name"]] = file["content"]
message_attached_files = gather_raw_query_files(query_files_dict)
chatml_messages.append(ChatMessage(content=message_attached_files, role="user"))
if not is_none_or_empty(chat.get("onlineContext")): if not is_none_or_empty(chat.get("onlineContext")):
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}" message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
if not is_none_or_empty(message_context): if not is_none_or_empty(message_context):
reconstructed_context_message = ChatMessage(content=message_context, role="user") reconstructed_context_message = ChatMessage(content=message_context, role="user")
chatml_messages.insert(0, reconstructed_context_message) chatml_messages.insert(0, reconstructed_context_message)
role = "user" if chat["by"] == "you" else "assistant" role = "user" if chat["by"] == "you" else "assistant"
message_content = construct_structured_message(chat["message"], chat.get("images"), model_type, vision_enabled) message_content = construct_structured_message(
chat["message"], chat.get("images"), model_type, vision_enabled, attached_file_context=query_files
)
reconstructed_message = ChatMessage(content=message_content, role=role) reconstructed_message = ChatMessage(content=message_content, role=role)
chatml_messages.insert(0, reconstructed_message) chatml_messages.insert(0, reconstructed_message)
@ -366,14 +420,18 @@ def generate_chatml_messages_with_context(
if not is_none_or_empty(user_message): if not is_none_or_empty(user_message):
messages.append( messages.append(
ChatMessage( ChatMessage(
content=construct_structured_message(user_message, query_images, model_type, vision_enabled), content=construct_structured_message(
user_message, query_images, model_type, vision_enabled, query_files
),
role="user", role="user",
) )
) )
if not is_none_or_empty(context_message): if not is_none_or_empty(context_message):
messages.append(ChatMessage(content=context_message, role="user")) messages.append(ChatMessage(content=context_message, role="user"))
if len(chatml_messages) > 0: if len(chatml_messages) > 0:
messages += chatml_messages messages += chatml_messages
if not is_none_or_empty(system_message): if not is_none_or_empty(system_message):
messages.append(ChatMessage(content=system_message, role="system")) messages.append(ChatMessage(content=system_message, role="system"))

View file

@ -28,6 +28,7 @@ async def text_to_image(
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
agent: Agent = None, agent: Agent = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
status_code = 200 status_code = 200
@ -69,6 +70,7 @@ async def text_to_image(
query_images=query_images, query_images=query_images,
user=user, user=user,
agent=agent, agent=agent,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )

View file

@ -68,6 +68,7 @@ async def search_online(
query_images: List[str] = None, query_images: List[str] = None,
previous_subqueries: Set = set(), previous_subqueries: Set = set(),
agent: Agent = None, agent: Agent = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
query += " ".join(custom_filters) query += " ".join(custom_filters)
@ -78,7 +79,14 @@ async def search_online(
# Breakdown the query into subqueries to get the correct answer # Breakdown the query into subqueries to get the correct answer
new_subqueries = await generate_online_subqueries( new_subqueries = await generate_online_subqueries(
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer query,
conversation_history,
location,
user,
query_images=query_images,
agent=agent,
tracer=tracer,
query_files=query_files,
) )
subqueries = list(new_subqueries - previous_subqueries) subqueries = list(new_subqueries - previous_subqueries)
response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {} response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}
@ -169,13 +177,21 @@ async def read_webpages(
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {},
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ, max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
query_files: str = None,
tracer: dict = {},
): ):
"Infer web pages to read from the query and extract relevant information from them" "Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read") logger.info(f"Inferring web pages to read")
urls = await infer_webpage_urls( urls = await infer_webpage_urls(
query, conversation_history, location, user, query_images, agent=agent, tracer=tracer query,
conversation_history,
location,
user,
query_images,
agent=agent,
query_files=query_files,
tracer=tracer,
) )
# Get the top 10 web pages to read # Get the top 10 web pages to read

View file

@ -36,6 +36,7 @@ async def run_code(
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
sandbox_url: str = SANDBOX_URL, sandbox_url: str = SANDBOX_URL,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
# Generate Code # Generate Code
@ -53,6 +54,7 @@ async def run_code(
query_images, query_images,
agent, agent,
tracer, tracer,
query_files,
) )
except Exception as e: except Exception as e:
raise ValueError(f"Failed to generate code for {query} with error: {e}") raise ValueError(f"Failed to generate code for {query} with error: {e}")
@ -82,6 +84,7 @@ async def generate_python_code(
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
tracer: dict = {}, tracer: dict = {},
query_files: str = None,
) -> List[str]: ) -> List[str]:
location = f"{location_data}" if location_data else "Unknown" location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else "" username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
@ -109,6 +112,7 @@ async def generate_python_code(
response_type="json_object", response_type="json_object",
user=user, user=user,
tracer=tracer, tracer=tracer,
query_files=query_files,
) )
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list

View file

@ -351,6 +351,7 @@ async def extract_references_and_questions(
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
previous_inferred_queries: Set = set(), previous_inferred_queries: Set = set(),
agent: Agent = None, agent: Agent = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
user = request.user.object if request.user.is_authenticated else None user = request.user.object if request.user.is_authenticated else None
@ -425,6 +426,7 @@ async def extract_references_and_questions(
user=user, user=user,
max_prompt_size=conversation_config.max_prompt_size, max_prompt_size=conversation_config.max_prompt_size,
personality_context=personality_context, personality_context=personality_context,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI: elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
@ -443,6 +445,7 @@ async def extract_references_and_questions(
query_images=query_images, query_images=query_images,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC: elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
@ -458,6 +461,7 @@ async def extract_references_and_questions(
user=user, user=user,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
@ -474,6 +478,7 @@ async def extract_references_and_questions(
user=user, user=user,
vision_enabled=vision_enabled, vision_enabled=vision_enabled,
personality_context=personality_context, personality_context=personality_context,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )

View file

@ -19,7 +19,6 @@ from khoj.database.adapters import (
AgentAdapters, AgentAdapters,
ConversationAdapters, ConversationAdapters,
EntryAdapters, EntryAdapters,
FileObjectAdapters,
PublicConversationAdapters, PublicConversationAdapters,
aget_user_name, aget_user_name,
) )
@ -45,12 +44,13 @@ from khoj.routers.helpers import (
ConversationCommandRateLimiter, ConversationCommandRateLimiter,
DeleteMessageRequestBody, DeleteMessageRequestBody,
FeedbackData, FeedbackData,
acreate_title_from_history,
agenerate_chat_response, agenerate_chat_response,
aget_relevant_information_sources, aget_relevant_information_sources,
aget_relevant_output_modes, aget_relevant_output_modes,
construct_automation_created_message, construct_automation_created_message,
create_automation, create_automation,
extract_relevant_info, gather_raw_query_files,
generate_excalidraw_diagram, generate_excalidraw_diagram,
generate_summary_from_files, generate_summary_from_files,
get_conversation_command, get_conversation_command,
@ -76,7 +76,12 @@ from khoj.utils.helpers import (
get_device, get_device,
is_none_or_empty, is_none_or_empty,
) )
from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, LocationData from khoj.utils.rawconfig import (
ChatRequestBody,
FileFilterRequest,
FilesFilterRequest,
LocationData,
)
# Initialize Router # Initialize Router
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -374,7 +379,7 @@ def fork_public_conversation(
{ {
"status": "ok", "status": "ok",
"next_url": redirect_uri, "next_url": redirect_uri,
"conversation_id": new_conversation.id, "conversation_id": str(new_conversation.id),
} }
), ),
) )
@ -530,6 +535,32 @@ async def set_conversation_title(
) )
@api_chat.post("/title")
@requires(["authenticated"])
async def generate_chat_title(
request: Request,
common: CommonQueryParams,
conversation_id: str,
):
user: KhojUser = request.user.object
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
# Conversation.title is explicitly set by the user. Do not override.
if conversation.title:
return {"status": "ok", "title": conversation.title}
if not conversation:
raise HTTPException(status_code=404, detail="Conversation not found")
new_title = await acreate_title_from_history(request.user.object, conversation=conversation)
conversation.slug = new_title
conversation.asave()
return {"status": "ok", "title": new_title}
@api_chat.delete("/conversation/message", response_class=Response) @api_chat.delete("/conversation/message", response_class=Response)
@requires(["authenticated"]) @requires(["authenticated"])
def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response: def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response:
@ -571,6 +602,7 @@ async def chat(
country_code = body.country_code or get_country_code_from_timezone(body.timezone) country_code = body.country_code or get_country_code_from_timezone(body.timezone)
timezone = body.timezone timezone = body.timezone
raw_images = body.images raw_images = body.images
raw_query_files = body.files
async def event_generator(q: str, images: list[str]): async def event_generator(q: str, images: list[str]):
start_time = time.perf_counter() start_time = time.perf_counter()
@ -582,6 +614,7 @@ async def chat(
q = unquote(q) q = unquote(q)
train_of_thought = [] train_of_thought = []
nonlocal conversation_id nonlocal conversation_id
nonlocal raw_query_files
tracer: dict = { tracer: dict = {
"mid": turn_id, "mid": turn_id,
@ -601,6 +634,11 @@ async def chat(
if uploaded_image: if uploaded_image:
uploaded_images.append(uploaded_image) uploaded_images.append(uploaded_image)
query_files: Dict[str, str] = {}
if raw_query_files:
for file in raw_query_files:
query_files[file.name] = file.content
async def send_event(event_type: ChatEvent, data: str | dict): async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft, train_of_thought nonlocal connection_alive, ttft, train_of_thought
if not connection_alive or await request.is_disconnected(): if not connection_alive or await request.is_disconnected():
@ -711,6 +749,8 @@ async def chat(
## Extract Document References ## Extract Document References
compiled_references: List[Any] = [] compiled_references: List[Any] = []
inferred_queries: List[Any] = [] inferred_queries: List[Any] = []
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
attached_file_context = gather_raw_query_files(query_files)
if conversation_commands == [ConversationCommand.Default] or is_automated_task: if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources( conversation_commands = await aget_relevant_information_sources(
@ -720,6 +760,7 @@ async def chat(
user=user, user=user,
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
query_files=attached_file_context,
tracer=tracer, tracer=tracer,
) )
@ -765,6 +806,7 @@ async def chat(
user_name=user_name, user_name=user_name,
location=location, location=location,
file_filters=conversation.file_filters if conversation else [], file_filters=conversation.file_filters if conversation else [],
query_files=attached_file_context,
tracer=tracer, tracer=tracer,
): ):
if isinstance(research_result, InformationCollectionIteration): if isinstance(research_result, InformationCollectionIteration):
@ -804,10 +846,6 @@ async def chat(
response_log = "No files selected for summarization. Please add files using the section on the left." response_log = "No files selected for summarization. Please add files using the section on the left."
async for result in send_llm_response(response_log): async for result in send_llm_response(response_log):
yield result yield result
elif len(file_filters) > 1 and not agent_has_entries:
response_log = "Only one file can be selected for summarization."
async for result in send_llm_response(response_log):
yield result
else: else:
async for response in generate_summary_from_files( async for response in generate_summary_from_files(
q=q, q=q,
@ -817,6 +855,7 @@ async def chat(
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
query_files=attached_file_context,
tracer=tracer, tracer=tracer,
): ):
if isinstance(response, dict) and ChatEvent.STATUS in response: if isinstance(response, dict) and ChatEvent.STATUS in response:
@ -837,8 +876,9 @@ async def chat(
client_application=request.user.client_app, client_application=request.user.client_app,
conversation_id=conversation_id, conversation_id=conversation_id,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
) )
return return
@ -882,8 +922,9 @@ async def chat(
inferred_queries=[query_to_run], inferred_queries=[query_to_run],
automation_id=automation.id, automation_id=automation.id,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
) )
async for result in send_llm_response(llm_response): async for result in send_llm_response(llm_response):
yield result yield result
@ -905,6 +946,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
query_files=attached_file_context,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -950,6 +992,7 @@ async def chat(
custom_filters, custom_filters,
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
query_files=attached_file_context,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -975,6 +1018,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
query_files=attached_file_context,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -1015,6 +1059,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS), partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
query_files=attached_file_context,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -1055,6 +1100,7 @@ async def chat(
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images, query_images=uploaded_images,
agent=agent, agent=agent,
query_files=attached_file_context,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -1086,8 +1132,10 @@ async def chat(
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
attached_file_context=attached_file_context,
raw_query_files=raw_query_files,
tracer=tracer,
) )
content_obj = { content_obj = {
"intentType": intent_type, "intentType": intent_type,
@ -1116,6 +1164,7 @@ async def chat(
user=user, user=user,
agent=agent, agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS), send_status_func=partial(send_event, ChatEvent.STATUS),
query_files=attached_file_context,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -1144,8 +1193,10 @@ async def chat(
compiled_references=compiled_references, compiled_references=compiled_references,
online_results=online_results, online_results=online_results,
query_images=uploaded_images, query_images=uploaded_images,
tracer=tracer,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
attached_file_context=attached_file_context,
raw_query_files=raw_query_files,
tracer=tracer,
) )
async for result in send_llm_response(json.dumps(content_obj)): async for result in send_llm_response(json.dumps(content_obj)):
@ -1171,8 +1222,10 @@ async def chat(
user_name, user_name,
researched_results, researched_results,
uploaded_images, uploaded_images,
tracer,
train_of_thought, train_of_thought,
attached_file_context,
raw_query_files,
tracer,
) )
# Send Response # Send Response

View file

@ -36,16 +36,18 @@ from khoj.database.models import (
LocalPlaintextConfig, LocalPlaintextConfig,
NotionConfig, NotionConfig,
) )
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ApiIndexedDataLimiter, ApiIndexedDataLimiter,
CommonQueryParams, CommonQueryParams,
configure_content, configure_content,
get_file_content,
get_user_config, get_user_config,
update_telemetry_state, update_telemetry_state,
) )
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import SearchModels from khoj.utils.config import SearchModels
from khoj.utils.helpers import get_file_type
from khoj.utils.rawconfig import ( from khoj.utils.rawconfig import (
ContentConfig, ContentConfig,
FullConfig, FullConfig,
@ -375,6 +377,75 @@ async def delete_content_source(
return {"status": "ok"} return {"status": "ok"}
@api_content.post("/convert", status_code=200)
@requires(["authenticated"])
async def convert_documents(
request: Request,
files: List[UploadFile],
client: Optional[str] = None,
):
MAX_FILE_SIZE_MB = 10 # 10MB limit
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
converted_files = []
supported_files = ["org", "markdown", "pdf", "plaintext", "docx"]
for file in files:
# Check file size first
file_size = 0
content = await file.read()
file_size = len(content)
await file.seek(0) # Reset file pointer
if file_size > MAX_FILE_SIZE_BYTES:
logger.warning(
f"Skipped converting oversized file ({file_size / 1024 / 1024:.1f}MB) sent by {client} client: {file.filename}"
)
continue
file_data = get_file_content(file)
if file_data.file_type in supported_files:
extracted_content = (
file_data.content.decode(file_data.encoding) if file_data.encoding else file_data.content
)
if file_data.file_type == "docx":
entries_per_page = DocxToEntries.extract_text(file_data.content)
annotated_pages = [
f"Page {index} of {file_data.name}:\n\n{entry}" for index, entry in enumerate(entries_per_page)
]
extracted_content = "\n".join(annotated_pages)
elif file_data.file_type == "pdf":
entries_per_page = PdfToEntries.extract_text(file_data.content)
annotated_pages = [
f"Page {index} of {file_data.name}:\n\n{entry}" for index, entry in enumerate(entries_per_page)
]
extracted_content = "\n".join(annotated_pages)
size_in_bytes = len(extracted_content.encode("utf-8"))
converted_files.append(
{
"name": file_data.name,
"content": extracted_content,
"file_type": file_data.file_type,
"size": size_in_bytes,
}
)
else:
logger.warning(f"Skipped converting unsupported file type sent by {client} client: {file.filename}")
update_telemetry_state(
request=request,
telemetry_type="api",
api="convert_documents",
client=client,
)
return Response(content=json.dumps(converted_files), media_type="application/json", status_code=200)
async def indexer( async def indexer(
request: Request, request: Request,
files: list[UploadFile], files: list[UploadFile],
@ -398,12 +469,13 @@ async def indexer(
try: try:
logger.info(f"📬 Updating content index via API call by {client} client") logger.info(f"📬 Updating content index via API call by {client} client")
for file in files: for file in files:
file_content = file.file.read() file_data = get_file_content(file)
file_type, encoding = get_file_type(file.content_type, file_content) if file_data.file_type in index_files:
if file_type in index_files: index_files[file_data.file_type][file_data.name] = (
index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content file_data.content.decode(file_data.encoding) if file_data.encoding else file_data.content
)
else: else:
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}") logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file_data.name}")
indexer_input = IndexerInput( indexer_input = IndexerInput(
org=index_files["org"], org=index_files["org"],

View file

@ -105,6 +105,7 @@ from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import ( from khoj.utils.helpers import (
LRU, LRU,
ConversationCommand, ConversationCommand,
get_file_type,
is_none_or_empty, is_none_or_empty,
is_valid_url, is_valid_url,
log_telemetry, log_telemetry,
@ -112,7 +113,7 @@ from khoj.utils.helpers import (
timer, timer,
tool_descriptions_for_llm, tool_descriptions_for_llm,
) )
from khoj.utils.rawconfig import LocationData from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -168,6 +169,12 @@ async def is_ready_to_chat(user: KhojUser):
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.") raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
def get_file_content(file: UploadFile):
file_content = file.file.read()
file_type, encoding = get_file_type(file.content_type, file_content)
return FileData(name=file.filename, content=file_content, file_type=file_type, encoding=encoding)
def update_telemetry_state( def update_telemetry_state(
request: Request, request: Request,
telemetry_type: str, telemetry_type: str,
@ -249,6 +256,39 @@ async def agenerate_chat_response(*args):
return await loop.run_in_executor(executor, generate_chat_response, *args) return await loop.run_in_executor(executor, generate_chat_response, *args)
def gather_raw_query_files(
query_files: Dict[str, str],
):
"""
Gather contextual data from the given (raw) files
"""
if len(query_files) == 0:
return ""
contextual_data = " ".join(
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in query_files.items()]
)
return f"I have attached the following files:\n\n{contextual_data}"
async def acreate_title_from_history(
user: KhojUser,
conversation: Conversation,
):
"""
Create a title from the given conversation history
"""
chat_history = construct_chat_history(conversation.conversation_log)
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
with timer("Chat actor: Generate title from conversation history", logger):
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
return response.strip()
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str: async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
""" """
Create a title from the given query Create a title from the given query
@ -294,6 +334,7 @@ async def aget_relevant_information_sources(
user: KhojUser, user: KhojUser,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
""" """
@ -331,6 +372,7 @@ async def aget_relevant_information_sources(
relevant_tools_prompt, relevant_tools_prompt,
response_type="json_object", response_type="json_object",
user=user, user=user,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )
@ -440,6 +482,7 @@ async def infer_webpage_urls(
user: KhojUser, user: KhojUser,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
) -> List[str]: ) -> List[str]:
""" """
@ -469,6 +512,7 @@ async def infer_webpage_urls(
query_images=query_images, query_images=query_images,
response_type="json_object", response_type="json_object",
user=user, user=user,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )
@ -494,6 +538,7 @@ async def generate_online_subqueries(
user: KhojUser, user: KhojUser,
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
) -> Set[str]: ) -> Set[str]:
""" """
@ -523,6 +568,7 @@ async def generate_online_subqueries(
query_images=query_images, query_images=query_images,
response_type="json_object", response_type="json_object",
user=user, user=user,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )
@ -645,26 +691,38 @@ async def generate_summary_from_files(
query_images: List[str] = None, query_images: List[str] = None,
agent: Agent = None, agent: Agent = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
try: try:
file_object = None file_objects = None
if await EntryAdapters.aagent_has_entries(agent): if await EntryAdapters.aagent_has_entries(agent):
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent) file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
if len(file_names) > 0: if len(file_names) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent) file_objects = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent)
if len(file_filters) > 0: if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files):
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0]) response_log = "Sorry, I couldn't find anything to summarize."
if len(file_object) == 0:
response_log = "Sorry, I couldn't find the full text of this file."
yield response_log yield response_log
return return
contextual_data = " ".join([file.raw_text for file in file_object])
contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects])
if query_files:
contextual_data += f"\n\n{query_files}"
if not q: if not q:
q = "Create a general summary of the file" q = "Create a general summary of the file"
async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"):
file_names = [file.file_name for file in file_objects]
file_names.extend(file_filters)
all_file_names = ""
for file_name in file_names:
all_file_names += f"- {file_name}\n"
async for result in send_status_func(f"**Constructing Summary Using:**\n{all_file_names}"):
yield {ChatEvent.STATUS: result} yield {ChatEvent.STATUS: result}
response = await extract_relevant_summary( response = await extract_relevant_summary(
@ -694,6 +752,7 @@ async def generate_excalidraw_diagram(
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
if send_status_func: if send_status_func:
@ -709,6 +768,7 @@ async def generate_excalidraw_diagram(
query_images=query_images, query_images=query_images,
user=user, user=user,
agent=agent, agent=agent,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )
@ -735,6 +795,7 @@ async def generate_better_diagram_description(
query_images: List[str] = None, query_images: List[str] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
) -> str: ) -> str:
""" """
@ -773,7 +834,11 @@ async def generate_better_diagram_description(
with timer("Chat actor: Generate better diagram description", logger): with timer("Chat actor: Generate better diagram description", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer improve_diagram_description_prompt,
query_images=query_images,
user=user,
query_files=query_files,
tracer=tracer,
) )
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@ -820,6 +885,7 @@ async def generate_better_image_prompt(
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
user: KhojUser = None, user: KhojUser = None,
agent: Agent = None, agent: Agent = None,
query_files: str = "",
tracer: dict = {}, tracer: dict = {},
) -> str: ) -> str:
""" """
@ -868,7 +934,7 @@ async def generate_better_image_prompt(
with timer("Chat actor: Generate contextual image prompt", logger): with timer("Chat actor: Generate contextual image prompt", logger):
response = await send_message_to_model_wrapper( response = await send_message_to_model_wrapper(
image_prompt, query_images=query_images, user=user, tracer=tracer image_prompt, query_images=query_images, user=user, query_files=query_files, tracer=tracer
) )
response = response.strip() response = response.strip()
if response.startswith(('"', "'")) and response.endswith(('"', "'")): if response.startswith(('"', "'")) and response.endswith(('"', "'")):
@ -884,6 +950,7 @@ async def send_message_to_model_wrapper(
user: KhojUser = None, user: KhojUser = None,
query_images: List[str] = None, query_images: List[str] = None,
context: str = "", context: str = "",
query_files: str = None,
tracer: dict = {}, tracer: dict = {},
): ):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user) conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
@ -923,6 +990,7 @@ async def send_message_to_model_wrapper(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
query_files=query_files,
) )
return send_message_to_model_offline( return send_message_to_model_offline(
@ -949,6 +1017,7 @@ async def send_message_to_model_wrapper(
vision_enabled=vision_available, vision_enabled=vision_available,
query_images=query_images, query_images=query_images,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
query_files=query_files,
) )
return send_message_to_model( return send_message_to_model(
@ -971,6 +1040,7 @@ async def send_message_to_model_wrapper(
vision_enabled=vision_available, vision_enabled=vision_available,
query_images=query_images, query_images=query_images,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
query_files=query_files,
) )
return anthropic_send_message_to_model( return anthropic_send_message_to_model(
@ -992,6 +1062,7 @@ async def send_message_to_model_wrapper(
vision_enabled=vision_available, vision_enabled=vision_available,
query_images=query_images, query_images=query_images,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
query_files=query_files,
) )
return gemini_send_message_to_model( return gemini_send_message_to_model(
@ -1006,6 +1077,7 @@ def send_message_to_model_wrapper_sync(
system_message: str = "", system_message: str = "",
response_type: str = "text", response_type: str = "text",
user: KhojUser = None, user: KhojUser = None,
query_files: str = "",
tracer: dict = {}, tracer: dict = {},
): ):
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user) conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
@ -1030,6 +1102,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
query_files=query_files,
) )
return send_message_to_model_offline( return send_message_to_model_offline(
@ -1051,6 +1124,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
query_files=query_files,
) )
openai_response = send_message_to_model( openai_response = send_message_to_model(
@ -1072,6 +1146,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
query_files=query_files,
) )
return anthropic_send_message_to_model( return anthropic_send_message_to_model(
@ -1091,6 +1166,7 @@ def send_message_to_model_wrapper_sync(
max_prompt_size=max_tokens, max_prompt_size=max_tokens,
vision_enabled=vision_available, vision_enabled=vision_available,
model_type=conversation_config.model_type, model_type=conversation_config.model_type,
query_files=query_files,
) )
return gemini_send_message_to_model( return gemini_send_message_to_model(
@ -1120,8 +1196,10 @@ def generate_chat_response(
user_name: Optional[str] = None, user_name: Optional[str] = None,
meta_research: str = "", meta_research: str = "",
query_images: Optional[List[str]] = None, query_images: Optional[List[str]] = None,
tracer: dict = {},
train_of_thought: List[Any] = [], train_of_thought: List[Any] = [],
query_files: str = None,
raw_query_files: List[FileAttachment] = None,
tracer: dict = {},
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]: ) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables # Initialize Variables
chat_response = None chat_response = None
@ -1142,8 +1220,9 @@ def generate_chat_response(
client_application=client_application, client_application=client_application,
conversation_id=conversation_id, conversation_id=conversation_id,
query_images=query_images, query_images=query_images,
tracer=tracer,
train_of_thought=train_of_thought, train_of_thought=train_of_thought,
raw_query_files=raw_query_files,
tracer=tracer,
) )
query_to_run = q query_to_run = q
@ -1177,6 +1256,7 @@ def generate_chat_response(
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )
@ -1202,6 +1282,7 @@ def generate_chat_response(
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )
@ -1224,6 +1305,7 @@ def generate_chat_response(
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
vision_available=vision_available, vision_available=vision_available,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE: elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
@ -1243,7 +1325,9 @@ def generate_chat_response(
location_data=location_data, location_data=location_data,
user_name=user_name, user_name=user_name,
agent=agent, agent=agent,
query_images=query_images,
vision_available=vision_available, vision_available=vision_available,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )
@ -1256,23 +1340,6 @@ def generate_chat_response(
return chat_response, metadata return chat_response, metadata
class ChatRequestBody(BaseModel):
q: str
n: Optional[int] = 7
d: Optional[float] = None
stream: Optional[bool] = False
title: Optional[str] = None
conversation_id: Optional[str] = None
turn_id: Optional[str] = None
city: Optional[str] = None
region: Optional[str] = None
country: Optional[str] = None
country_code: Optional[str] = None
timezone: Optional[str] = None
images: Optional[list[str]] = None
create_new: Optional[bool] = False
class DeleteMessageRequestBody(BaseModel): class DeleteMessageRequestBody(BaseModel):
conversation_id: str conversation_id: str
turn_id: str turn_id: str

View file

@ -11,6 +11,7 @@ from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import ( from khoj.processor.conversation.utils import (
InformationCollectionIteration, InformationCollectionIteration,
clean_json, clean_json,
construct_chat_history,
construct_iteration_history, construct_iteration_history,
construct_tool_chat_history, construct_tool_chat_history,
) )
@ -19,8 +20,6 @@ from khoj.processor.tools.run_code import run_code
from khoj.routers.api import extract_references_and_questions from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import ( from khoj.routers.helpers import (
ChatEvent, ChatEvent,
construct_chat_history,
extract_relevant_info,
generate_summary_from_files, generate_summary_from_files,
send_message_to_model_wrapper, send_message_to_model_wrapper,
) )
@ -47,6 +46,7 @@ async def apick_next_tool(
max_iterations: int = 5, max_iterations: int = 5,
send_status_func: Optional[Callable] = None, send_status_func: Optional[Callable] = None,
tracer: dict = {}, tracer: dict = {},
query_files: str = None,
): ):
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately.""" """Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
@ -92,6 +92,7 @@ async def apick_next_tool(
response_type="json_object", response_type="json_object",
user=user, user=user,
query_images=query_images, query_images=query_images,
query_files=query_files,
tracer=tracer, tracer=tracer,
) )
except Exception as e: except Exception as e:
@ -151,6 +152,7 @@ async def execute_information_collection(
location: LocationData = None, location: LocationData = None,
file_filters: List[str] = [], file_filters: List[str] = [],
tracer: dict = {}, tracer: dict = {},
query_files: str = None,
): ):
current_iteration = 0 current_iteration = 0
MAX_ITERATIONS = 5 MAX_ITERATIONS = 5
@ -174,6 +176,7 @@ async def execute_information_collection(
MAX_ITERATIONS, MAX_ITERATIONS,
send_status_func, send_status_func,
tracer=tracer, tracer=tracer,
query_files=query_files,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -204,6 +207,7 @@ async def execute_information_collection(
previous_inferred_queries=previous_inferred_queries, previous_inferred_queries=previous_inferred_queries,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
query_files=query_files,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -265,6 +269,7 @@ async def execute_information_collection(
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
tracer=tracer, tracer=tracer,
query_files=query_files,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]
@ -295,6 +300,7 @@ async def execute_information_collection(
send_status_func, send_status_func,
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
query_files=query_files,
tracer=tracer, tracer=tracer,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
@ -320,6 +326,7 @@ async def execute_information_collection(
query_images=query_images, query_images=query_images,
agent=agent, agent=agent,
send_status_func=send_status_func, send_status_func=send_status_func,
query_files=query_files,
): ):
if isinstance(result, dict) and ChatEvent.STATUS in result: if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS] yield result[ChatEvent.STATUS]

View file

@ -138,6 +138,38 @@ class SearchResponse(ConfigBase):
corpus_id: str corpus_id: str
class FileData(BaseModel):
name: str
content: bytes
file_type: str
encoding: str | None = None
class FileAttachment(BaseModel):
name: str
content: str
file_type: str
size: int
class ChatRequestBody(BaseModel):
q: str
n: Optional[int] = 7
d: Optional[float] = None
stream: Optional[bool] = False
title: Optional[str] = None
conversation_id: Optional[str] = None
turn_id: Optional[str] = None
city: Optional[str] = None
region: Optional[str] = None
country: Optional[str] = None
country_code: Optional[str] = None
timezone: Optional[str] = None
images: Optional[list[str]] = None
files: Optional[list[FileAttachment]] = []
create_new: Optional[bool] = False
class Entry: class Entry:
raw: str raw: str
compiled: str compiled: str

View file

@ -337,7 +337,6 @@ def test_summarize_one_file(client_offline_chat, default_user2: KhojUser):
# Assert # Assert
assert response_message != "" assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left." assert response_message != "No files selected for summarization. Please add files using the section on the left."
assert response_message != "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
@ -375,7 +374,6 @@ def test_summarize_extra_text(client_offline_chat, default_user2: KhojUser):
# Assert # Assert
assert response_message != "" assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left." assert response_message != "No files selected for summarization. Please add files using the section on the left."
assert response_message != "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
@ -404,7 +402,7 @@ def test_summarize_multiple_files(client_offline_chat, default_user2: KhojUser):
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
assert response_message == "Only one file can be selected for summarization." assert response_message is not None
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
@ -460,7 +458,6 @@ def test_summarize_different_conversation(client_offline_chat, default_user2: Kh
# Assert # Assert
assert response_message != "" assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left." assert response_message != "No files selected for summarization. Please add files using the section on the left."
assert response_message != "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)

View file

@ -312,7 +312,6 @@ def test_summarize_one_file(chat_client, default_user2: KhojUser):
# Assert # Assert
assert response_message != "" assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left." assert response_message != "No files selected for summarization. Please add files using the section on the left."
assert response_message != "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
@ -344,7 +343,6 @@ def test_summarize_extra_text(chat_client, default_user2: KhojUser):
# Assert # Assert
assert response_message != "" assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left." assert response_message != "No files selected for summarization. Please add files using the section on the left."
assert response_message != "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
@ -371,7 +369,7 @@ def test_summarize_multiple_files(chat_client, default_user2: KhojUser):
response_message = response.json()["response"] response_message = response.json()["response"]
# Assert # Assert
assert response_message == "Only one file can be selected for summarization." assert response_message is not None
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
@ -435,7 +433,6 @@ def test_summarize_different_conversation(chat_client, default_user2: KhojUser):
assert ( assert (
response_message_conv1 != "No files selected for summarization. Please add files using the section on the left." response_message_conv1 != "No files selected for summarization. Please add files using the section on the left."
) )
assert response_message_conv1 != "Only one file can be selected for summarization."
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)