mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Merge pull request #957 from khoj-ai/features/include-full-file-in-convo-with-filter
Support including file attachments in the chat message Now that models have much larger context windows, we can reasonably include full texts of certain files in the messages. Do this when an explicit file filter is set in a conversation. Do so in a separate user message in order to mitigate any confusion in the operation. Pipe the relevant attached_files context through all methods calling into models. This breaks certain prior behaviors. We will no longer automatically be processing/generating embeddings on the backend and adding documents to the "brain". You'll have to go to settings and go through the upload documents flow there in order to add docs to the brain (i.e., have search include them during question / response).
This commit is contained in:
commit
b563f46a2e
33 changed files with 880 additions and 418 deletions
|
@ -8,7 +8,7 @@ import ChatHistory from "../components/chatHistory/chatHistory";
|
||||||
import { useSearchParams } from "next/navigation";
|
import { useSearchParams } from "next/navigation";
|
||||||
import Loading from "../components/loading/loading";
|
import Loading from "../components/loading/loading";
|
||||||
|
|
||||||
import { processMessageChunk } from "../common/chatFunctions";
|
import { generateNewTitle, processMessageChunk } from "../common/chatFunctions";
|
||||||
|
|
||||||
import "katex/dist/katex.min.css";
|
import "katex/dist/katex.min.css";
|
||||||
|
|
||||||
|
@ -19,7 +19,11 @@ import {
|
||||||
StreamMessage,
|
StreamMessage,
|
||||||
} from "../components/chatMessage/chatMessage";
|
} from "../components/chatMessage/chatMessage";
|
||||||
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../common/utils";
|
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../common/utils";
|
||||||
import { ChatInputArea, ChatOptions } from "../components/chatInputArea/chatInputArea";
|
import {
|
||||||
|
AttachedFileText,
|
||||||
|
ChatInputArea,
|
||||||
|
ChatOptions,
|
||||||
|
} from "../components/chatInputArea/chatInputArea";
|
||||||
import { useAuthenticatedData } from "../common/auth";
|
import { useAuthenticatedData } from "../common/auth";
|
||||||
import { AgentData } from "../agents/page";
|
import { AgentData } from "../agents/page";
|
||||||
|
|
||||||
|
@ -30,7 +34,7 @@ interface ChatBodyDataProps {
|
||||||
setQueryToProcess: (query: string) => void;
|
setQueryToProcess: (query: string) => void;
|
||||||
streamedMessages: StreamMessage[];
|
streamedMessages: StreamMessage[];
|
||||||
setStreamedMessages: (messages: StreamMessage[]) => void;
|
setStreamedMessages: (messages: StreamMessage[]) => void;
|
||||||
setUploadedFiles: (files: string[]) => void;
|
setUploadedFiles: (files: AttachedFileText[] | undefined) => void;
|
||||||
isMobileWidth?: boolean;
|
isMobileWidth?: boolean;
|
||||||
isLoggedIn: boolean;
|
isLoggedIn: boolean;
|
||||||
setImages: (images: string[]) => void;
|
setImages: (images: string[]) => void;
|
||||||
|
@ -77,7 +81,24 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
setIsInResearchMode(true);
|
setIsInResearchMode(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}, [setQueryToProcess, props.setImages]);
|
|
||||||
|
const storedUploadedFiles = localStorage.getItem("uploadedFiles");
|
||||||
|
|
||||||
|
if (storedUploadedFiles) {
|
||||||
|
const parsedFiles = storedUploadedFiles ? JSON.parse(storedUploadedFiles) : [];
|
||||||
|
const uploadedFiles: AttachedFileText[] = [];
|
||||||
|
for (const file of parsedFiles) {
|
||||||
|
uploadedFiles.push({
|
||||||
|
name: file.name,
|
||||||
|
file_type: file.file_type,
|
||||||
|
content: file.content,
|
||||||
|
size: file.size,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
localStorage.removeItem("uploadedFiles");
|
||||||
|
props.setUploadedFiles(uploadedFiles);
|
||||||
|
}
|
||||||
|
}, [setQueryToProcess, props.setImages, conversationId]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (message) {
|
if (message) {
|
||||||
|
@ -100,6 +121,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
) {
|
) {
|
||||||
setProcessingMessage(false);
|
setProcessingMessage(false);
|
||||||
setImages([]); // Reset images after processing
|
setImages([]); // Reset images after processing
|
||||||
|
props.setUploadedFiles(undefined); // Reset uploaded files after processing
|
||||||
} else {
|
} else {
|
||||||
setMessage("");
|
setMessage("");
|
||||||
}
|
}
|
||||||
|
@ -153,7 +175,7 @@ export default function Chat() {
|
||||||
const [messages, setMessages] = useState<StreamMessage[]>([]);
|
const [messages, setMessages] = useState<StreamMessage[]>([]);
|
||||||
const [queryToProcess, setQueryToProcess] = useState<string>("");
|
const [queryToProcess, setQueryToProcess] = useState<string>("");
|
||||||
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
||||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | undefined>(undefined);
|
||||||
const [images, setImages] = useState<string[]>([]);
|
const [images, setImages] = useState<string[]>([]);
|
||||||
|
|
||||||
const locationData = useIPLocationData() || {
|
const locationData = useIPLocationData() || {
|
||||||
|
@ -192,6 +214,7 @@ export default function Chat() {
|
||||||
timestamp: new Date().toISOString(),
|
timestamp: new Date().toISOString(),
|
||||||
rawQuery: queryToProcess || "",
|
rawQuery: queryToProcess || "",
|
||||||
images: images,
|
images: images,
|
||||||
|
queryFiles: uploadedFiles,
|
||||||
};
|
};
|
||||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
||||||
setProcessQuerySignal(true);
|
setProcessQuerySignal(true);
|
||||||
|
@ -224,6 +247,9 @@ export default function Chat() {
|
||||||
setQueryToProcess("");
|
setQueryToProcess("");
|
||||||
setProcessQuerySignal(false);
|
setProcessQuerySignal(false);
|
||||||
setImages([]);
|
setImages([]);
|
||||||
|
|
||||||
|
if (conversationId) generateNewTitle(conversationId, setTitle);
|
||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -273,6 +299,7 @@ export default function Chat() {
|
||||||
timezone: locationData.timezone,
|
timezone: locationData.timezone,
|
||||||
}),
|
}),
|
||||||
...(images.length > 0 && { images: images }),
|
...(images.length > 0 && { images: images }),
|
||||||
|
...(uploadedFiles && { files: uploadedFiles }),
|
||||||
};
|
};
|
||||||
|
|
||||||
const response = await fetch(chatAPI, {
|
const response = await fetch(chatAPI, {
|
||||||
|
@ -325,7 +352,7 @@ export default function Chat() {
|
||||||
<div>
|
<div>
|
||||||
<SidePanel
|
<SidePanel
|
||||||
conversationId={conversationId}
|
conversationId={conversationId}
|
||||||
uploadedFiles={uploadedFiles}
|
uploadedFiles={[]}
|
||||||
isMobileWidth={isMobileWidth}
|
isMobileWidth={isMobileWidth}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -267,6 +267,78 @@ export async function createNewConversation(slug: string) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function packageFilesForUpload(files: FileList): Promise<FormData> {
|
||||||
|
const formData = new FormData();
|
||||||
|
|
||||||
|
const fileReadPromises = Array.from(files).map((file) => {
|
||||||
|
return new Promise<void>((resolve, reject) => {
|
||||||
|
let reader = new FileReader();
|
||||||
|
reader.onload = function (event) {
|
||||||
|
if (event.target === null) {
|
||||||
|
reject();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let fileContents = event.target.result;
|
||||||
|
let fileType = file.type;
|
||||||
|
let fileName = file.name;
|
||||||
|
if (fileType === "") {
|
||||||
|
let fileExtension = fileName.split(".").pop();
|
||||||
|
if (fileExtension === "org") {
|
||||||
|
fileType = "text/org";
|
||||||
|
} else if (fileExtension === "md") {
|
||||||
|
fileType = "text/markdown";
|
||||||
|
} else if (fileExtension === "txt") {
|
||||||
|
fileType = "text/plain";
|
||||||
|
} else if (fileExtension === "html") {
|
||||||
|
fileType = "text/html";
|
||||||
|
} else if (fileExtension === "pdf") {
|
||||||
|
fileType = "application/pdf";
|
||||||
|
} else if (fileExtension === "docx") {
|
||||||
|
fileType =
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document";
|
||||||
|
} else {
|
||||||
|
// Skip this file if its type is not supported
|
||||||
|
resolve();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fileContents === null) {
|
||||||
|
reject();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let fileObj = new Blob([fileContents], { type: fileType });
|
||||||
|
formData.append("files", fileObj, file.name);
|
||||||
|
resolve();
|
||||||
|
};
|
||||||
|
reader.onerror = reject;
|
||||||
|
reader.readAsArrayBuffer(file);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.all(fileReadPromises);
|
||||||
|
return formData;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function generateNewTitle(conversationId: string, setTitle: (title: string) => void) {
|
||||||
|
fetch(`/api/chat/title?conversation_id=${conversationId}`, {
|
||||||
|
method: "POST",
|
||||||
|
})
|
||||||
|
.then((res) => {
|
||||||
|
if (!res.ok) throw new Error(`Failed to call API with error ${res.statusText}`);
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.then((data) => {
|
||||||
|
setTitle(data.title);
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
console.error(err);
|
||||||
|
return;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
export function uploadDataForIndexing(
|
export function uploadDataForIndexing(
|
||||||
files: FileList,
|
files: FileList,
|
||||||
setWarning: (warning: string) => void,
|
setWarning: (warning: string) => void,
|
||||||
|
|
|
@ -49,8 +49,11 @@ import {
|
||||||
Gavel,
|
Gavel,
|
||||||
Broadcast,
|
Broadcast,
|
||||||
KeyReturn,
|
KeyReturn,
|
||||||
|
FilePdf,
|
||||||
|
FileMd,
|
||||||
|
MicrosoftWordLogo,
|
||||||
} from "@phosphor-icons/react";
|
} from "@phosphor-icons/react";
|
||||||
import { Markdown, OrgMode, Pdf, Word } from "@/app/components/logo/fileLogo";
|
import { OrgMode } from "@/app/components/logo/fileLogo";
|
||||||
|
|
||||||
interface IconMap {
|
interface IconMap {
|
||||||
[key: string]: (color: string, width: string, height: string) => JSX.Element | null;
|
[key: string]: (color: string, width: string, height: string) => JSX.Element | null;
|
||||||
|
@ -238,11 +241,12 @@ function getIconFromFilename(
|
||||||
return <OrgMode className={className} />;
|
return <OrgMode className={className} />;
|
||||||
case "markdown":
|
case "markdown":
|
||||||
case "md":
|
case "md":
|
||||||
return <Markdown className={className} />;
|
return <FileMd className={className} />;
|
||||||
case "pdf":
|
case "pdf":
|
||||||
return <Pdf className={className} />;
|
return <FilePdf className={className} />;
|
||||||
case "doc":
|
case "doc":
|
||||||
return <Word className={className} />;
|
case "docx":
|
||||||
|
return <MicrosoftWordLogo className={className} />;
|
||||||
case "jpg":
|
case "jpg":
|
||||||
case "jpeg":
|
case "jpeg":
|
||||||
case "png":
|
case "png":
|
||||||
|
|
|
@ -71,6 +71,16 @@ export function useIsMobileWidth() {
|
||||||
return isMobileWidth;
|
return isMobileWidth;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const convertBytesToText = (fileSize: number) => {
|
||||||
|
if (fileSize < 1024) {
|
||||||
|
return `${fileSize} B`;
|
||||||
|
} else if (fileSize < 1024 * 1024) {
|
||||||
|
return `${(fileSize / 1024).toFixed(2)} KB`;
|
||||||
|
} else {
|
||||||
|
return `${(fileSize / (1024 * 1024)).toFixed(2)} MB`;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
export function useDebounce<T>(value: T, delay: number): T {
|
export function useDebounce<T>(value: T, delay: number): T {
|
||||||
const [debouncedValue, setDebouncedValue] = useState<T>(value);
|
const [debouncedValue, setDebouncedValue] = useState<T>(value);
|
||||||
|
|
||||||
|
|
|
@ -373,6 +373,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
||||||
images: message.images,
|
images: message.images,
|
||||||
conversationId: props.conversationId,
|
conversationId: props.conversationId,
|
||||||
turnId: messageTurnId,
|
turnId: messageTurnId,
|
||||||
|
queryFiles: message.queryFiles,
|
||||||
}}
|
}}
|
||||||
customClassName="fullHistory"
|
customClassName="fullHistory"
|
||||||
borderLeftColor={`${data?.agent?.color}-500`}
|
borderLeftColor={`${data?.agent?.color}-500`}
|
||||||
|
|
|
@ -40,19 +40,36 @@ import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/comp
|
||||||
import { convertColorToTextClass, convertToBGClass } from "@/app/common/colorUtils";
|
import { convertColorToTextClass, convertToBGClass } from "@/app/common/colorUtils";
|
||||||
|
|
||||||
import LoginPrompt from "../loginPrompt/loginPrompt";
|
import LoginPrompt from "../loginPrompt/loginPrompt";
|
||||||
import { uploadDataForIndexing } from "../../common/chatFunctions";
|
|
||||||
import { InlineLoading } from "../loading/loading";
|
import { InlineLoading } from "../loading/loading";
|
||||||
import { getIconForSlashCommand } from "@/app/common/iconUtils";
|
import { getIconForSlashCommand, getIconFromFilename } from "@/app/common/iconUtils";
|
||||||
|
import { packageFilesForUpload } from "@/app/common/chatFunctions";
|
||||||
|
import { convertBytesToText } from "@/app/common/utils";
|
||||||
|
import {
|
||||||
|
Dialog,
|
||||||
|
DialogContent,
|
||||||
|
DialogDescription,
|
||||||
|
DialogHeader,
|
||||||
|
DialogTitle,
|
||||||
|
DialogTrigger,
|
||||||
|
} from "@/components/ui/dialog";
|
||||||
|
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||||
|
|
||||||
export interface ChatOptions {
|
export interface ChatOptions {
|
||||||
[key: string]: string;
|
[key: string]: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface AttachedFileText {
|
||||||
|
name: string;
|
||||||
|
content: string;
|
||||||
|
file_type: string;
|
||||||
|
size: number;
|
||||||
|
}
|
||||||
|
|
||||||
interface ChatInputProps {
|
interface ChatInputProps {
|
||||||
sendMessage: (message: string) => void;
|
sendMessage: (message: string) => void;
|
||||||
sendImage: (image: string) => void;
|
sendImage: (image: string) => void;
|
||||||
sendDisabled: boolean;
|
sendDisabled: boolean;
|
||||||
setUploadedFiles?: (files: string[]) => void;
|
setUploadedFiles: (files: AttachedFileText[]) => void;
|
||||||
conversationId?: string | null;
|
conversationId?: string | null;
|
||||||
chatOptionsData?: ChatOptions | null;
|
chatOptionsData?: ChatOptions | null;
|
||||||
isMobileWidth?: boolean;
|
isMobileWidth?: boolean;
|
||||||
|
@ -75,6 +92,9 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
||||||
const [imagePaths, setImagePaths] = useState<string[]>([]);
|
const [imagePaths, setImagePaths] = useState<string[]>([]);
|
||||||
const [imageData, setImageData] = useState<string[]>([]);
|
const [imageData, setImageData] = useState<string[]>([]);
|
||||||
|
|
||||||
|
const [attachedFiles, setAttachedFiles] = useState<FileList | null>(null);
|
||||||
|
const [convertedAttachedFiles, setConvertedAttachedFiles] = useState<AttachedFileText[]>([]);
|
||||||
|
|
||||||
const [recording, setRecording] = useState(false);
|
const [recording, setRecording] = useState(false);
|
||||||
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
|
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(null);
|
||||||
|
|
||||||
|
@ -154,6 +174,8 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
||||||
}
|
}
|
||||||
|
|
||||||
props.sendMessage(messageToSend);
|
props.sendMessage(messageToSend);
|
||||||
|
setAttachedFiles(null);
|
||||||
|
setConvertedAttachedFiles([]);
|
||||||
setMessage("");
|
setMessage("");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,22 +225,69 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
||||||
setImagePaths((prevPaths) => [...prevPaths, ...newImagePaths]);
|
setImagePaths((prevPaths) => [...prevPaths, ...newImagePaths]);
|
||||||
// Set focus to the input for user message after uploading files
|
// Set focus to the input for user message after uploading files
|
||||||
chatInputRef?.current?.focus();
|
chatInputRef?.current?.focus();
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
uploadDataForIndexing(
|
// Process all non-image files
|
||||||
files,
|
const nonImageFiles = Array.from(files).filter(
|
||||||
setWarning,
|
(file) => !image_endings.includes(file.name.split(".").pop() || ""),
|
||||||
setUploading,
|
|
||||||
setError,
|
|
||||||
props.setUploadedFiles,
|
|
||||||
props.conversationId,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Concatenate attachedFiles and files
|
||||||
|
const newFiles = nonImageFiles
|
||||||
|
? Array.from(nonImageFiles).concat(Array.from(attachedFiles || []))
|
||||||
|
: Array.from(attachedFiles || []);
|
||||||
|
|
||||||
|
// Ensure files are below size limit (10 MB)
|
||||||
|
for (let i = 0; i < newFiles.length; i++) {
|
||||||
|
if (newFiles[i].size > 10 * 1024 * 1024) {
|
||||||
|
setWarning(
|
||||||
|
`File ${newFiles[i].name} is too large. Please upload files smaller than 10 MB.`,
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const dataTransfer = new DataTransfer();
|
||||||
|
newFiles.forEach((file) => dataTransfer.items.add(file));
|
||||||
|
setAttachedFiles(dataTransfer.files);
|
||||||
|
|
||||||
|
// Extract text from files
|
||||||
|
extractTextFromFiles(dataTransfer.files).then((data) => {
|
||||||
|
props.setUploadedFiles(data);
|
||||||
|
setConvertedAttachedFiles(data);
|
||||||
|
});
|
||||||
|
|
||||||
// Set focus to the input for user message after uploading files
|
// Set focus to the input for user message after uploading files
|
||||||
chatInputRef?.current?.focus();
|
chatInputRef?.current?.focus();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function extractTextFromFiles(files: FileList): Promise<AttachedFileText[]> {
|
||||||
|
const formData = await packageFilesForUpload(files);
|
||||||
|
setUploading(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch("/api/content/convert", {
|
||||||
|
method: "POST",
|
||||||
|
body: formData,
|
||||||
|
});
|
||||||
|
setUploading(false);
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`HTTP error! status: ${response.status}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
return await response.json();
|
||||||
|
} catch (error) {
|
||||||
|
setError(
|
||||||
|
"Error converting files. " +
|
||||||
|
error +
|
||||||
|
". Please try again, or contact team@khoj.dev if the issue persists.",
|
||||||
|
);
|
||||||
|
console.error("Error converting files:", error);
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Assuming this function is added within the same context as the provided excerpt
|
// Assuming this function is added within the same context as the provided excerpt
|
||||||
async function startRecordingAndTranscribe() {
|
async function startRecordingAndTranscribe() {
|
||||||
try {
|
try {
|
||||||
|
@ -445,6 +514,93 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<div>
|
<div>
|
||||||
|
<div className="flex items-center gap-2 overflow-x-auto">
|
||||||
|
{imageUploaded &&
|
||||||
|
imagePaths.map((path, index) => (
|
||||||
|
<div key={index} className="relative flex-shrink-0 pb-3 pt-2 group">
|
||||||
|
<img
|
||||||
|
src={path}
|
||||||
|
alt={`img-${index}`}
|
||||||
|
className="w-auto h-16 object-cover rounded-xl"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||||
|
onClick={() => removeImageUpload(index)}
|
||||||
|
>
|
||||||
|
<X className="h-3 w-3" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
{convertedAttachedFiles &&
|
||||||
|
Array.from(convertedAttachedFiles).map((file, index) => (
|
||||||
|
<Dialog key={index}>
|
||||||
|
<DialogTrigger asChild>
|
||||||
|
<div key={index} className="relative flex-shrink-0 p-2 group">
|
||||||
|
<div
|
||||||
|
className={`w-auto h-16 object-cover rounded-xl ${props.agentColor ? convertToBGClass(props.agentColor) : "bg-orange-300 hover:bg-orange-500"} bg-opacity-15`}
|
||||||
|
>
|
||||||
|
<div className="flex p-2 flex-col justify-start items-start h-full">
|
||||||
|
<span className="text-sm font-bold text-neutral-500 dark:text-neutral-400 text-ellipsis truncate max-w-[200px] break-words">
|
||||||
|
{file.name}
|
||||||
|
</span>
|
||||||
|
<span className="flex items-center gap-1">
|
||||||
|
{getIconFromFilename(file.file_type)}
|
||||||
|
<span className="text-xs text-neutral-500 dark:text-neutral-400">
|
||||||
|
{convertBytesToText(file.size)}
|
||||||
|
</span>
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||||
|
onClick={() => {
|
||||||
|
setAttachedFiles((prevFiles) => {
|
||||||
|
const removeFile = file.name;
|
||||||
|
if (!prevFiles) return null;
|
||||||
|
const updatedFiles = Array.from(
|
||||||
|
prevFiles,
|
||||||
|
).filter((file) => file.name !== removeFile);
|
||||||
|
const dataTransfer = new DataTransfer();
|
||||||
|
updatedFiles.forEach((file) =>
|
||||||
|
dataTransfer.items.add(file),
|
||||||
|
);
|
||||||
|
|
||||||
|
const filteredConvertedAttachedFiles =
|
||||||
|
convertedAttachedFiles.filter(
|
||||||
|
(file) => file.name !== removeFile,
|
||||||
|
);
|
||||||
|
|
||||||
|
props.setUploadedFiles(
|
||||||
|
filteredConvertedAttachedFiles,
|
||||||
|
);
|
||||||
|
setConvertedAttachedFiles(
|
||||||
|
filteredConvertedAttachedFiles,
|
||||||
|
);
|
||||||
|
return dataTransfer.files;
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<X className="h-3 w-3" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</DialogTrigger>
|
||||||
|
<DialogContent>
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>{file.name}</DialogTitle>
|
||||||
|
</DialogHeader>
|
||||||
|
<DialogDescription>
|
||||||
|
<ScrollArea className="h-72 w-full rounded-md">
|
||||||
|
{file.content}
|
||||||
|
</ScrollArea>
|
||||||
|
</DialogDescription>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
<div
|
<div
|
||||||
className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
||||||
onDragOver={handleDragOver}
|
onDragOver={handleDragOver}
|
||||||
|
@ -453,12 +609,14 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
||||||
>
|
>
|
||||||
<input
|
<input
|
||||||
type="file"
|
type="file"
|
||||||
|
accept=".pdf,.doc,.docx,.txt,.md,.org,.jpg,.jpeg,.png,.webp"
|
||||||
multiple={true}
|
multiple={true}
|
||||||
ref={fileInputRef}
|
ref={fileInputRef}
|
||||||
onChange={handleFileChange}
|
onChange={handleFileChange}
|
||||||
style={{ display: "none" }}
|
style={{ display: "none" }}
|
||||||
/>
|
/>
|
||||||
<div className="flex items-end pb-2">
|
|
||||||
|
<div className="flex items-center">
|
||||||
<Button
|
<Button
|
||||||
variant={"ghost"}
|
variant={"ghost"}
|
||||||
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
|
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
|
||||||
|
@ -469,29 +627,6 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
<div className="flex-grow flex flex-col w-full gap-1.5 relative">
|
<div className="flex-grow flex flex-col w-full gap-1.5 relative">
|
||||||
<div className="flex items-center gap-2 overflow-x-auto">
|
|
||||||
{imageUploaded &&
|
|
||||||
imagePaths.map((path, index) => (
|
|
||||||
<div
|
|
||||||
key={index}
|
|
||||||
className="relative flex-shrink-0 pb-3 pt-2 group"
|
|
||||||
>
|
|
||||||
<img
|
|
||||||
src={path}
|
|
||||||
alt={`img-${index}`}
|
|
||||||
className="w-auto h-16 object-cover rounded-xl"
|
|
||||||
/>
|
|
||||||
<Button
|
|
||||||
variant="ghost"
|
|
||||||
size="icon"
|
|
||||||
className="absolute -top-0 -right-2 h-5 w-5 rounded-full bg-neutral-200 dark:bg-neutral-600 hover:bg-neutral-300 dark:hover:bg-neutral-500 opacity-0 group-hover:opacity-100 transition-opacity"
|
|
||||||
onClick={() => removeImageUpload(index)}
|
|
||||||
>
|
|
||||||
<X className="h-3 w-3" />
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
<Textarea
|
<Textarea
|
||||||
ref={chatInputRef}
|
ref={chatInputRef}
|
||||||
className={`border-none focus:border-none
|
className={`border-none focus:border-none
|
||||||
|
@ -582,10 +717,14 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
||||||
<span className="text-muted-foreground text-sm">Research Mode</span>
|
<span className="text-muted-foreground text-sm">Research Mode</span>
|
||||||
{useResearchMode ? (
|
{useResearchMode ? (
|
||||||
<ToggleRight
|
<ToggleRight
|
||||||
|
weight="fill"
|
||||||
className={`w-6 h-6 inline-block ${props.agentColor ? convertColorToTextClass(props.agentColor) : convertColorToTextClass("orange")} rounded-full`}
|
className={`w-6 h-6 inline-block ${props.agentColor ? convertColorToTextClass(props.agentColor) : convertColorToTextClass("orange")} rounded-full`}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<ToggleLeft className={`w-6 h-6 inline-block rounded-full`} />
|
<ToggleLeft
|
||||||
|
weight="fill"
|
||||||
|
className={`w-6 h-6 inline-block ${props.agentColor ? convertColorToTextClass(props.agentColor) : convertColorToTextClass("orange")} rounded-full`}
|
||||||
|
/>
|
||||||
)}
|
)}
|
||||||
</Button>
|
</Button>
|
||||||
</TooltipTrigger>
|
</TooltipTrigger>
|
||||||
|
|
|
@ -40,6 +40,18 @@ import { AgentData } from "@/app/agents/page";
|
||||||
import renderMathInElement from "katex/contrib/auto-render";
|
import renderMathInElement from "katex/contrib/auto-render";
|
||||||
import "katex/dist/katex.min.css";
|
import "katex/dist/katex.min.css";
|
||||||
import ExcalidrawComponent from "../excalidraw/excalidraw";
|
import ExcalidrawComponent from "../excalidraw/excalidraw";
|
||||||
|
import { AttachedFileText } from "../chatInputArea/chatInputArea";
|
||||||
|
import {
|
||||||
|
Dialog,
|
||||||
|
DialogContent,
|
||||||
|
DialogDescription,
|
||||||
|
DialogHeader,
|
||||||
|
DialogTrigger,
|
||||||
|
} from "@/components/ui/dialog";
|
||||||
|
import { DialogTitle } from "@radix-ui/react-dialog";
|
||||||
|
import { convertBytesToText } from "@/app/common/utils";
|
||||||
|
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||||
|
import { getIconFromFilename } from "@/app/common/iconUtils";
|
||||||
|
|
||||||
const md = new markdownIt({
|
const md = new markdownIt({
|
||||||
html: true,
|
html: true,
|
||||||
|
@ -149,6 +161,7 @@ export interface SingleChatMessage {
|
||||||
images?: string[];
|
images?: string[];
|
||||||
conversationId: string;
|
conversationId: string;
|
||||||
turnId?: string;
|
turnId?: string;
|
||||||
|
queryFiles?: AttachedFileText[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface StreamMessage {
|
export interface StreamMessage {
|
||||||
|
@ -165,6 +178,7 @@ export interface StreamMessage {
|
||||||
intentType?: string;
|
intentType?: string;
|
||||||
inferredQueries?: string[];
|
inferredQueries?: string[];
|
||||||
turnId?: string;
|
turnId?: string;
|
||||||
|
queryFiles?: AttachedFileText[];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatHistoryData {
|
export interface ChatHistoryData {
|
||||||
|
@ -398,7 +412,6 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||||
if (props.chatMessage.intent) {
|
if (props.chatMessage.intent) {
|
||||||
const { type, "inferred-queries": inferredQueries } = props.chatMessage.intent;
|
const { type, "inferred-queries": inferredQueries } = props.chatMessage.intent;
|
||||||
|
|
||||||
console.log("intent type", type);
|
|
||||||
if (type in intentTypeHandlers) {
|
if (type in intentTypeHandlers) {
|
||||||
message = intentTypeHandlers[type as keyof typeof intentTypeHandlers](message);
|
message = intentTypeHandlers[type as keyof typeof intentTypeHandlers](message);
|
||||||
}
|
}
|
||||||
|
@ -695,6 +708,40 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
||||||
onMouseLeave={(event) => setIsHovering(false)}
|
onMouseLeave={(event) => setIsHovering(false)}
|
||||||
onMouseEnter={(event) => setIsHovering(true)}
|
onMouseEnter={(event) => setIsHovering(true)}
|
||||||
>
|
>
|
||||||
|
{props.chatMessage.queryFiles && props.chatMessage.queryFiles.length > 0 && (
|
||||||
|
<div className="flex flex-wrap flex-col m-2 max-w-full">
|
||||||
|
{props.chatMessage.queryFiles.map((file, index) => (
|
||||||
|
<Dialog key={index}>
|
||||||
|
<DialogTrigger asChild>
|
||||||
|
<div
|
||||||
|
className="flex items-center space-x-2 cursor-pointer bg-gray-500 bg-opacity-25 rounded-lg m-1 p-2 w-full
|
||||||
|
"
|
||||||
|
>
|
||||||
|
<div className="flex-shrink-0">
|
||||||
|
{getIconFromFilename(file.file_type)}
|
||||||
|
</div>
|
||||||
|
<span className="truncate flex-1 min-w-0">{file.name}</span>
|
||||||
|
{file.size && (
|
||||||
|
<span className="text-gray-400 flex-shrink-0">
|
||||||
|
({convertBytesToText(file.size)})
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</DialogTrigger>
|
||||||
|
<DialogContent>
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>{file.name}</DialogTitle>
|
||||||
|
</DialogHeader>
|
||||||
|
<DialogDescription>
|
||||||
|
<ScrollArea className="h-72 w-full rounded-md">
|
||||||
|
{file.content}
|
||||||
|
</ScrollArea>
|
||||||
|
</DialogDescription>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<div className={chatMessageWrapperClasses(props.chatMessage)}>
|
<div className={chatMessageWrapperClasses(props.chatMessage)}>
|
||||||
<div
|
<div
|
||||||
ref={messageRef}
|
ref={messageRef}
|
||||||
|
|
|
@ -81,111 +81,3 @@ export function OrgMode({ className }: { className?: string }) {
|
||||||
</svg>
|
</svg>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function Markdown({ className }: { className?: string }) {
|
|
||||||
const classes = className ?? "w-6 h-6 text-muted-foreground inline-flex mr-1";
|
|
||||||
return (
|
|
||||||
<svg
|
|
||||||
className={`${classes}`}
|
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
|
||||||
width="208"
|
|
||||||
height="128"
|
|
||||||
viewBox="0 0 208 128"
|
|
||||||
>
|
|
||||||
<rect
|
|
||||||
width="198"
|
|
||||||
height="118"
|
|
||||||
x="5"
|
|
||||||
y="5"
|
|
||||||
ry="10"
|
|
||||||
stroke="#000"
|
|
||||||
strokeWidth="10"
|
|
||||||
fill="none"
|
|
||||||
/>
|
|
||||||
<path d="M30 98V30h20l20 25 20-25h20v68H90V59L70 84 50 59v39zm125 0l-30-33h20V30h20v35h20z" />
|
|
||||||
</svg>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function Pdf({ className }: { className?: string }) {
|
|
||||||
const classes = className ?? "w-6 h-6 text-muted-foreground inline-flex mr-1";
|
|
||||||
return (
|
|
||||||
<svg
|
|
||||||
className={`${classes}`}
|
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
|
||||||
enableBackground="new 0 0 334.371 380.563"
|
|
||||||
version="1.1"
|
|
||||||
viewBox="0 0 14 16"
|
|
||||||
>
|
|
||||||
<g transform="matrix(.04589 0 0 .04589 -.66877 -.73379)">
|
|
||||||
<polygon
|
|
||||||
points="51.791 356.65 51.791 23.99 204.5 23.99 282.65 102.07 282.65 356.65"
|
|
||||||
fill="#fff"
|
|
||||||
strokeWidth="212.65"
|
|
||||||
/>
|
|
||||||
<path
|
|
||||||
d="m201.19 31.99 73.46 73.393v243.26h-214.86v-316.66h141.4m6.623-16h-164.02v348.66h246.85v-265.9z"
|
|
||||||
strokeWidth="21.791"
|
|
||||||
/>
|
|
||||||
</g>
|
|
||||||
<g transform="matrix(.04589 0 0 .04589 -.66877 -.73379)">
|
|
||||||
<polygon
|
|
||||||
points="282.65 356.65 51.791 356.65 51.791 23.99 204.5 23.99 206.31 25.8 206.31 100.33 280.9 100.33 282.65 102.07"
|
|
||||||
fill="#fff"
|
|
||||||
strokeWidth="212.65"
|
|
||||||
/>
|
|
||||||
<path
|
|
||||||
d="m198.31 31.99v76.337h76.337v240.32h-214.86v-316.66h138.52m9.5-16h-164.02v348.66h246.85v-265.9l-6.43-6.424h-69.907v-69.842z"
|
|
||||||
strokeWidth="21.791"
|
|
||||||
/>
|
|
||||||
</g>
|
|
||||||
<g transform="matrix(.04589 0 0 .04589 -.66877 -.73379)" strokeWidth="21.791">
|
|
||||||
<polygon points="258.31 87.75 219.64 87.75 219.64 48.667 258.31 86.38" />
|
|
||||||
<path d="m227.64 67.646 12.41 12.104h-12.41v-12.104m-5.002-27.229h-10.998v55.333h54.666v-12.742z" />
|
|
||||||
</g>
|
|
||||||
<g
|
|
||||||
transform="matrix(.04589 0 0 .04589 -.66877 -.73379)"
|
|
||||||
fill="#ed1c24"
|
|
||||||
strokeWidth="212.65"
|
|
||||||
>
|
|
||||||
<polygon points="311.89 284.49 22.544 284.49 22.544 167.68 37.291 152.94 37.291 171.49 297.15 171.49 297.15 152.94 311.89 167.68" />
|
|
||||||
<path d="m303.65 168.63 1.747 1.747v107.62h-276.35v-107.62l1.747-1.747v9.362h272.85v-9.362m-12.999-31.385v27.747h-246.86v-27.747l-27.747 27.747v126h302.35v-126z" />
|
|
||||||
</g>
|
|
||||||
<rect x="1.7219" y="7.9544" width="10.684" height="4.0307" fill="none" />
|
|
||||||
<g transform="matrix(.04589 0 0 .04589 1.7219 11.733)" fill="#fff" strokeWidth="21.791">
|
|
||||||
<path d="m9.216 0v-83.2h30.464q6.784 0 12.928 1.408 6.144 1.28 10.752 4.608 4.608 3.2 7.296 8.576 2.816 5.248 2.816 13.056 0 7.68-2.816 13.184-2.688 5.504-7.296 9.088-4.608 3.456-10.624 5.248-6.016 1.664-12.544 1.664h-8.96v26.368zm22.016-43.776h7.936q6.528 0 9.6-3.072 3.2-3.072 3.2-8.704t-3.456-7.936-9.856-2.304h-7.424z" />
|
|
||||||
<path d="m87.04 0v-83.2h24.576q9.472 0 17.28 2.304 7.936 2.304 13.568 7.296t8.704 12.8q3.2 7.808 3.2 18.816t-3.072 18.944-8.704 13.056q-5.504 5.12-13.184 7.552-7.552 2.432-16.512 2.432zm22.016-17.664h1.28q4.48 0 8.448-1.024 3.968-1.152 6.784-3.84 2.944-2.688 4.608-7.424t1.664-12.032-1.664-11.904-4.608-7.168q-2.816-2.56-6.784-3.456-3.968-1.024-8.448-1.024h-1.28z" />
|
|
||||||
<path d="m169.22 0v-83.2h54.272v18.432h-32.256v15.872h27.648v18.432h-27.648v30.464z" />
|
|
||||||
</g>
|
|
||||||
</svg>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export function Word({ className }: { className?: string }) {
|
|
||||||
const classes = className ?? "w-6 h-6 text-muted-foreground inline-flex mr-1";
|
|
||||||
return (
|
|
||||||
<svg
|
|
||||||
className={`${classes}`}
|
|
||||||
xmlns="http://www.w3.org/2000/svg"
|
|
||||||
fill="#FFF"
|
|
||||||
stroke-miterlimit="10"
|
|
||||||
strokeWidth="2"
|
|
||||||
viewBox="0 0 96 96"
|
|
||||||
>
|
|
||||||
<path
|
|
||||||
stroke="#979593"
|
|
||||||
d="M67.1716 7H27c-1.1046 0-2 .8954-2 2v78c0 1.1046.8954 2 2 2h58c1.1046 0 2-.8954 2-2V26.8284c0-.5304-.2107-1.0391-.5858-1.4142L68.5858 7.5858C68.2107 7.2107 67.702 7 67.1716 7z"
|
|
||||||
/>
|
|
||||||
<path fill="none" stroke="#979593" d="M67 7v18c0 1.1046.8954 2 2 2h18" />
|
|
||||||
<path
|
|
||||||
fill="#C8C6C4"
|
|
||||||
d="M79 61H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0-6H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0-6H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0-6H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1zm0 24H48v-2h31c.5523 0 1 .4477 1 1s-.4477 1-1 1z"
|
|
||||||
/>
|
|
||||||
<path
|
|
||||||
fill="#185ABD"
|
|
||||||
d="M12 74h32c2.2091 0 4-1.7909 4-4V38c0-2.2091-1.7909-4-4-4H12c-2.2091 0-4 1.7909-4 4v32c0 2.2091 1.7909 4 4 4z"
|
|
||||||
/>
|
|
||||||
<path d="M21.6245 60.6455c.0661.522.109.9769.1296 1.3657h.0762c.0306-.3685.0889-.8129.1751-1.3349.0862-.5211.1703-.961.2517-1.319L25.7911 44h4.5702l3.6562 15.1272c.183.7468.3353 1.6973.457 2.8532h.0608c.0508-.7979.1777-1.7184.3809-2.7615L37.8413 44H42l-5.1183 22h-4.86l-3.4885-14.5744c-.1016-.4197-.2158-.9663-.3428-1.6417-.127-.6745-.2057-1.1656-.236-1.4724h-.0608c-.0407.358-.1195.8896-.2364 1.595-.1169.7062-.211 1.2273-.2819 1.565L24.1 66h-4.9357L14 44h4.2349l3.1843 15.3882c.0709.3165.1392.7362.2053 1.2573z" />
|
|
||||||
</svg>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
|
@ -11,7 +11,11 @@ import { Card, CardTitle } from "@/components/ui/card";
|
||||||
import SuggestionCard from "@/app/components/suggestions/suggestionCard";
|
import SuggestionCard from "@/app/components/suggestions/suggestionCard";
|
||||||
import SidePanel from "@/app/components/sidePanel/chatHistorySidePanel";
|
import SidePanel from "@/app/components/sidePanel/chatHistorySidePanel";
|
||||||
import Loading from "@/app/components/loading/loading";
|
import Loading from "@/app/components/loading/loading";
|
||||||
import { ChatInputArea, ChatOptions } from "@/app/components/chatInputArea/chatInputArea";
|
import {
|
||||||
|
AttachedFileText,
|
||||||
|
ChatInputArea,
|
||||||
|
ChatOptions,
|
||||||
|
} from "@/app/components/chatInputArea/chatInputArea";
|
||||||
import { Suggestion, suggestionsData } from "@/app/components/suggestions/suggestionsData";
|
import { Suggestion, suggestionsData } from "@/app/components/suggestions/suggestionsData";
|
||||||
import LoginPrompt from "@/app/components/loginPrompt/loginPrompt";
|
import LoginPrompt from "@/app/components/loginPrompt/loginPrompt";
|
||||||
|
|
||||||
|
@ -34,7 +38,7 @@ import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover
|
||||||
interface ChatBodyDataProps {
|
interface ChatBodyDataProps {
|
||||||
chatOptionsData: ChatOptions | null;
|
chatOptionsData: ChatOptions | null;
|
||||||
onConversationIdChange?: (conversationId: string) => void;
|
onConversationIdChange?: (conversationId: string) => void;
|
||||||
setUploadedFiles: (files: string[]) => void;
|
setUploadedFiles: (files: AttachedFileText[]) => void;
|
||||||
isMobileWidth?: boolean;
|
isMobileWidth?: boolean;
|
||||||
isLoggedIn: boolean;
|
isLoggedIn: boolean;
|
||||||
userConfig: UserConfig | null;
|
userConfig: UserConfig | null;
|
||||||
|
@ -155,6 +159,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
if (images.length > 0) {
|
if (images.length > 0) {
|
||||||
localStorage.setItem("images", JSON.stringify(images));
|
localStorage.setItem("images", JSON.stringify(images));
|
||||||
}
|
}
|
||||||
|
|
||||||
window.location.href = `/chat?conversationId=${newConversationId}`;
|
window.location.href = `/chat?conversationId=${newConversationId}`;
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error creating new conversation:", error);
|
console.error("Error creating new conversation:", error);
|
||||||
|
@ -401,7 +406,7 @@ export default function Home() {
|
||||||
const [chatOptionsData, setChatOptionsData] = useState<ChatOptions | null>(null);
|
const [chatOptionsData, setChatOptionsData] = useState<ChatOptions | null>(null);
|
||||||
const [isLoading, setLoading] = useState(true);
|
const [isLoading, setLoading] = useState(true);
|
||||||
const [conversationId, setConversationID] = useState<string | null>(null);
|
const [conversationId, setConversationID] = useState<string | null>(null);
|
||||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | null>(null);
|
||||||
const isMobileWidth = useIsMobileWidth();
|
const isMobileWidth = useIsMobileWidth();
|
||||||
|
|
||||||
const { userConfig: initialUserConfig, isLoadingUserConfig } = useUserConfig(true);
|
const { userConfig: initialUserConfig, isLoadingUserConfig } = useUserConfig(true);
|
||||||
|
@ -417,6 +422,12 @@ export default function Home() {
|
||||||
setUserConfig(initialUserConfig);
|
setUserConfig(initialUserConfig);
|
||||||
}, [initialUserConfig]);
|
}, [initialUserConfig]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (uploadedFiles) {
|
||||||
|
localStorage.setItem("uploadedFiles", JSON.stringify(uploadedFiles));
|
||||||
|
}
|
||||||
|
}, [uploadedFiles]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
fetch("/api/chat/options")
|
fetch("/api/chat/options")
|
||||||
.then((response) => response.json())
|
.then((response) => response.json())
|
||||||
|
@ -442,7 +453,7 @@ export default function Home() {
|
||||||
<div className={`${styles.sidePanel}`}>
|
<div className={`${styles.sidePanel}`}>
|
||||||
<SidePanel
|
<SidePanel
|
||||||
conversationId={conversationId}
|
conversationId={conversationId}
|
||||||
uploadedFiles={uploadedFiles}
|
uploadedFiles={[]}
|
||||||
isMobileWidth={isMobileWidth}
|
isMobileWidth={isMobileWidth}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -137,10 +137,8 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
|
||||||
|
|
||||||
const deleteSelected = async () => {
|
const deleteSelected = async () => {
|
||||||
let filesToDelete = selectedFiles.length > 0 ? selectedFiles : filteredFiles;
|
let filesToDelete = selectedFiles.length > 0 ? selectedFiles : filteredFiles;
|
||||||
console.log("Delete selected files", filesToDelete);
|
|
||||||
|
|
||||||
if (filesToDelete.length === 0) {
|
if (filesToDelete.length === 0) {
|
||||||
console.log("No files to delete");
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -162,15 +160,12 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
|
||||||
|
|
||||||
// Reset selectedFiles
|
// Reset selectedFiles
|
||||||
setSelectedFiles([]);
|
setSelectedFiles([]);
|
||||||
|
|
||||||
console.log("Deleted files:", filesToDelete);
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error deleting files:", error);
|
console.error("Error deleting files:", error);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const deleteFile = async (filename: string) => {
|
const deleteFile = async (filename: string) => {
|
||||||
console.log("Delete selected file", filename);
|
|
||||||
try {
|
try {
|
||||||
const response = await fetch(
|
const response = await fetch(
|
||||||
`/api/content/file?filename=${encodeURIComponent(filename)}`,
|
`/api/content/file?filename=${encodeURIComponent(filename)}`,
|
||||||
|
@ -189,8 +184,6 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
|
||||||
|
|
||||||
// Remove the file from selectedFiles if it's there
|
// Remove the file from selectedFiles if it's there
|
||||||
setSelectedFiles((prevSelected) => prevSelected.filter((file) => file !== filename));
|
setSelectedFiles((prevSelected) => prevSelected.filter((file) => file !== filename));
|
||||||
|
|
||||||
console.log("Deleted file:", filename);
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error("Error deleting file:", error);
|
console.error("Error deleting file:", error);
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,23 +5,25 @@ import React, { Suspense, useEffect, useRef, useState } from "react";
|
||||||
|
|
||||||
import SidePanel from "../../components/sidePanel/chatHistorySidePanel";
|
import SidePanel from "../../components/sidePanel/chatHistorySidePanel";
|
||||||
import ChatHistory from "../../components/chatHistory/chatHistory";
|
import ChatHistory from "../../components/chatHistory/chatHistory";
|
||||||
import NavMenu from "../../components/navMenu/navMenu";
|
|
||||||
import Loading from "../../components/loading/loading";
|
import Loading from "../../components/loading/loading";
|
||||||
|
|
||||||
import "katex/dist/katex.min.css";
|
import "katex/dist/katex.min.css";
|
||||||
|
|
||||||
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../../common/utils";
|
import { useIsMobileWidth, welcomeConsole } from "../../common/utils";
|
||||||
import { useAuthenticatedData } from "@/app/common/auth";
|
import { useAuthenticatedData } from "@/app/common/auth";
|
||||||
|
|
||||||
import { ChatInputArea, ChatOptions } from "@/app/components/chatInputArea/chatInputArea";
|
import {
|
||||||
|
AttachedFileText,
|
||||||
|
ChatInputArea,
|
||||||
|
ChatOptions,
|
||||||
|
} from "@/app/components/chatInputArea/chatInputArea";
|
||||||
import { StreamMessage } from "@/app/components/chatMessage/chatMessage";
|
import { StreamMessage } from "@/app/components/chatMessage/chatMessage";
|
||||||
import { processMessageChunk } from "@/app/common/chatFunctions";
|
|
||||||
import { AgentData } from "@/app/agents/page";
|
import { AgentData } from "@/app/agents/page";
|
||||||
|
|
||||||
interface ChatBodyDataProps {
|
interface ChatBodyDataProps {
|
||||||
chatOptionsData: ChatOptions | null;
|
chatOptionsData: ChatOptions | null;
|
||||||
setTitle: (title: string) => void;
|
setTitle: (title: string) => void;
|
||||||
setUploadedFiles: (files: string[]) => void;
|
setUploadedFiles: (files: AttachedFileText[]) => void;
|
||||||
isMobileWidth?: boolean;
|
isMobileWidth?: boolean;
|
||||||
publicConversationSlug: string;
|
publicConversationSlug: string;
|
||||||
streamedMessages: StreamMessage[];
|
streamedMessages: StreamMessage[];
|
||||||
|
@ -50,23 +52,6 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
||||||
}
|
}
|
||||||
}, [images, props.setImages]);
|
}, [images, props.setImages]);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const storedImages = localStorage.getItem("images");
|
|
||||||
if (storedImages) {
|
|
||||||
const parsedImages: string[] = JSON.parse(storedImages);
|
|
||||||
setImages(parsedImages);
|
|
||||||
const encodedImages = parsedImages.map((img: string) => encodeURIComponent(img));
|
|
||||||
props.setImages(encodedImages);
|
|
||||||
localStorage.removeItem("images");
|
|
||||||
}
|
|
||||||
|
|
||||||
const storedMessage = localStorage.getItem("message");
|
|
||||||
if (storedMessage) {
|
|
||||||
setProcessingMessage(true);
|
|
||||||
setQueryToProcess(storedMessage);
|
|
||||||
}
|
|
||||||
}, [setQueryToProcess, props.setImages]);
|
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (message) {
|
if (message) {
|
||||||
setProcessingMessage(true);
|
setProcessingMessage(true);
|
||||||
|
@ -130,14 +115,10 @@ export default function SharedChat() {
|
||||||
const [conversationId, setConversationID] = useState<string | undefined>(undefined);
|
const [conversationId, setConversationID] = useState<string | undefined>(undefined);
|
||||||
const [messages, setMessages] = useState<StreamMessage[]>([]);
|
const [messages, setMessages] = useState<StreamMessage[]>([]);
|
||||||
const [queryToProcess, setQueryToProcess] = useState<string>("");
|
const [queryToProcess, setQueryToProcess] = useState<string>("");
|
||||||
const [processQuerySignal, setProcessQuerySignal] = useState(false);
|
const [uploadedFiles, setUploadedFiles] = useState<AttachedFileText[] | null>(null);
|
||||||
const [uploadedFiles, setUploadedFiles] = useState<string[]>([]);
|
|
||||||
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
|
const [paramSlug, setParamSlug] = useState<string | undefined>(undefined);
|
||||||
const [images, setImages] = useState<string[]>([]);
|
const [images, setImages] = useState<string[]>([]);
|
||||||
|
|
||||||
const locationData = useIPLocationData() || {
|
|
||||||
timezone: Intl.DateTimeFormat().resolvedOptions().timeZone,
|
|
||||||
};
|
|
||||||
const authenticatedData = useAuthenticatedData();
|
const authenticatedData = useAuthenticatedData();
|
||||||
const isMobileWidth = useIsMobileWidth();
|
const isMobileWidth = useIsMobileWidth();
|
||||||
|
|
||||||
|
@ -161,6 +142,12 @@ export default function SharedChat() {
|
||||||
setParamSlug(window.location.pathname.split("/").pop() || "");
|
setParamSlug(window.location.pathname.split("/").pop() || "");
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (uploadedFiles) {
|
||||||
|
localStorage.setItem("uploadedFiles", JSON.stringify(uploadedFiles));
|
||||||
|
}
|
||||||
|
}, [uploadedFiles]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (queryToProcess && !conversationId) {
|
if (queryToProcess && !conversationId) {
|
||||||
// If the user has not yet started conversing in the chat, create a new conversation
|
// If the user has not yet started conversing in the chat, create a new conversation
|
||||||
|
@ -173,6 +160,11 @@ export default function SharedChat() {
|
||||||
.then((response) => response.json())
|
.then((response) => response.json())
|
||||||
.then((data) => {
|
.then((data) => {
|
||||||
setConversationID(data.conversation_id);
|
setConversationID(data.conversation_id);
|
||||||
|
localStorage.setItem("message", queryToProcess);
|
||||||
|
if (images.length > 0) {
|
||||||
|
localStorage.setItem("images", JSON.stringify(images));
|
||||||
|
}
|
||||||
|
window.location.href = `/chat?conversationId=${data.conversation_id}`;
|
||||||
})
|
})
|
||||||
.catch((err) => {
|
.catch((err) => {
|
||||||
console.error(err);
|
console.error(err);
|
||||||
|
@ -180,105 +172,8 @@ export default function SharedChat() {
|
||||||
});
|
});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (queryToProcess) {
|
|
||||||
// Add a new object to the state
|
|
||||||
const newStreamMessage: StreamMessage = {
|
|
||||||
rawResponse: "",
|
|
||||||
trainOfThought: [],
|
|
||||||
context: [],
|
|
||||||
onlineContext: {},
|
|
||||||
codeContext: {},
|
|
||||||
completed: false,
|
|
||||||
timestamp: new Date().toISOString(),
|
|
||||||
rawQuery: queryToProcess || "",
|
|
||||||
images: images,
|
|
||||||
};
|
|
||||||
setMessages((prevMessages) => [...prevMessages, newStreamMessage]);
|
|
||||||
setProcessQuerySignal(true);
|
|
||||||
}
|
|
||||||
}, [queryToProcess, conversationId, paramSlug]);
|
}, [queryToProcess, conversationId, paramSlug]);
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
if (processQuerySignal) {
|
|
||||||
chat();
|
|
||||||
}
|
|
||||||
}, [processQuerySignal]);
|
|
||||||
|
|
||||||
async function readChatStream(response: Response) {
|
|
||||||
if (!response.ok) throw new Error(response.statusText);
|
|
||||||
if (!response.body) throw new Error("Response body is null");
|
|
||||||
|
|
||||||
const reader = response.body.getReader();
|
|
||||||
const decoder = new TextDecoder();
|
|
||||||
const eventDelimiter = "␃🔚␗";
|
|
||||||
let buffer = "";
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
const { done, value } = await reader.read();
|
|
||||||
if (done) {
|
|
||||||
setQueryToProcess("");
|
|
||||||
setProcessQuerySignal(false);
|
|
||||||
setImages([]);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const chunk = decoder.decode(value, { stream: true });
|
|
||||||
|
|
||||||
buffer += chunk;
|
|
||||||
|
|
||||||
let newEventIndex;
|
|
||||||
while ((newEventIndex = buffer.indexOf(eventDelimiter)) !== -1) {
|
|
||||||
const event = buffer.slice(0, newEventIndex);
|
|
||||||
buffer = buffer.slice(newEventIndex + eventDelimiter.length);
|
|
||||||
if (event) {
|
|
||||||
const currentMessage = messages.find((message) => !message.completed);
|
|
||||||
|
|
||||||
if (!currentMessage) {
|
|
||||||
console.error("No current message found");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
processMessageChunk(event, currentMessage);
|
|
||||||
|
|
||||||
setMessages([...messages]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function chat() {
|
|
||||||
if (!queryToProcess || !conversationId) return;
|
|
||||||
const chatAPI = "/api/chat?client=web";
|
|
||||||
const chatAPIBody = {
|
|
||||||
q: queryToProcess,
|
|
||||||
conversation_id: conversationId,
|
|
||||||
stream: true,
|
|
||||||
...(locationData && {
|
|
||||||
region: locationData.region,
|
|
||||||
country: locationData.country,
|
|
||||||
city: locationData.city,
|
|
||||||
country_code: locationData.countryCode,
|
|
||||||
timezone: locationData.timezone,
|
|
||||||
}),
|
|
||||||
...(images.length > 0 && { image: images }),
|
|
||||||
};
|
|
||||||
|
|
||||||
const response = await fetch(chatAPI, {
|
|
||||||
method: "POST",
|
|
||||||
headers: {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
body: JSON.stringify(chatAPIBody),
|
|
||||||
});
|
|
||||||
|
|
||||||
try {
|
|
||||||
await readChatStream(response);
|
|
||||||
} catch (error) {
|
|
||||||
console.error(error);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isLoading) {
|
if (isLoading) {
|
||||||
return <Loading />;
|
return <Loading />;
|
||||||
}
|
}
|
||||||
|
@ -293,7 +188,7 @@ export default function SharedChat() {
|
||||||
<div className={styles.sidePanel}>
|
<div className={styles.sidePanel}>
|
||||||
<SidePanel
|
<SidePanel
|
||||||
conversationId={conversationId ?? null}
|
conversationId={conversationId ?? null}
|
||||||
uploadedFiles={uploadedFiles}
|
uploadedFiles={[]}
|
||||||
isMobileWidth={isMobileWidth}
|
isMobileWidth={isMobileWidth}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -9,7 +9,7 @@ const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(
|
||||||
return (
|
return (
|
||||||
<textarea
|
<textarea
|
||||||
className={cn(
|
className={cn(
|
||||||
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50",
|
"flex min-h-[80px] w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background placeholder:text-muted-foreground disabled:cursor-not-allowed disabled:opacity-50",
|
||||||
className,
|
className,
|
||||||
)}
|
)}
|
||||||
ref={ref}
|
ref={ref}
|
||||||
|
|
|
@ -1387,6 +1387,10 @@ class FileObjectAdapters:
|
||||||
async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
async def async_get_file_objects_by_name(user: KhojUser, file_name: str, agent: Agent = None):
|
||||||
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
|
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name=file_name, agent=agent))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def async_get_file_objects_by_names(user: KhojUser, file_names: List[str]):
|
||||||
|
return await sync_to_async(list)(FileObject.objects.filter(user=user, file_name__in=file_names))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def async_get_all_file_objects(user: KhojUser):
|
async def async_get_all_file_objects(user: KhojUser):
|
||||||
return await sync_to_async(list)(FileObject.objects.filter(user=user))
|
return await sync_to_async(list)(FileObject.objects.filter(user=user))
|
||||||
|
|
|
@ -458,7 +458,11 @@ class Conversation(BaseModel):
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
conversation_log = models.JSONField(default=dict)
|
conversation_log = models.JSONField(default=dict)
|
||||||
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
# Slug is an app-generated conversation identifier. Need not be unique. Used as display title essentially.
|
||||||
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
|
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
|
|
||||||
|
# The title field is explicitly set by the user.
|
||||||
title = models.CharField(max_length=200, default=None, null=True, blank=True)
|
title = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||||
agent = models.ForeignKey(Agent, on_delete=models.SET_NULL, default=None, null=True, blank=True)
|
agent = models.ForeignKey(Agent, on_delete=models.SET_NULL, default=None, null=True, blank=True)
|
||||||
file_filters = models.JSONField(default=list)
|
file_filters = models.JSONField(default=list)
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import tempfile
|
||||||
from datetime import datetime
|
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from langchain_community.document_loaders import Docx2txtLoader
|
from langchain_community.document_loaders import Docx2txtLoader
|
||||||
|
@ -58,28 +57,13 @@ class DocxToEntries(TextToEntries):
|
||||||
file_to_text_map = dict()
|
file_to_text_map = dict()
|
||||||
for docx_file in docx_files:
|
for docx_file in docx_files:
|
||||||
try:
|
try:
|
||||||
timestamp_now = datetime.utcnow().timestamp()
|
docx_texts = DocxToEntries.extract_text(docx_files[docx_file])
|
||||||
tmp_file = f"tmp_docx_file_{timestamp_now}.docx"
|
|
||||||
with open(tmp_file, "wb") as f:
|
|
||||||
bytes_content = docx_files[docx_file]
|
|
||||||
f.write(bytes_content)
|
|
||||||
|
|
||||||
# Load the content using Docx2txtLoader
|
|
||||||
loader = Docx2txtLoader(tmp_file)
|
|
||||||
docx_entries_per_file = loader.load()
|
|
||||||
|
|
||||||
# Convert the loaded entries into the desired format
|
|
||||||
docx_texts = [page.page_content for page in docx_entries_per_file]
|
|
||||||
|
|
||||||
entry_to_location_map += zip(docx_texts, [docx_file] * len(docx_texts))
|
entry_to_location_map += zip(docx_texts, [docx_file] * len(docx_texts))
|
||||||
entries.extend(docx_texts)
|
entries.extend(docx_texts)
|
||||||
file_to_text_map[docx_file] = docx_texts
|
file_to_text_map[docx_file] = docx_texts
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Unable to process file: {docx_file}. This file will not be indexed.")
|
logger.warning(f"Unable to extract entries from file: {docx_file}")
|
||||||
logger.warning(e, exc_info=True)
|
logger.warning(e, exc_info=True)
|
||||||
finally:
|
|
||||||
if os.path.exists(f"{tmp_file}"):
|
|
||||||
os.remove(f"{tmp_file}")
|
|
||||||
return file_to_text_map, DocxToEntries.convert_docx_entries_to_maps(entries, dict(entry_to_location_map))
|
return file_to_text_map, DocxToEntries.convert_docx_entries_to_maps(entries, dict(entry_to_location_map))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -103,3 +87,25 @@ class DocxToEntries(TextToEntries):
|
||||||
logger.debug(f"Converted {len(parsed_entries)} DOCX entries to dictionaries")
|
logger.debug(f"Converted {len(parsed_entries)} DOCX entries to dictionaries")
|
||||||
|
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_text(docx_file):
|
||||||
|
"""Extract text from specified DOCX file"""
|
||||||
|
try:
|
||||||
|
docx_entry_by_pages = []
|
||||||
|
# Create temp file with .docx extension that gets auto-deleted
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".docx", delete=True) as tmp:
|
||||||
|
tmp.write(docx_file)
|
||||||
|
tmp.flush() # Ensure all data is written
|
||||||
|
|
||||||
|
# Load the content using Docx2txtLoader
|
||||||
|
loader = Docx2txtLoader(tmp.name)
|
||||||
|
docx_entries_per_file = loader.load()
|
||||||
|
|
||||||
|
# Convert the loaded entries into the desired format
|
||||||
|
docx_entry_by_pages = [page.page_content for page in docx_entries_per_file]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Unable to extract text from file: {docx_file}")
|
||||||
|
logger.warning(e, exc_info=True)
|
||||||
|
|
||||||
|
return docx_entry_by_pages
|
||||||
|
|
|
@ -1,13 +1,10 @@
|
||||||
import base64
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import tempfile
|
||||||
from datetime import datetime
|
from io import BytesIO
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
from langchain_community.document_loaders import PyMuPDFLoader
|
from langchain_community.document_loaders import PyMuPDFLoader
|
||||||
|
|
||||||
# importing FileObjectAdapter so that we can add new files and debug file object db.
|
|
||||||
# from khoj.database.adapters import FileObjectAdapters
|
|
||||||
from khoj.database.models import Entry as DbEntry
|
from khoj.database.models import Entry as DbEntry
|
||||||
from khoj.database.models import KhojUser
|
from khoj.database.models import KhojUser
|
||||||
from khoj.processor.content.text_to_entries import TextToEntries
|
from khoj.processor.content.text_to_entries import TextToEntries
|
||||||
|
@ -60,31 +57,13 @@ class PdfToEntries(TextToEntries):
|
||||||
entry_to_location_map: List[Tuple[str, str]] = []
|
entry_to_location_map: List[Tuple[str, str]] = []
|
||||||
for pdf_file in pdf_files:
|
for pdf_file in pdf_files:
|
||||||
try:
|
try:
|
||||||
# Write the PDF file to a temporary file, as it is stored in byte format in the pdf_file object and the PDF Loader expects a file path
|
pdf_entries_per_file = PdfToEntries.extract_text(pdf_files[pdf_file])
|
||||||
timestamp_now = datetime.utcnow().timestamp()
|
entry_to_location_map += zip(pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file))
|
||||||
tmp_file = f"tmp_pdf_file_{timestamp_now}.pdf"
|
|
||||||
with open(f"{tmp_file}", "wb") as f:
|
|
||||||
bytes = pdf_files[pdf_file]
|
|
||||||
f.write(bytes)
|
|
||||||
try:
|
|
||||||
loader = PyMuPDFLoader(f"{tmp_file}", extract_images=False)
|
|
||||||
pdf_entries_per_file = [page.page_content for page in loader.load()]
|
|
||||||
except ImportError:
|
|
||||||
loader = PyMuPDFLoader(f"{tmp_file}")
|
|
||||||
pdf_entries_per_file = [
|
|
||||||
page.page_content for page in loader.load()
|
|
||||||
] # page_content items list for a given pdf.
|
|
||||||
entry_to_location_map += zip(
|
|
||||||
pdf_entries_per_file, [pdf_file] * len(pdf_entries_per_file)
|
|
||||||
) # this is an indexed map of pdf_entries for the pdf.
|
|
||||||
entries.extend(pdf_entries_per_file)
|
entries.extend(pdf_entries_per_file)
|
||||||
file_to_text_map[pdf_file] = pdf_entries_per_file
|
file_to_text_map[pdf_file] = pdf_entries_per_file
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
|
logger.warning(f"Unable to extract entries from file: {pdf_file}")
|
||||||
logger.warning(e, exc_info=True)
|
logger.warning(e, exc_info=True)
|
||||||
finally:
|
|
||||||
if os.path.exists(f"{tmp_file}"):
|
|
||||||
os.remove(f"{tmp_file}")
|
|
||||||
|
|
||||||
return file_to_text_map, PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
|
return file_to_text_map, PdfToEntries.convert_pdf_entries_to_maps(entries, dict(entry_to_location_map))
|
||||||
|
|
||||||
|
@ -109,3 +88,32 @@ class PdfToEntries(TextToEntries):
|
||||||
logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries")
|
logger.debug(f"Converted {len(parsed_entries)} PDF entries to dictionaries")
|
||||||
|
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract_text(pdf_file):
|
||||||
|
"""Extract text from specified PDF files"""
|
||||||
|
try:
|
||||||
|
# Create temp file with .pdf extension that gets auto-deleted
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=True) as tmpf:
|
||||||
|
tmpf.write(pdf_file)
|
||||||
|
tmpf.flush() # Ensure all data is written
|
||||||
|
|
||||||
|
# Load the content using PyMuPDFLoader
|
||||||
|
loader = PyMuPDFLoader(tmpf.name, extract_images=True)
|
||||||
|
pdf_entries_per_file = loader.load()
|
||||||
|
|
||||||
|
# Convert the loaded entries into the desired format
|
||||||
|
pdf_entry_by_pages = [PdfToEntries.clean_text(page.page_content) for page in pdf_entries_per_file]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Unable to process file: {pdf_file}. This file will not be indexed.")
|
||||||
|
logger.warning(e, exc_info=True)
|
||||||
|
|
||||||
|
return pdf_entry_by_pages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clean_text(text: str) -> str:
|
||||||
|
# Remove null bytes
|
||||||
|
text = text.replace("\x00", "")
|
||||||
|
# Replace invalid Unicode
|
||||||
|
text = text.encode("utf-8", errors="ignore").decode("utf-8")
|
||||||
|
return text
|
||||||
|
|
|
@ -36,6 +36,7 @@ def extract_questions_anthropic(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -82,9 +83,12 @@ def extract_questions_anthropic(
|
||||||
images=query_images,
|
images=query_images,
|
||||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
|
attached_file_context=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [ChatMessage(content=prompt, role="user")]
|
messages = []
|
||||||
|
|
||||||
|
messages.append(ChatMessage(content=prompt, role="user"))
|
||||||
|
|
||||||
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
||||||
|
|
||||||
|
@ -148,6 +152,7 @@ def converse_anthropic(
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -205,6 +210,7 @@ def converse_anthropic(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
model_type=ChatModelOptions.ModelType.ANTHROPIC,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
messages, system_prompt = format_messages_for_anthropic(messages, system_prompt)
|
||||||
|
|
|
@ -37,6 +37,7 @@ def extract_questions_gemini(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -83,9 +84,13 @@ def extract_questions_gemini(
|
||||||
images=query_images,
|
images=query_images,
|
||||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
|
attached_file_context=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [ChatMessage(content=prompt, role="user"), ChatMessage(content=system_prompt, role="system")]
|
messages = []
|
||||||
|
|
||||||
|
messages.append(ChatMessage(content=prompt, role="user"))
|
||||||
|
messages.append(ChatMessage(content=system_prompt, role="system"))
|
||||||
|
|
||||||
response = gemini_send_message_to_model(
|
response = gemini_send_message_to_model(
|
||||||
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
messages, api_key, model, response_type="json_object", temperature=temperature, tracer=tracer
|
||||||
|
@ -108,7 +113,13 @@ def extract_questions_gemini(
|
||||||
|
|
||||||
|
|
||||||
def gemini_send_message_to_model(
|
def gemini_send_message_to_model(
|
||||||
messages, api_key, model, response_type="text", temperature=0, model_kwargs=None, tracer={}
|
messages,
|
||||||
|
api_key,
|
||||||
|
model,
|
||||||
|
response_type="text",
|
||||||
|
temperature=0,
|
||||||
|
model_kwargs=None,
|
||||||
|
tracer={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Send message to model
|
Send message to model
|
||||||
|
@ -151,6 +162,7 @@ def converse_gemini(
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
|
query_files: str = None,
|
||||||
tracer={},
|
tracer={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -209,6 +221,7 @@ def converse_gemini(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.GOOGLE,
|
model_type=ChatModelOptions.ModelType.GOOGLE,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
messages, system_prompt = format_messages_for_gemini(messages, system_prompt)
|
||||||
|
|
|
@ -37,6 +37,7 @@ def extract_questions_offline(
|
||||||
max_prompt_size: int = None,
|
max_prompt_size: int = None,
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -87,6 +88,7 @@ def extract_questions_offline(
|
||||||
loaded_model=offline_chat_model,
|
loaded_model=offline_chat_model,
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
state.chat_lock.acquire()
|
state.chat_lock.acquire()
|
||||||
|
@ -152,6 +154,7 @@ def converse_offline(
|
||||||
location_data: LocationData = None,
|
location_data: LocationData = None,
|
||||||
user_name: str = None,
|
user_name: str = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||||
"""
|
"""
|
||||||
|
@ -216,6 +219,7 @@ def converse_offline(
|
||||||
max_prompt_size=max_prompt_size,
|
max_prompt_size=max_prompt_size,
|
||||||
tokenizer_name=tokenizer_name,
|
tokenizer_name=tokenizer_name,
|
||||||
model_type=ChatModelOptions.ModelType.OFFLINE,
|
model_type=ChatModelOptions.ModelType.OFFLINE,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
|
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
|
||||||
|
|
|
@ -34,6 +34,7 @@ def extract_questions(
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_enabled: bool = False,
|
vision_enabled: bool = False,
|
||||||
personality_context: Optional[str] = None,
|
personality_context: Optional[str] = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -79,9 +80,11 @@ def extract_questions(
|
||||||
images=query_images,
|
images=query_images,
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
|
attached_file_context=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [ChatMessage(content=prompt, role="user")]
|
messages = []
|
||||||
|
messages.append(ChatMessage(content=prompt, role="user"))
|
||||||
|
|
||||||
response = send_message_to_model(
|
response = send_message_to_model(
|
||||||
messages,
|
messages,
|
||||||
|
@ -148,6 +151,7 @@ def converse(
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
query_images: Optional[list[str]] = None,
|
query_images: Optional[list[str]] = None,
|
||||||
vision_available: bool = False,
|
vision_available: bool = False,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -206,6 +210,7 @@ def converse(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=ChatModelOptions.ModelType.OPENAI,
|
model_type=ChatModelOptions.ModelType.OPENAI,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
|
truncated_messages = "\n".join({f"{message.content[:70]}..." for message in messages})
|
||||||
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
logger.debug(f"Conversation Context for GPT: {truncated_messages}")
|
||||||
|
|
|
@ -988,16 +988,27 @@ You are an extremely smart and helpful title generator assistant. Given a user q
|
||||||
|
|
||||||
# Examples:
|
# Examples:
|
||||||
User: Show a new Calvin and Hobbes quote every morning at 9am. My Current Location: Shanghai, China
|
User: Show a new Calvin and Hobbes quote every morning at 9am. My Current Location: Shanghai, China
|
||||||
Khoj: Your daily Calvin and Hobbes Quote
|
Assistant: Your daily Calvin and Hobbes Quote
|
||||||
|
|
||||||
User: Notify me when version 2.0.0 of the sentence transformers python package is released. My Current Location: Mexico City, Mexico
|
User: Notify me when version 2.0.0 of the sentence transformers python package is released. My Current Location: Mexico City, Mexico
|
||||||
Khoj: Sentence Transformers Python Package Version 2.0.0 Release
|
Assistant: Sentence Transformers Python Package Version 2.0.0 Release
|
||||||
|
|
||||||
User: Gather the latest tech news on the first sunday of every month.
|
User: Gather the latest tech news on the first sunday of every month.
|
||||||
Khoj: Your Monthly Dose of Tech News
|
Assistant: Your Monthly Dose of Tech News
|
||||||
|
|
||||||
User Query: {query}
|
User Query: {query}
|
||||||
Khoj:
|
Assistant:
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_title_generation = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
You are an extremely smart and helpful title generator assistant. Given a conversation, extract the subject of the conversation. Crisp, informative, ten words or less.
|
||||||
|
|
||||||
|
Conversation History:
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
Assistant:
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -36,6 +36,7 @@ from khoj.utils.helpers import (
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
merge_dicts,
|
merge_dicts,
|
||||||
)
|
)
|
||||||
|
from khoj.utils.rawconfig import FileAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -146,7 +147,7 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
|
||||||
chat_history += f"User: {chat['intent']['query']}\n"
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
|
|
||||||
if chat["intent"].get("inferred-queries"):
|
if chat["intent"].get("inferred-queries"):
|
||||||
chat_history += f'Khoj: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
|
chat_history += f'{agent_name}: {{"queries": {chat["intent"].get("inferred-queries")}}}\n'
|
||||||
|
|
||||||
chat_history += f"{agent_name}: {chat['message']}\n\n"
|
chat_history += f"{agent_name}: {chat['message']}\n\n"
|
||||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||||
|
@ -155,6 +156,16 @@ def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="A
|
||||||
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
||||||
chat_history += f"User: {chat['intent']['query']}\n"
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
||||||
|
elif chat["by"] == "you":
|
||||||
|
raw_query_files = chat.get("queryFiles")
|
||||||
|
if raw_query_files:
|
||||||
|
query_files: Dict[str, str] = {}
|
||||||
|
for file in raw_query_files:
|
||||||
|
query_files[file["name"]] = file["content"]
|
||||||
|
|
||||||
|
query_file_context = gather_raw_query_files(query_files)
|
||||||
|
chat_history += f"User: {query_file_context}\n"
|
||||||
|
|
||||||
return chat_history
|
return chat_history
|
||||||
|
|
||||||
|
|
||||||
|
@ -243,8 +254,9 @@ def save_to_conversation_log(
|
||||||
conversation_id: str = None,
|
conversation_id: str = None,
|
||||||
automation_id: str = None,
|
automation_id: str = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
tracer: Dict[str, Any] = {},
|
raw_query_files: List[FileAttachment] = [],
|
||||||
train_of_thought: List[Any] = [],
|
train_of_thought: List[Any] = [],
|
||||||
|
tracer: Dict[str, Any] = {},
|
||||||
):
|
):
|
||||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
turn_id = tracer.get("mid") or str(uuid.uuid4())
|
turn_id = tracer.get("mid") or str(uuid.uuid4())
|
||||||
|
@ -255,6 +267,7 @@ def save_to_conversation_log(
|
||||||
"created": user_message_time,
|
"created": user_message_time,
|
||||||
"images": query_images,
|
"images": query_images,
|
||||||
"turnId": turn_id,
|
"turnId": turn_id,
|
||||||
|
"queryFiles": [file.model_dump(mode="json") for file in raw_query_files],
|
||||||
},
|
},
|
||||||
khoj_message_metadata={
|
khoj_message_metadata={
|
||||||
"context": compiled_references,
|
"context": compiled_references,
|
||||||
|
@ -289,25 +302,50 @@ Khoj: "{inferred_queries if ("text-to-image" in intent_type) else chat_response}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def construct_structured_message(message: str, images: list[str], model_type: str, vision_enabled: bool):
|
def construct_structured_message(
|
||||||
|
message: str, images: list[str], model_type: str, vision_enabled: bool, attached_file_context: str
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Format messages into appropriate multimedia format for supported chat model types
|
Format messages into appropriate multimedia format for supported chat model types
|
||||||
"""
|
"""
|
||||||
if not images or not vision_enabled:
|
|
||||||
return message
|
|
||||||
|
|
||||||
if model_type in [
|
if model_type in [
|
||||||
ChatModelOptions.ModelType.OPENAI,
|
ChatModelOptions.ModelType.OPENAI,
|
||||||
ChatModelOptions.ModelType.GOOGLE,
|
ChatModelOptions.ModelType.GOOGLE,
|
||||||
ChatModelOptions.ModelType.ANTHROPIC,
|
ChatModelOptions.ModelType.ANTHROPIC,
|
||||||
]:
|
]:
|
||||||
return [
|
constructed_messages: List[Any] = [
|
||||||
{"type": "text", "text": message},
|
{"type": "text", "text": message},
|
||||||
*[{"type": "image_url", "image_url": {"url": image}} for image in images],
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if not is_none_or_empty(attached_file_context):
|
||||||
|
constructed_messages.append({"type": "text", "text": attached_file_context})
|
||||||
|
if vision_enabled and images:
|
||||||
|
for image in images:
|
||||||
|
constructed_messages.append({"type": "image_url", "image_url": {"url": image}})
|
||||||
|
return constructed_messages
|
||||||
|
|
||||||
|
if not is_none_or_empty(attached_file_context):
|
||||||
|
return f"{attached_file_context}\n\n{message}"
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def gather_raw_query_files(
|
||||||
|
query_files: Dict[str, str],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Gather contextual data from the given (raw) files
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(query_files) == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
contextual_data = " ".join(
|
||||||
|
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in query_files.items()]
|
||||||
|
)
|
||||||
|
return f"I have attached the following files:\n\n{contextual_data}"
|
||||||
|
|
||||||
|
|
||||||
def generate_chatml_messages_with_context(
|
def generate_chatml_messages_with_context(
|
||||||
user_message,
|
user_message,
|
||||||
system_message=None,
|
system_message=None,
|
||||||
|
@ -320,6 +358,7 @@ def generate_chatml_messages_with_context(
|
||||||
vision_enabled=False,
|
vision_enabled=False,
|
||||||
model_type="",
|
model_type="",
|
||||||
context_message="",
|
context_message="",
|
||||||
|
query_files: str = None,
|
||||||
):
|
):
|
||||||
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
"""Generate chat messages with appropriate context from previous conversation to send to the chat model"""
|
||||||
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
# Set max prompt size from user config or based on pre-configured for model and machine specs
|
||||||
|
@ -336,6 +375,8 @@ def generate_chatml_messages_with_context(
|
||||||
chatml_messages: List[ChatMessage] = []
|
chatml_messages: List[ChatMessage] = []
|
||||||
for chat in conversation_log.get("chat", []):
|
for chat in conversation_log.get("chat", []):
|
||||||
message_context = ""
|
message_context = ""
|
||||||
|
message_attached_files = ""
|
||||||
|
|
||||||
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
if chat["by"] == "khoj" and "excalidraw" in chat["intent"].get("type", ""):
|
||||||
message_context += chat.get("intent").get("inferred-queries")[0]
|
message_context += chat.get("intent").get("inferred-queries")[0]
|
||||||
if not is_none_or_empty(chat.get("context")):
|
if not is_none_or_empty(chat.get("context")):
|
||||||
|
@ -347,14 +388,27 @@ def generate_chatml_messages_with_context(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
|
message_context += f"{prompts.notes_conversation.format(references=references)}\n\n"
|
||||||
|
|
||||||
|
if chat.get("queryFiles"):
|
||||||
|
raw_query_files = chat.get("queryFiles")
|
||||||
|
query_files_dict = dict()
|
||||||
|
for file in raw_query_files:
|
||||||
|
query_files_dict[file["name"]] = file["content"]
|
||||||
|
|
||||||
|
message_attached_files = gather_raw_query_files(query_files_dict)
|
||||||
|
chatml_messages.append(ChatMessage(content=message_attached_files, role="user"))
|
||||||
|
|
||||||
if not is_none_or_empty(chat.get("onlineContext")):
|
if not is_none_or_empty(chat.get("onlineContext")):
|
||||||
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
|
message_context += f"{prompts.online_search_conversation.format(online_results=chat.get('onlineContext'))}"
|
||||||
|
|
||||||
if not is_none_or_empty(message_context):
|
if not is_none_or_empty(message_context):
|
||||||
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
reconstructed_context_message = ChatMessage(content=message_context, role="user")
|
||||||
chatml_messages.insert(0, reconstructed_context_message)
|
chatml_messages.insert(0, reconstructed_context_message)
|
||||||
|
|
||||||
role = "user" if chat["by"] == "you" else "assistant"
|
role = "user" if chat["by"] == "you" else "assistant"
|
||||||
message_content = construct_structured_message(chat["message"], chat.get("images"), model_type, vision_enabled)
|
message_content = construct_structured_message(
|
||||||
|
chat["message"], chat.get("images"), model_type, vision_enabled, attached_file_context=query_files
|
||||||
|
)
|
||||||
|
|
||||||
reconstructed_message = ChatMessage(content=message_content, role=role)
|
reconstructed_message = ChatMessage(content=message_content, role=role)
|
||||||
chatml_messages.insert(0, reconstructed_message)
|
chatml_messages.insert(0, reconstructed_message)
|
||||||
|
@ -366,14 +420,18 @@ def generate_chatml_messages_with_context(
|
||||||
if not is_none_or_empty(user_message):
|
if not is_none_or_empty(user_message):
|
||||||
messages.append(
|
messages.append(
|
||||||
ChatMessage(
|
ChatMessage(
|
||||||
content=construct_structured_message(user_message, query_images, model_type, vision_enabled),
|
content=construct_structured_message(
|
||||||
|
user_message, query_images, model_type, vision_enabled, query_files
|
||||||
|
),
|
||||||
role="user",
|
role="user",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if not is_none_or_empty(context_message):
|
if not is_none_or_empty(context_message):
|
||||||
messages.append(ChatMessage(content=context_message, role="user"))
|
messages.append(ChatMessage(content=context_message, role="user"))
|
||||||
|
|
||||||
if len(chatml_messages) > 0:
|
if len(chatml_messages) > 0:
|
||||||
messages += chatml_messages
|
messages += chatml_messages
|
||||||
|
|
||||||
if not is_none_or_empty(system_message):
|
if not is_none_or_empty(system_message):
|
||||||
messages.append(ChatMessage(content=system_message, role="system"))
|
messages.append(ChatMessage(content=system_message, role="system"))
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ async def text_to_image(
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
status_code = 200
|
status_code = 200
|
||||||
|
@ -69,6 +70,7 @@ async def text_to_image(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -68,6 +68,7 @@ async def search_online(
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
previous_subqueries: Set = set(),
|
previous_subqueries: Set = set(),
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
query += " ".join(custom_filters)
|
query += " ".join(custom_filters)
|
||||||
|
@ -78,7 +79,14 @@ async def search_online(
|
||||||
|
|
||||||
# Breakdown the query into subqueries to get the correct answer
|
# Breakdown the query into subqueries to get the correct answer
|
||||||
new_subqueries = await generate_online_subqueries(
|
new_subqueries = await generate_online_subqueries(
|
||||||
query, conversation_history, location, user, query_images=query_images, agent=agent, tracer=tracer
|
query,
|
||||||
|
conversation_history,
|
||||||
|
location,
|
||||||
|
user,
|
||||||
|
query_images=query_images,
|
||||||
|
agent=agent,
|
||||||
|
tracer=tracer,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
subqueries = list(new_subqueries - previous_subqueries)
|
subqueries = list(new_subqueries - previous_subqueries)
|
||||||
response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}
|
response_dict: Dict[str, Dict[str, List[Dict] | Dict]] = {}
|
||||||
|
@ -169,13 +177,21 @@ async def read_webpages(
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
tracer: dict = {},
|
|
||||||
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
||||||
|
query_files: str = None,
|
||||||
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"Infer web pages to read from the query and extract relevant information from them"
|
"Infer web pages to read from the query and extract relevant information from them"
|
||||||
logger.info(f"Inferring web pages to read")
|
logger.info(f"Inferring web pages to read")
|
||||||
urls = await infer_webpage_urls(
|
urls = await infer_webpage_urls(
|
||||||
query, conversation_history, location, user, query_images, agent=agent, tracer=tracer
|
query,
|
||||||
|
conversation_history,
|
||||||
|
location,
|
||||||
|
user,
|
||||||
|
query_images,
|
||||||
|
agent=agent,
|
||||||
|
query_files=query_files,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the top 10 web pages to read
|
# Get the top 10 web pages to read
|
||||||
|
|
|
@ -36,6 +36,7 @@ async def run_code(
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
sandbox_url: str = SANDBOX_URL,
|
sandbox_url: str = SANDBOX_URL,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
# Generate Code
|
# Generate Code
|
||||||
|
@ -53,6 +54,7 @@ async def run_code(
|
||||||
query_images,
|
query_images,
|
||||||
agent,
|
agent,
|
||||||
tracer,
|
tracer,
|
||||||
|
query_files,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
||||||
|
@ -82,6 +84,7 @@ async def generate_python_code(
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
|
query_files: str = None,
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
location = f"{location_data}" if location_data else "Unknown"
|
location = f"{location_data}" if location_data else "Unknown"
|
||||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||||
|
@ -109,6 +112,7 @@ async def generate_python_code(
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate that the response is a non-empty, JSON-serializable list
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
|
|
@ -351,6 +351,7 @@ async def extract_references_and_questions(
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
previous_inferred_queries: Set = set(),
|
previous_inferred_queries: Set = set(),
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
user = request.user.object if request.user.is_authenticated else None
|
user = request.user.object if request.user.is_authenticated else None
|
||||||
|
@ -425,6 +426,7 @@ async def extract_references_and_questions(
|
||||||
user=user,
|
user=user,
|
||||||
max_prompt_size=conversation_config.max_prompt_size,
|
max_prompt_size=conversation_config.max_prompt_size,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||||
|
@ -443,6 +445,7 @@ async def extract_references_and_questions(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||||
|
@ -458,6 +461,7 @@ async def extract_references_and_questions(
|
||||||
user=user,
|
user=user,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
|
@ -474,6 +478,7 @@ async def extract_references_and_questions(
|
||||||
user=user,
|
user=user,
|
||||||
vision_enabled=vision_enabled,
|
vision_enabled=vision_enabled,
|
||||||
personality_context=personality_context,
|
personality_context=personality_context,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,6 @@ from khoj.database.adapters import (
|
||||||
AgentAdapters,
|
AgentAdapters,
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
FileObjectAdapters,
|
|
||||||
PublicConversationAdapters,
|
PublicConversationAdapters,
|
||||||
aget_user_name,
|
aget_user_name,
|
||||||
)
|
)
|
||||||
|
@ -45,12 +44,13 @@ from khoj.routers.helpers import (
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
DeleteMessageRequestBody,
|
DeleteMessageRequestBody,
|
||||||
FeedbackData,
|
FeedbackData,
|
||||||
|
acreate_title_from_history,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
aget_relevant_information_sources,
|
aget_relevant_information_sources,
|
||||||
aget_relevant_output_modes,
|
aget_relevant_output_modes,
|
||||||
construct_automation_created_message,
|
construct_automation_created_message,
|
||||||
create_automation,
|
create_automation,
|
||||||
extract_relevant_info,
|
gather_raw_query_files,
|
||||||
generate_excalidraw_diagram,
|
generate_excalidraw_diagram,
|
||||||
generate_summary_from_files,
|
generate_summary_from_files,
|
||||||
get_conversation_command,
|
get_conversation_command,
|
||||||
|
@ -76,7 +76,12 @@ from khoj.utils.helpers import (
|
||||||
get_device,
|
get_device,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
)
|
)
|
||||||
from khoj.utils.rawconfig import FileFilterRequest, FilesFilterRequest, LocationData
|
from khoj.utils.rawconfig import (
|
||||||
|
ChatRequestBody,
|
||||||
|
FileFilterRequest,
|
||||||
|
FilesFilterRequest,
|
||||||
|
LocationData,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize Router
|
# Initialize Router
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -374,7 +379,7 @@ def fork_public_conversation(
|
||||||
{
|
{
|
||||||
"status": "ok",
|
"status": "ok",
|
||||||
"next_url": redirect_uri,
|
"next_url": redirect_uri,
|
||||||
"conversation_id": new_conversation.id,
|
"conversation_id": str(new_conversation.id),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -530,6 +535,32 @@ async def set_conversation_title(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api_chat.post("/title")
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def generate_chat_title(
|
||||||
|
request: Request,
|
||||||
|
common: CommonQueryParams,
|
||||||
|
conversation_id: str,
|
||||||
|
):
|
||||||
|
user: KhojUser = request.user.object
|
||||||
|
conversation = await ConversationAdapters.aget_conversation_by_user(user=user, conversation_id=conversation_id)
|
||||||
|
|
||||||
|
# Conversation.title is explicitly set by the user. Do not override.
|
||||||
|
if conversation.title:
|
||||||
|
return {"status": "ok", "title": conversation.title}
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
raise HTTPException(status_code=404, detail="Conversation not found")
|
||||||
|
|
||||||
|
new_title = await acreate_title_from_history(request.user.object, conversation=conversation)
|
||||||
|
|
||||||
|
conversation.slug = new_title
|
||||||
|
|
||||||
|
conversation.asave()
|
||||||
|
|
||||||
|
return {"status": "ok", "title": new_title}
|
||||||
|
|
||||||
|
|
||||||
@api_chat.delete("/conversation/message", response_class=Response)
|
@api_chat.delete("/conversation/message", response_class=Response)
|
||||||
@requires(["authenticated"])
|
@requires(["authenticated"])
|
||||||
def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response:
|
def delete_message(request: Request, delete_request: DeleteMessageRequestBody) -> Response:
|
||||||
|
@ -571,6 +602,7 @@ async def chat(
|
||||||
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
country_code = body.country_code or get_country_code_from_timezone(body.timezone)
|
||||||
timezone = body.timezone
|
timezone = body.timezone
|
||||||
raw_images = body.images
|
raw_images = body.images
|
||||||
|
raw_query_files = body.files
|
||||||
|
|
||||||
async def event_generator(q: str, images: list[str]):
|
async def event_generator(q: str, images: list[str]):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
@ -582,6 +614,7 @@ async def chat(
|
||||||
q = unquote(q)
|
q = unquote(q)
|
||||||
train_of_thought = []
|
train_of_thought = []
|
||||||
nonlocal conversation_id
|
nonlocal conversation_id
|
||||||
|
nonlocal raw_query_files
|
||||||
|
|
||||||
tracer: dict = {
|
tracer: dict = {
|
||||||
"mid": turn_id,
|
"mid": turn_id,
|
||||||
|
@ -601,6 +634,11 @@ async def chat(
|
||||||
if uploaded_image:
|
if uploaded_image:
|
||||||
uploaded_images.append(uploaded_image)
|
uploaded_images.append(uploaded_image)
|
||||||
|
|
||||||
|
query_files: Dict[str, str] = {}
|
||||||
|
if raw_query_files:
|
||||||
|
for file in raw_query_files:
|
||||||
|
query_files[file.name] = file.content
|
||||||
|
|
||||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||||
nonlocal connection_alive, ttft, train_of_thought
|
nonlocal connection_alive, ttft, train_of_thought
|
||||||
if not connection_alive or await request.is_disconnected():
|
if not connection_alive or await request.is_disconnected():
|
||||||
|
@ -711,6 +749,8 @@ async def chat(
|
||||||
## Extract Document References
|
## Extract Document References
|
||||||
compiled_references: List[Any] = []
|
compiled_references: List[Any] = []
|
||||||
inferred_queries: List[Any] = []
|
inferred_queries: List[Any] = []
|
||||||
|
file_filters = conversation.file_filters if conversation and conversation.file_filters else []
|
||||||
|
attached_file_context = gather_raw_query_files(query_files)
|
||||||
|
|
||||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||||
conversation_commands = await aget_relevant_information_sources(
|
conversation_commands = await aget_relevant_information_sources(
|
||||||
|
@ -720,6 +760,7 @@ async def chat(
|
||||||
user=user,
|
user=user,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
query_files=attached_file_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -765,6 +806,7 @@ async def chat(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
location=location,
|
location=location,
|
||||||
file_filters=conversation.file_filters if conversation else [],
|
file_filters=conversation.file_filters if conversation else [],
|
||||||
|
query_files=attached_file_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(research_result, InformationCollectionIteration):
|
if isinstance(research_result, InformationCollectionIteration):
|
||||||
|
@ -804,10 +846,6 @@ async def chat(
|
||||||
response_log = "No files selected for summarization. Please add files using the section on the left."
|
response_log = "No files selected for summarization. Please add files using the section on the left."
|
||||||
async for result in send_llm_response(response_log):
|
async for result in send_llm_response(response_log):
|
||||||
yield result
|
yield result
|
||||||
elif len(file_filters) > 1 and not agent_has_entries:
|
|
||||||
response_log = "Only one file can be selected for summarization."
|
|
||||||
async for result in send_llm_response(response_log):
|
|
||||||
yield result
|
|
||||||
else:
|
else:
|
||||||
async for response in generate_summary_from_files(
|
async for response in generate_summary_from_files(
|
||||||
q=q,
|
q=q,
|
||||||
|
@ -817,6 +855,7 @@ async def chat(
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
|
query_files=attached_file_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(response, dict) and ChatEvent.STATUS in response:
|
if isinstance(response, dict) and ChatEvent.STATUS in response:
|
||||||
|
@ -837,8 +876,9 @@ async def chat(
|
||||||
client_application=request.user.client_app,
|
client_application=request.user.client_app,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
tracer=tracer,
|
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
|
raw_query_files=raw_query_files,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -882,8 +922,9 @@ async def chat(
|
||||||
inferred_queries=[query_to_run],
|
inferred_queries=[query_to_run],
|
||||||
automation_id=automation.id,
|
automation_id=automation.id,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
tracer=tracer,
|
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
|
raw_query_files=raw_query_files,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
async for result in send_llm_response(llm_response):
|
async for result in send_llm_response(llm_response):
|
||||||
yield result
|
yield result
|
||||||
|
@ -905,6 +946,7 @@ async def chat(
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
query_files=attached_file_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -950,6 +992,7 @@ async def chat(
|
||||||
custom_filters,
|
custom_filters,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
query_files=attached_file_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -975,6 +1018,7 @@ async def chat(
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
query_files=attached_file_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -1015,6 +1059,7 @@ async def chat(
|
||||||
partial(send_event, ChatEvent.STATUS),
|
partial(send_event, ChatEvent.STATUS),
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
query_files=attached_file_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -1055,6 +1100,7 @@ async def chat(
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
query_files=attached_file_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -1086,8 +1132,10 @@ async def chat(
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
tracer=tracer,
|
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
|
attached_file_context=attached_file_context,
|
||||||
|
raw_query_files=raw_query_files,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
content_obj = {
|
content_obj = {
|
||||||
"intentType": intent_type,
|
"intentType": intent_type,
|
||||||
|
@ -1116,6 +1164,7 @@ async def chat(
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||||
|
query_files=attached_file_context,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -1144,8 +1193,10 @@ async def chat(
|
||||||
compiled_references=compiled_references,
|
compiled_references=compiled_references,
|
||||||
online_results=online_results,
|
online_results=online_results,
|
||||||
query_images=uploaded_images,
|
query_images=uploaded_images,
|
||||||
tracer=tracer,
|
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
|
attached_file_context=attached_file_context,
|
||||||
|
raw_query_files=raw_query_files,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for result in send_llm_response(json.dumps(content_obj)):
|
async for result in send_llm_response(json.dumps(content_obj)):
|
||||||
|
@ -1171,8 +1222,10 @@ async def chat(
|
||||||
user_name,
|
user_name,
|
||||||
researched_results,
|
researched_results,
|
||||||
uploaded_images,
|
uploaded_images,
|
||||||
tracer,
|
|
||||||
train_of_thought,
|
train_of_thought,
|
||||||
|
attached_file_context,
|
||||||
|
raw_query_files,
|
||||||
|
tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Send Response
|
# Send Response
|
||||||
|
|
|
@ -36,16 +36,18 @@ from khoj.database.models import (
|
||||||
LocalPlaintextConfig,
|
LocalPlaintextConfig,
|
||||||
NotionConfig,
|
NotionConfig,
|
||||||
)
|
)
|
||||||
|
from khoj.processor.content.docx.docx_to_entries import DocxToEntries
|
||||||
|
from khoj.processor.content.pdf.pdf_to_entries import PdfToEntries
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ApiIndexedDataLimiter,
|
ApiIndexedDataLimiter,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
configure_content,
|
configure_content,
|
||||||
|
get_file_content,
|
||||||
get_user_config,
|
get_user_config,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
)
|
)
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from khoj.utils.config import SearchModels
|
from khoj.utils.config import SearchModels
|
||||||
from khoj.utils.helpers import get_file_type
|
|
||||||
from khoj.utils.rawconfig import (
|
from khoj.utils.rawconfig import (
|
||||||
ContentConfig,
|
ContentConfig,
|
||||||
FullConfig,
|
FullConfig,
|
||||||
|
@ -375,6 +377,75 @@ async def delete_content_source(
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
|
@api_content.post("/convert", status_code=200)
|
||||||
|
@requires(["authenticated"])
|
||||||
|
async def convert_documents(
|
||||||
|
request: Request,
|
||||||
|
files: List[UploadFile],
|
||||||
|
client: Optional[str] = None,
|
||||||
|
):
|
||||||
|
MAX_FILE_SIZE_MB = 10 # 10MB limit
|
||||||
|
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024
|
||||||
|
|
||||||
|
converted_files = []
|
||||||
|
supported_files = ["org", "markdown", "pdf", "plaintext", "docx"]
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
# Check file size first
|
||||||
|
file_size = 0
|
||||||
|
content = await file.read()
|
||||||
|
file_size = len(content)
|
||||||
|
await file.seek(0) # Reset file pointer
|
||||||
|
|
||||||
|
if file_size > MAX_FILE_SIZE_BYTES:
|
||||||
|
logger.warning(
|
||||||
|
f"Skipped converting oversized file ({file_size / 1024 / 1024:.1f}MB) sent by {client} client: {file.filename}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_data = get_file_content(file)
|
||||||
|
if file_data.file_type in supported_files:
|
||||||
|
extracted_content = (
|
||||||
|
file_data.content.decode(file_data.encoding) if file_data.encoding else file_data.content
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_data.file_type == "docx":
|
||||||
|
entries_per_page = DocxToEntries.extract_text(file_data.content)
|
||||||
|
annotated_pages = [
|
||||||
|
f"Page {index} of {file_data.name}:\n\n{entry}" for index, entry in enumerate(entries_per_page)
|
||||||
|
]
|
||||||
|
extracted_content = "\n".join(annotated_pages)
|
||||||
|
|
||||||
|
elif file_data.file_type == "pdf":
|
||||||
|
entries_per_page = PdfToEntries.extract_text(file_data.content)
|
||||||
|
annotated_pages = [
|
||||||
|
f"Page {index} of {file_data.name}:\n\n{entry}" for index, entry in enumerate(entries_per_page)
|
||||||
|
]
|
||||||
|
extracted_content = "\n".join(annotated_pages)
|
||||||
|
|
||||||
|
size_in_bytes = len(extracted_content.encode("utf-8"))
|
||||||
|
|
||||||
|
converted_files.append(
|
||||||
|
{
|
||||||
|
"name": file_data.name,
|
||||||
|
"content": extracted_content,
|
||||||
|
"file_type": file_data.file_type,
|
||||||
|
"size": size_in_bytes,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Skipped converting unsupported file type sent by {client} client: {file.filename}")
|
||||||
|
|
||||||
|
update_telemetry_state(
|
||||||
|
request=request,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="convert_documents",
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Response(content=json.dumps(converted_files), media_type="application/json", status_code=200)
|
||||||
|
|
||||||
|
|
||||||
async def indexer(
|
async def indexer(
|
||||||
request: Request,
|
request: Request,
|
||||||
files: list[UploadFile],
|
files: list[UploadFile],
|
||||||
|
@ -398,12 +469,13 @@ async def indexer(
|
||||||
try:
|
try:
|
||||||
logger.info(f"📬 Updating content index via API call by {client} client")
|
logger.info(f"📬 Updating content index via API call by {client} client")
|
||||||
for file in files:
|
for file in files:
|
||||||
file_content = file.file.read()
|
file_data = get_file_content(file)
|
||||||
file_type, encoding = get_file_type(file.content_type, file_content)
|
if file_data.file_type in index_files:
|
||||||
if file_type in index_files:
|
index_files[file_data.file_type][file_data.name] = (
|
||||||
index_files[file_type][file.filename] = file_content.decode(encoding) if encoding else file_content
|
file_data.content.decode(file_data.encoding) if file_data.encoding else file_data.content
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file.filename}")
|
logger.warning(f"Skipped indexing unsupported file type sent by {client} client: {file_data.name}")
|
||||||
|
|
||||||
indexer_input = IndexerInput(
|
indexer_input = IndexerInput(
|
||||||
org=index_files["org"],
|
org=index_files["org"],
|
||||||
|
|
|
@ -105,6 +105,7 @@ from khoj.utils.config import OfflineChatProcessorModel
|
||||||
from khoj.utils.helpers import (
|
from khoj.utils.helpers import (
|
||||||
LRU,
|
LRU,
|
||||||
ConversationCommand,
|
ConversationCommand,
|
||||||
|
get_file_type,
|
||||||
is_none_or_empty,
|
is_none_or_empty,
|
||||||
is_valid_url,
|
is_valid_url,
|
||||||
log_telemetry,
|
log_telemetry,
|
||||||
|
@ -112,7 +113,7 @@ from khoj.utils.helpers import (
|
||||||
timer,
|
timer,
|
||||||
tool_descriptions_for_llm,
|
tool_descriptions_for_llm,
|
||||||
)
|
)
|
||||||
from khoj.utils.rawconfig import LocationData
|
from khoj.utils.rawconfig import ChatRequestBody, FileAttachment, FileData, LocationData
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -168,6 +169,12 @@ async def is_ready_to_chat(user: KhojUser):
|
||||||
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
raise HTTPException(status_code=500, detail="Set your OpenAI API key or enable Local LLM via Khoj settings.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_content(file: UploadFile):
|
||||||
|
file_content = file.file.read()
|
||||||
|
file_type, encoding = get_file_type(file.content_type, file_content)
|
||||||
|
return FileData(name=file.filename, content=file_content, file_type=file_type, encoding=encoding)
|
||||||
|
|
||||||
|
|
||||||
def update_telemetry_state(
|
def update_telemetry_state(
|
||||||
request: Request,
|
request: Request,
|
||||||
telemetry_type: str,
|
telemetry_type: str,
|
||||||
|
@ -249,6 +256,39 @@ async def agenerate_chat_response(*args):
|
||||||
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
return await loop.run_in_executor(executor, generate_chat_response, *args)
|
||||||
|
|
||||||
|
|
||||||
|
def gather_raw_query_files(
|
||||||
|
query_files: Dict[str, str],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Gather contextual data from the given (raw) files
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(query_files) == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
contextual_data = " ".join(
|
||||||
|
[f"File: {file_name}\n\n{file_content}\n\n" for file_name, file_content in query_files.items()]
|
||||||
|
)
|
||||||
|
return f"I have attached the following files:\n\n{contextual_data}"
|
||||||
|
|
||||||
|
|
||||||
|
async def acreate_title_from_history(
|
||||||
|
user: KhojUser,
|
||||||
|
conversation: Conversation,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a title from the given conversation history
|
||||||
|
"""
|
||||||
|
chat_history = construct_chat_history(conversation.conversation_log)
|
||||||
|
|
||||||
|
title_generation_prompt = prompts.conversation_title_generation.format(chat_history=chat_history)
|
||||||
|
|
||||||
|
with timer("Chat actor: Generate title from conversation history", logger):
|
||||||
|
response = await send_message_to_model_wrapper(title_generation_prompt, user=user)
|
||||||
|
|
||||||
|
return response.strip()
|
||||||
|
|
||||||
|
|
||||||
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
async def acreate_title_from_query(query: str, user: KhojUser = None) -> str:
|
||||||
"""
|
"""
|
||||||
Create a title from the given query
|
Create a title from the given query
|
||||||
|
@ -294,6 +334,7 @@ async def aget_relevant_information_sources(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -331,6 +372,7 @@ async def aget_relevant_information_sources(
|
||||||
relevant_tools_prompt,
|
relevant_tools_prompt,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -440,6 +482,7 @@ async def infer_webpage_urls(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -469,6 +512,7 @@ async def infer_webpage_urls(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -494,6 +538,7 @@ async def generate_online_subqueries(
|
||||||
user: KhojUser,
|
user: KhojUser,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -523,6 +568,7 @@ async def generate_online_subqueries(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -645,26 +691,38 @@ async def generate_summary_from_files(
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
file_object = None
|
file_objects = None
|
||||||
if await EntryAdapters.aagent_has_entries(agent):
|
if await EntryAdapters.aagent_has_entries(agent):
|
||||||
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
||||||
if len(file_names) > 0:
|
if len(file_names) > 0:
|
||||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent)
|
file_objects = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent)
|
||||||
|
|
||||||
if len(file_filters) > 0:
|
if (file_objects and len(file_objects) == 0 and not query_files) or (not file_objects and not query_files):
|
||||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
response_log = "Sorry, I couldn't find anything to summarize."
|
||||||
|
|
||||||
if len(file_object) == 0:
|
|
||||||
response_log = "Sorry, I couldn't find the full text of this file."
|
|
||||||
yield response_log
|
yield response_log
|
||||||
return
|
return
|
||||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
|
||||||
|
contextual_data = " ".join([f"File: {file.file_name}\n\n{file.raw_text}" for file in file_objects])
|
||||||
|
|
||||||
|
if query_files:
|
||||||
|
contextual_data += f"\n\n{query_files}"
|
||||||
|
|
||||||
if not q:
|
if not q:
|
||||||
q = "Create a general summary of the file"
|
q = "Create a general summary of the file"
|
||||||
async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"):
|
|
||||||
|
file_names = [file.file_name for file in file_objects]
|
||||||
|
file_names.extend(file_filters)
|
||||||
|
|
||||||
|
all_file_names = ""
|
||||||
|
|
||||||
|
for file_name in file_names:
|
||||||
|
all_file_names += f"- {file_name}\n"
|
||||||
|
|
||||||
|
async for result in send_status_func(f"**Constructing Summary Using:**\n{all_file_names}"):
|
||||||
yield {ChatEvent.STATUS: result}
|
yield {ChatEvent.STATUS: result}
|
||||||
|
|
||||||
response = await extract_relevant_summary(
|
response = await extract_relevant_summary(
|
||||||
|
@ -694,6 +752,7 @@ async def generate_excalidraw_diagram(
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
if send_status_func:
|
if send_status_func:
|
||||||
|
@ -709,6 +768,7 @@ async def generate_excalidraw_diagram(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
user=user,
|
user=user,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -735,6 +795,7 @@ async def generate_better_diagram_description(
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -773,7 +834,11 @@ async def generate_better_diagram_description(
|
||||||
|
|
||||||
with timer("Chat actor: Generate better diagram description", logger):
|
with timer("Chat actor: Generate better diagram description", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
improve_diagram_description_prompt, query_images=query_images, user=user, tracer=tracer
|
improve_diagram_description_prompt,
|
||||||
|
query_images=query_images,
|
||||||
|
user=user,
|
||||||
|
query_files=query_files,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
|
@ -820,6 +885,7 @@ async def generate_better_image_prompt(
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
agent: Agent = None,
|
agent: Agent = None,
|
||||||
|
query_files: str = "",
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -868,7 +934,7 @@ async def generate_better_image_prompt(
|
||||||
|
|
||||||
with timer("Chat actor: Generate contextual image prompt", logger):
|
with timer("Chat actor: Generate contextual image prompt", logger):
|
||||||
response = await send_message_to_model_wrapper(
|
response = await send_message_to_model_wrapper(
|
||||||
image_prompt, query_images=query_images, user=user, tracer=tracer
|
image_prompt, query_images=query_images, user=user, query_files=query_files, tracer=tracer
|
||||||
)
|
)
|
||||||
response = response.strip()
|
response = response.strip()
|
||||||
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
if response.startswith(('"', "'")) and response.endswith(('"', "'")):
|
||||||
|
@ -884,6 +950,7 @@ async def send_message_to_model_wrapper(
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
query_images: List[str] = None,
|
query_images: List[str] = None,
|
||||||
context: str = "",
|
context: str = "",
|
||||||
|
query_files: str = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||||
|
@ -923,6 +990,7 @@ async def send_message_to_model_wrapper(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=conversation_config.model_type,
|
model_type=conversation_config.model_type,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return send_message_to_model_offline(
|
return send_message_to_model_offline(
|
||||||
|
@ -949,6 +1017,7 @@ async def send_message_to_model_wrapper(
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
model_type=conversation_config.model_type,
|
model_type=conversation_config.model_type,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return send_message_to_model(
|
return send_message_to_model(
|
||||||
|
@ -971,6 +1040,7 @@ async def send_message_to_model_wrapper(
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
model_type=conversation_config.model_type,
|
model_type=conversation_config.model_type,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return anthropic_send_message_to_model(
|
return anthropic_send_message_to_model(
|
||||||
|
@ -992,6 +1062,7 @@ async def send_message_to_model_wrapper(
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
model_type=conversation_config.model_type,
|
model_type=conversation_config.model_type,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
return gemini_send_message_to_model(
|
||||||
|
@ -1006,6 +1077,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
system_message: str = "",
|
system_message: str = "",
|
||||||
response_type: str = "text",
|
response_type: str = "text",
|
||||||
user: KhojUser = None,
|
user: KhojUser = None,
|
||||||
|
query_files: str = "",
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
):
|
):
|
||||||
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config(user)
|
||||||
|
@ -1030,6 +1102,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=conversation_config.model_type,
|
model_type=conversation_config.model_type,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return send_message_to_model_offline(
|
return send_message_to_model_offline(
|
||||||
|
@ -1051,6 +1124,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=conversation_config.model_type,
|
model_type=conversation_config.model_type,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
openai_response = send_message_to_model(
|
openai_response = send_message_to_model(
|
||||||
|
@ -1072,6 +1146,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=conversation_config.model_type,
|
model_type=conversation_config.model_type,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return anthropic_send_message_to_model(
|
return anthropic_send_message_to_model(
|
||||||
|
@ -1091,6 +1166,7 @@ def send_message_to_model_wrapper_sync(
|
||||||
max_prompt_size=max_tokens,
|
max_prompt_size=max_tokens,
|
||||||
vision_enabled=vision_available,
|
vision_enabled=vision_available,
|
||||||
model_type=conversation_config.model_type,
|
model_type=conversation_config.model_type,
|
||||||
|
query_files=query_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
return gemini_send_message_to_model(
|
return gemini_send_message_to_model(
|
||||||
|
@ -1120,8 +1196,10 @@ def generate_chat_response(
|
||||||
user_name: Optional[str] = None,
|
user_name: Optional[str] = None,
|
||||||
meta_research: str = "",
|
meta_research: str = "",
|
||||||
query_images: Optional[List[str]] = None,
|
query_images: Optional[List[str]] = None,
|
||||||
tracer: dict = {},
|
|
||||||
train_of_thought: List[Any] = [],
|
train_of_thought: List[Any] = [],
|
||||||
|
query_files: str = None,
|
||||||
|
raw_query_files: List[FileAttachment] = None,
|
||||||
|
tracer: dict = {},
|
||||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||||
# Initialize Variables
|
# Initialize Variables
|
||||||
chat_response = None
|
chat_response = None
|
||||||
|
@ -1142,8 +1220,9 @@ def generate_chat_response(
|
||||||
client_application=client_application,
|
client_application=client_application,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
tracer=tracer,
|
|
||||||
train_of_thought=train_of_thought,
|
train_of_thought=train_of_thought,
|
||||||
|
raw_query_files=raw_query_files,
|
||||||
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
query_to_run = q
|
query_to_run = q
|
||||||
|
@ -1177,6 +1256,7 @@ def generate_chat_response(
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1202,6 +1282,7 @@ def generate_chat_response(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1224,6 +1305,7 @@ def generate_chat_response(
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||||
|
@ -1243,7 +1325,9 @@ def generate_chat_response(
|
||||||
location_data=location_data,
|
location_data=location_data,
|
||||||
user_name=user_name,
|
user_name=user_name,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
query_images=query_images,
|
||||||
vision_available=vision_available,
|
vision_available=vision_available,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1256,23 +1340,6 @@ def generate_chat_response(
|
||||||
return chat_response, metadata
|
return chat_response, metadata
|
||||||
|
|
||||||
|
|
||||||
class ChatRequestBody(BaseModel):
|
|
||||||
q: str
|
|
||||||
n: Optional[int] = 7
|
|
||||||
d: Optional[float] = None
|
|
||||||
stream: Optional[bool] = False
|
|
||||||
title: Optional[str] = None
|
|
||||||
conversation_id: Optional[str] = None
|
|
||||||
turn_id: Optional[str] = None
|
|
||||||
city: Optional[str] = None
|
|
||||||
region: Optional[str] = None
|
|
||||||
country: Optional[str] = None
|
|
||||||
country_code: Optional[str] = None
|
|
||||||
timezone: Optional[str] = None
|
|
||||||
images: Optional[list[str]] = None
|
|
||||||
create_new: Optional[bool] = False
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteMessageRequestBody(BaseModel):
|
class DeleteMessageRequestBody(BaseModel):
|
||||||
conversation_id: str
|
conversation_id: str
|
||||||
turn_id: str
|
turn_id: str
|
||||||
|
|
|
@ -11,6 +11,7 @@ from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.utils import (
|
from khoj.processor.conversation.utils import (
|
||||||
InformationCollectionIteration,
|
InformationCollectionIteration,
|
||||||
clean_json,
|
clean_json,
|
||||||
|
construct_chat_history,
|
||||||
construct_iteration_history,
|
construct_iteration_history,
|
||||||
construct_tool_chat_history,
|
construct_tool_chat_history,
|
||||||
)
|
)
|
||||||
|
@ -19,8 +20,6 @@ from khoj.processor.tools.run_code import run_code
|
||||||
from khoj.routers.api import extract_references_and_questions
|
from khoj.routers.api import extract_references_and_questions
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ChatEvent,
|
ChatEvent,
|
||||||
construct_chat_history,
|
|
||||||
extract_relevant_info,
|
|
||||||
generate_summary_from_files,
|
generate_summary_from_files,
|
||||||
send_message_to_model_wrapper,
|
send_message_to_model_wrapper,
|
||||||
)
|
)
|
||||||
|
@ -47,6 +46,7 @@ async def apick_next_tool(
|
||||||
max_iterations: int = 5,
|
max_iterations: int = 5,
|
||||||
send_status_func: Optional[Callable] = None,
|
send_status_func: Optional[Callable] = None,
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
|
query_files: str = None,
|
||||||
):
|
):
|
||||||
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
|
"""Given a query, determine which of the available tools the agent should use in order to answer appropriately."""
|
||||||
|
|
||||||
|
@ -92,6 +92,7 @@ async def apick_next_tool(
|
||||||
response_type="json_object",
|
response_type="json_object",
|
||||||
user=user,
|
user=user,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -151,6 +152,7 @@ async def execute_information_collection(
|
||||||
location: LocationData = None,
|
location: LocationData = None,
|
||||||
file_filters: List[str] = [],
|
file_filters: List[str] = [],
|
||||||
tracer: dict = {},
|
tracer: dict = {},
|
||||||
|
query_files: str = None,
|
||||||
):
|
):
|
||||||
current_iteration = 0
|
current_iteration = 0
|
||||||
MAX_ITERATIONS = 5
|
MAX_ITERATIONS = 5
|
||||||
|
@ -174,6 +176,7 @@ async def execute_information_collection(
|
||||||
MAX_ITERATIONS,
|
MAX_ITERATIONS,
|
||||||
send_status_func,
|
send_status_func,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
query_files=query_files,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -204,6 +207,7 @@ async def execute_information_collection(
|
||||||
previous_inferred_queries=previous_inferred_queries,
|
previous_inferred_queries=previous_inferred_queries,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
query_files=query_files,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -265,6 +269,7 @@ async def execute_information_collection(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
|
query_files=query_files,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
@ -295,6 +300,7 @@ async def execute_information_collection(
|
||||||
send_status_func,
|
send_status_func,
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
|
query_files=query_files,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
|
@ -320,6 +326,7 @@ async def execute_information_collection(
|
||||||
query_images=query_images,
|
query_images=query_images,
|
||||||
agent=agent,
|
agent=agent,
|
||||||
send_status_func=send_status_func,
|
send_status_func=send_status_func,
|
||||||
|
query_files=query_files,
|
||||||
):
|
):
|
||||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||||
yield result[ChatEvent.STATUS]
|
yield result[ChatEvent.STATUS]
|
||||||
|
|
|
@ -138,6 +138,38 @@ class SearchResponse(ConfigBase):
|
||||||
corpus_id: str
|
corpus_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class FileData(BaseModel):
|
||||||
|
name: str
|
||||||
|
content: bytes
|
||||||
|
file_type: str
|
||||||
|
encoding: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileAttachment(BaseModel):
|
||||||
|
name: str
|
||||||
|
content: str
|
||||||
|
file_type: str
|
||||||
|
size: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequestBody(BaseModel):
|
||||||
|
q: str
|
||||||
|
n: Optional[int] = 7
|
||||||
|
d: Optional[float] = None
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
title: Optional[str] = None
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
turn_id: Optional[str] = None
|
||||||
|
city: Optional[str] = None
|
||||||
|
region: Optional[str] = None
|
||||||
|
country: Optional[str] = None
|
||||||
|
country_code: Optional[str] = None
|
||||||
|
timezone: Optional[str] = None
|
||||||
|
images: Optional[list[str]] = None
|
||||||
|
files: Optional[list[FileAttachment]] = []
|
||||||
|
create_new: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class Entry:
|
class Entry:
|
||||||
raw: str
|
raw: str
|
||||||
compiled: str
|
compiled: str
|
||||||
|
|
|
@ -337,7 +337,6 @@ def test_summarize_one_file(client_offline_chat, default_user2: KhojUser):
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message != ""
|
assert response_message != ""
|
||||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||||
assert response_message != "Only one file can be selected for summarization."
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
@ -375,7 +374,6 @@ def test_summarize_extra_text(client_offline_chat, default_user2: KhojUser):
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message != ""
|
assert response_message != ""
|
||||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||||
assert response_message != "Only one file can be selected for summarization."
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
@ -404,7 +402,7 @@ def test_summarize_multiple_files(client_offline_chat, default_user2: KhojUser):
|
||||||
response_message = response.json()["response"]
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message == "Only one file can be selected for summarization."
|
assert response_message is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
@ -460,7 +458,6 @@ def test_summarize_different_conversation(client_offline_chat, default_user2: Kh
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message != ""
|
assert response_message != ""
|
||||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||||
assert response_message != "Only one file can be selected for summarization."
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
|
|
@ -312,7 +312,6 @@ def test_summarize_one_file(chat_client, default_user2: KhojUser):
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message != ""
|
assert response_message != ""
|
||||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||||
assert response_message != "Only one file can be selected for summarization."
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
@ -344,7 +343,6 @@ def test_summarize_extra_text(chat_client, default_user2: KhojUser):
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message != ""
|
assert response_message != ""
|
||||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||||
assert response_message != "Only one file can be selected for summarization."
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
@ -371,7 +369,7 @@ def test_summarize_multiple_files(chat_client, default_user2: KhojUser):
|
||||||
response_message = response.json()["response"]
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message == "Only one file can be selected for summarization."
|
assert response_message is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
@ -435,7 +433,6 @@ def test_summarize_different_conversation(chat_client, default_user2: KhojUser):
|
||||||
assert (
|
assert (
|
||||||
response_message_conv1 != "No files selected for summarization. Please add files using the section on the left."
|
response_message_conv1 != "No files selected for summarization. Please add files using the section on the left."
|
||||||
)
|
)
|
||||||
assert response_message_conv1 != "Only one file can be selected for summarization."
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
|
Loading…
Reference in a new issue