mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 09:25:06 +01:00
Research Mode: Give Khoj the ability to perform more advanced reasoning (#952)
## Overview Khoj can now go into research mode and use a python code interpreter. These are experimental features that are being released early for feedback and testing. - Research mode allows Khoj to dynamically select the tools it needs to best answer the question. It is also allowed more iterations to get to a satisfactory answer. Its more dynamic train of thought is shown to improve visibility into its thinking. - Adding ability for Khoj to use a python code interpreter is an adjacent capability. It can help Khoj do some data analysis and generate charts for you. A sandboxed python to run code is provided using [cohere-terrarium](https://github.com/cohere-ai/cohere-terrarium?tab=readme-ov-file), [pyodide](https://pyodide.org/). ## Analysis Research mode (significantly?) improves Khoj's information retrieval for more complex queries requiring multi-step lookups but takes longer to run. It can research for longer, requiring less back-n-forth with the user to find an answer. Research mode gives most gains when used with more advanced chat models (like o1, 4o, new claude sonnet and gemini-pro-002). Smaller models improve their response quality but tend to get into repetitive loops more often. ## Next Steps - Get community feedback on research mode. What works, what fails, what is confusing, what'd be cool to have. - Tune Khoj's capabilities for longer autonomous runs and to generalize across a larger range of model sizes ## Miscellaneous Improvements - Khoj's train of thought is saved and shown for all messages, not just the latest one - Render charts generated by Khoj and code running using the code tool on the web app - Align chat input color to currently selected agent color
This commit is contained in:
commit
22f3ed3f5d
34 changed files with 1726 additions and 471 deletions
|
@ -12,7 +12,12 @@ import { processMessageChunk } from "../common/chatFunctions";
|
|||
|
||||
import "katex/dist/katex.min.css";
|
||||
|
||||
import { Context, OnlineContext, StreamMessage } from "../components/chatMessage/chatMessage";
|
||||
import {
|
||||
CodeContext,
|
||||
Context,
|
||||
OnlineContext,
|
||||
StreamMessage,
|
||||
} from "../components/chatMessage/chatMessage";
|
||||
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../common/utils";
|
||||
import { ChatInputArea, ChatOptions } from "../components/chatInputArea/chatInputArea";
|
||||
import { useAuthenticatedData } from "../common/auth";
|
||||
|
@ -37,6 +42,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
const [images, setImages] = useState<string[]>([]);
|
||||
const [processingMessage, setProcessingMessage] = useState(false);
|
||||
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
|
||||
const [isInResearchMode, setIsInResearchMode] = useState(false);
|
||||
const chatInputRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
const setQueryToProcess = props.setQueryToProcess;
|
||||
|
@ -65,6 +71,10 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
if (storedMessage) {
|
||||
setProcessingMessage(true);
|
||||
setQueryToProcess(storedMessage);
|
||||
|
||||
if (storedMessage.trim().startsWith("/research")) {
|
||||
setIsInResearchMode(true);
|
||||
}
|
||||
}
|
||||
}, [setQueryToProcess, props.setImages]);
|
||||
|
||||
|
@ -125,6 +135,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
isMobileWidth={props.isMobileWidth}
|
||||
setUploadedFiles={props.setUploadedFiles}
|
||||
ref={chatInputRef}
|
||||
isResearchModeEnabled={isInResearchMode}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
|
@ -174,6 +185,7 @@ export default function Chat() {
|
|||
trainOfThought: [],
|
||||
context: [],
|
||||
onlineContext: {},
|
||||
codeContext: {},
|
||||
completed: false,
|
||||
timestamp: new Date().toISOString(),
|
||||
rawQuery: queryToProcess || "",
|
||||
|
@ -202,6 +214,7 @@ export default function Chat() {
|
|||
// Track context used for chat response
|
||||
let context: Context[] = [];
|
||||
let onlineContext: OnlineContext = {};
|
||||
let codeContext: CodeContext = {};
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
@ -228,11 +241,12 @@ export default function Chat() {
|
|||
}
|
||||
|
||||
// Track context used for chat response. References are rendered at the end of the chat
|
||||
({ context, onlineContext } = processMessageChunk(
|
||||
({ context, onlineContext, codeContext } = processMessageChunk(
|
||||
event,
|
||||
currentMessage,
|
||||
context,
|
||||
onlineContext,
|
||||
codeContext,
|
||||
));
|
||||
|
||||
setMessages([...messages]);
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
import { Context, OnlineContext, StreamMessage } from "../components/chatMessage/chatMessage";
|
||||
import {
|
||||
CodeContext,
|
||||
Context,
|
||||
OnlineContext,
|
||||
StreamMessage,
|
||||
} from "../components/chatMessage/chatMessage";
|
||||
|
||||
export interface RawReferenceData {
|
||||
context?: Context[];
|
||||
onlineContext?: OnlineContext;
|
||||
codeContext?: CodeContext;
|
||||
}
|
||||
|
||||
export interface ResponseWithIntent {
|
||||
|
@ -67,10 +73,11 @@ export function processMessageChunk(
|
|||
currentMessage: StreamMessage,
|
||||
context: Context[] = [],
|
||||
onlineContext: OnlineContext = {},
|
||||
): { context: Context[]; onlineContext: OnlineContext } {
|
||||
codeContext: CodeContext = {},
|
||||
): { context: Context[]; onlineContext: OnlineContext; codeContext: CodeContext } {
|
||||
const chunk = convertMessageChunkToJson(rawChunk);
|
||||
|
||||
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext };
|
||||
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext };
|
||||
|
||||
if (chunk.type === "status") {
|
||||
console.log(`status: ${chunk.data}`);
|
||||
|
@ -81,7 +88,8 @@ export function processMessageChunk(
|
|||
|
||||
if (references.context) context = references.context;
|
||||
if (references.onlineContext) onlineContext = references.onlineContext;
|
||||
return { context, onlineContext };
|
||||
if (references.codeContext) codeContext = references.codeContext;
|
||||
return { context, onlineContext, codeContext };
|
||||
} else if (chunk.type === "message") {
|
||||
const chunkData = chunk.data;
|
||||
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw
|
||||
|
@ -119,13 +127,41 @@ export function processMessageChunk(
|
|||
console.log(`Completed streaming: ${new Date()}`);
|
||||
|
||||
// Append any references after all the data has been streamed
|
||||
if (codeContext) currentMessage.codeContext = codeContext;
|
||||
if (onlineContext) currentMessage.onlineContext = onlineContext;
|
||||
if (context) currentMessage.context = context;
|
||||
|
||||
// Replace file links with base64 data
|
||||
currentMessage.rawResponse = renderCodeGenImageInline(
|
||||
currentMessage.rawResponse,
|
||||
codeContext,
|
||||
);
|
||||
|
||||
// Add code context files to the message
|
||||
if (codeContext) {
|
||||
Object.entries(codeContext).forEach(([key, value]) => {
|
||||
value.results.output_files?.forEach((file) => {
|
||||
if (file.filename.endsWith(".png") || file.filename.endsWith(".jpg")) {
|
||||
// Don't add the image again if it's already in the message!
|
||||
if (!currentMessage.rawResponse.includes(`![${file.filename}](`)) {
|
||||
currentMessage.rawResponse += `\n\n![${file.filename}](data:image/png;base64,${file.b64_data})`;
|
||||
}
|
||||
} else if (
|
||||
file.filename.endsWith(".txt") ||
|
||||
file.filename.endsWith(".org") ||
|
||||
file.filename.endsWith(".md")
|
||||
) {
|
||||
const decodedText = atob(file.b64_data);
|
||||
currentMessage.rawResponse += `\n\n\`\`\`\n${decodedText}\n\`\`\``;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Mark current message streaming as completed
|
||||
currentMessage.completed = true;
|
||||
}
|
||||
return { context, onlineContext };
|
||||
return { context, onlineContext, codeContext };
|
||||
}
|
||||
|
||||
export function handleImageResponse(imageJson: any, liveStream: boolean): ResponseWithIntent {
|
||||
|
@ -150,6 +186,22 @@ export function handleImageResponse(imageJson: any, liveStream: boolean): Respon
|
|||
return responseWithIntent;
|
||||
}
|
||||
|
||||
export function renderCodeGenImageInline(message: string, codeContext: CodeContext) {
|
||||
if (!codeContext) return message;
|
||||
|
||||
Object.values(codeContext).forEach((contextData) => {
|
||||
contextData.results.output_files?.forEach((file) => {
|
||||
const regex = new RegExp(`!?\\[.*?\\]\\(.*${file.filename}\\)`, "g");
|
||||
if (file.filename.match(/\.(png|jpg|jpeg|gif|webp)$/i)) {
|
||||
const replacement = `![${file.filename}](data:image/${file.filename.split(".").pop()};base64,${file.b64_data})`;
|
||||
message = message.replace(regex, replacement);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
export function modifyFileFilterForConversation(
|
||||
conversationId: string | null,
|
||||
filenames: string[],
|
||||
|
|
|
@ -42,6 +42,13 @@ export function converColorToBgGradient(color: string) {
|
|||
return `${convertToBGGradientClass(color)} dark:border dark:border-neutral-700`;
|
||||
}
|
||||
|
||||
export function convertColorToRingClass(color: string | undefined) {
|
||||
if (color && tailwindColors.includes(color)) {
|
||||
return `focus-visible:ring-${color}-500`;
|
||||
}
|
||||
return `focus-visible:ring-orange-500`;
|
||||
}
|
||||
|
||||
export function convertColorToBorderClass(color: string) {
|
||||
if (tailwindColors.includes(color)) {
|
||||
return `border-${color}-500`;
|
||||
|
|
|
@ -13,13 +13,14 @@ import { ScrollArea } from "@/components/ui/scroll-area";
|
|||
|
||||
import { InlineLoading } from "../loading/loading";
|
||||
|
||||
import { Lightbulb, ArrowDown } from "@phosphor-icons/react";
|
||||
import { Lightbulb, ArrowDown, XCircle } from "@phosphor-icons/react";
|
||||
|
||||
import AgentProfileCard from "../profileCard/profileCard";
|
||||
import { getIconFromIconName } from "@/app/common/iconUtils";
|
||||
import { AgentData } from "@/app/agents/page";
|
||||
import React from "react";
|
||||
import { useIsMobileWidth } from "@/app/common/utils";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
interface ChatResponse {
|
||||
status: string;
|
||||
|
@ -40,24 +41,52 @@ interface ChatHistoryProps {
|
|||
customClassName?: string;
|
||||
}
|
||||
|
||||
function constructTrainOfThought(
|
||||
trainOfThought: string[],
|
||||
lastMessage: boolean,
|
||||
agentColor: string,
|
||||
key: string,
|
||||
completed: boolean = false,
|
||||
) {
|
||||
const lastIndex = trainOfThought.length - 1;
|
||||
return (
|
||||
<div className={`${styles.trainOfThought} shadow-sm`} key={key}>
|
||||
{!completed && <InlineLoading className="float-right" />}
|
||||
interface TrainOfThoughtComponentProps {
|
||||
trainOfThought: string[];
|
||||
lastMessage: boolean;
|
||||
agentColor: string;
|
||||
key: string;
|
||||
completed?: boolean;
|
||||
}
|
||||
|
||||
{trainOfThought.map((train, index) => (
|
||||
function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
|
||||
const lastIndex = props.trainOfThought.length - 1;
|
||||
const [collapsed, setCollapsed] = useState(props.completed);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${!collapsed ? styles.trainOfThought + " shadow-sm" : ""}`}
|
||||
key={props.key}
|
||||
>
|
||||
{!props.completed && <InlineLoading className="float-right" />}
|
||||
{props.completed &&
|
||||
(collapsed ? (
|
||||
<Button
|
||||
className="w-fit text-left justify-start content-start text-xs"
|
||||
onClick={() => setCollapsed(false)}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
What was my train of thought?
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
className="w-fit text-left justify-start content-start text-xs p-0 h-fit"
|
||||
onClick={() => setCollapsed(true)}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
<XCircle size={16} className="mr-1" />
|
||||
Close
|
||||
</Button>
|
||||
))}
|
||||
{!collapsed &&
|
||||
props.trainOfThought.map((train, index) => (
|
||||
<TrainOfThought
|
||||
key={`train-${index}`}
|
||||
message={train}
|
||||
primary={index === lastIndex && lastMessage && !completed}
|
||||
agentColor={agentColor}
|
||||
primary={index === lastIndex && props.lastMessage && !props.completed}
|
||||
agentColor={props.agentColor}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
@ -254,17 +283,16 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
}
|
||||
|
||||
return (
|
||||
<ScrollArea className={`h-[80vh] relative`} ref={scrollAreaRef}>
|
||||
<ScrollArea className={`h-[73vh] relative`} ref={scrollAreaRef}>
|
||||
<div>
|
||||
<div className={`${styles.chatHistory} ${props.customClassName}`}>
|
||||
<div ref={sentinelRef} style={{ height: "1px" }}>
|
||||
{fetchingData && (
|
||||
<InlineLoading message="Loading Conversation" className="opacity-50" />
|
||||
)}
|
||||
{fetchingData && <InlineLoading className="opacity-50" />}
|
||||
</div>
|
||||
{data &&
|
||||
data.chat &&
|
||||
data.chat.map((chatMessage, index) => (
|
||||
<>
|
||||
<ChatMessage
|
||||
key={`${index}fullHistory`}
|
||||
ref={
|
||||
|
@ -274,7 +302,8 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
: // attach ref to the newest fetched message to handle scroll on fetch
|
||||
// note: stabilize index selection against last page having less messages than fetchMessageCount
|
||||
index ===
|
||||
data.chat.length - (currentPage - 1) * fetchMessageCount
|
||||
data.chat.length -
|
||||
(currentPage - 1) * fetchMessageCount
|
||||
? latestFetchedMessageRef
|
||||
: null
|
||||
}
|
||||
|
@ -284,6 +313,18 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
isLastMessage={index === data.chat.length - 1}
|
||||
/>
|
||||
{chatMessage.trainOfThought && chatMessage.by === "khoj" && (
|
||||
<TrainOfThoughtComponent
|
||||
trainOfThought={chatMessage.trainOfThought?.map(
|
||||
(train) => train.data,
|
||||
)}
|
||||
lastMessage={false}
|
||||
agentColor={data?.agent?.color || "orange"}
|
||||
key={`${index}trainOfThought`}
|
||||
completed={true}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
))}
|
||||
{props.incomingMessages &&
|
||||
props.incomingMessages.map((message, index) => {
|
||||
|
@ -296,6 +337,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
message: message.rawQuery,
|
||||
context: [],
|
||||
onlineContext: {},
|
||||
codeContext: {},
|
||||
created: message.timestamp,
|
||||
by: "you",
|
||||
automationId: "",
|
||||
|
@ -304,13 +346,14 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
customClassName="fullHistory"
|
||||
borderLeftColor={`${data?.agent?.color}-500`}
|
||||
/>
|
||||
{message.trainOfThought &&
|
||||
constructTrainOfThought(
|
||||
message.trainOfThought,
|
||||
index === incompleteIncomingMessageIndex,
|
||||
data?.agent?.color || "orange",
|
||||
`${index}trainOfThought`,
|
||||
message.completed,
|
||||
{message.trainOfThought && (
|
||||
<TrainOfThoughtComponent
|
||||
trainOfThought={message.trainOfThought}
|
||||
lastMessage={index === incompleteIncomingMessageIndex}
|
||||
agentColor={data?.agent?.color || "orange"}
|
||||
key={`${index}trainOfThought`}
|
||||
completed={message.completed}
|
||||
/>
|
||||
)}
|
||||
<ChatMessage
|
||||
key={`${index}incoming`}
|
||||
|
@ -319,6 +362,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
message: message.rawResponse,
|
||||
context: message.context,
|
||||
onlineContext: message.onlineContext,
|
||||
codeContext: message.codeContext,
|
||||
created: message.timestamp,
|
||||
by: "khoj",
|
||||
automationId: "",
|
||||
|
@ -345,6 +389,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
message: props.pendingMessage,
|
||||
context: [],
|
||||
onlineContext: {},
|
||||
codeContext: {},
|
||||
created: new Date().getTime().toString(),
|
||||
by: "you",
|
||||
automationId: "",
|
||||
|
@ -372,7 +417,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className={`${props.customClassName} fixed bottom-[15%] z-10`}>
|
||||
<div className={`${props.customClassName} fixed bottom-[20%] z-10`}>
|
||||
{!isNearBottom && (
|
||||
<button
|
||||
title="Scroll to bottom"
|
||||
|
|
|
@ -3,7 +3,15 @@ import React, { useEffect, useRef, useState, forwardRef } from "react";
|
|||
|
||||
import DOMPurify from "dompurify";
|
||||
import "katex/dist/katex.min.css";
|
||||
import { ArrowUp, Microphone, Paperclip, X, Stop } from "@phosphor-icons/react";
|
||||
import {
|
||||
ArrowUp,
|
||||
Microphone,
|
||||
Paperclip,
|
||||
X,
|
||||
Stop,
|
||||
ToggleLeft,
|
||||
ToggleRight,
|
||||
} from "@phosphor-icons/react";
|
||||
|
||||
import {
|
||||
Command,
|
||||
|
@ -29,7 +37,7 @@ import { Popover, PopoverContent } from "@/components/ui/popover";
|
|||
import { PopoverTrigger } from "@radix-ui/react-popover";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
|
||||
import { convertToBGClass } from "@/app/common/colorUtils";
|
||||
import { convertColorToTextClass, convertToBGClass } from "@/app/common/colorUtils";
|
||||
|
||||
import LoginPrompt from "../loginPrompt/loginPrompt";
|
||||
import { uploadDataForIndexing } from "../../common/chatFunctions";
|
||||
|
@ -50,6 +58,7 @@ interface ChatInputProps {
|
|||
isMobileWidth?: boolean;
|
||||
isLoggedIn: boolean;
|
||||
agentColor?: string;
|
||||
isResearchModeEnabled?: boolean;
|
||||
}
|
||||
|
||||
export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((props, ref) => {
|
||||
|
@ -72,6 +81,11 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
const [progressValue, setProgressValue] = useState(0);
|
||||
const [isDragAndDropping, setIsDragAndDropping] = useState(false);
|
||||
|
||||
const [showCommandList, setShowCommandList] = useState(false);
|
||||
const [useResearchMode, setUseResearchMode] = useState<boolean>(
|
||||
props.isResearchModeEnabled || false,
|
||||
);
|
||||
|
||||
const chatInputRef = ref as React.MutableRefObject<HTMLTextAreaElement>;
|
||||
useEffect(() => {
|
||||
if (!uploading) {
|
||||
|
@ -112,6 +126,12 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
fetchImageData();
|
||||
}, [imagePaths]);
|
||||
|
||||
useEffect(() => {
|
||||
if (props.isResearchModeEnabled) {
|
||||
setUseResearchMode(props.isResearchModeEnabled);
|
||||
}
|
||||
}, [props.isResearchModeEnabled]);
|
||||
|
||||
function onSendMessage() {
|
||||
if (imageUploaded) {
|
||||
setImageUploaded(false);
|
||||
|
@ -128,7 +148,12 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
return;
|
||||
}
|
||||
|
||||
props.sendMessage(message.trim());
|
||||
let messageToSend = message.trim();
|
||||
if (useResearchMode) {
|
||||
messageToSend = `/research ${messageToSend}`;
|
||||
}
|
||||
|
||||
props.sendMessage(messageToSend);
|
||||
setMessage("");
|
||||
}
|
||||
|
||||
|
@ -275,6 +300,12 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
chatInputRef.current.style.height = "auto";
|
||||
chatInputRef.current.style.height =
|
||||
Math.max(chatInputRef.current.scrollHeight - 24, 64) + "px";
|
||||
|
||||
if (message.startsWith("/") && message.split(" ").length === 1) {
|
||||
setShowCommandList(true);
|
||||
} else {
|
||||
setShowCommandList(false);
|
||||
}
|
||||
}, [message]);
|
||||
|
||||
function handleDragOver(event: React.DragEvent<HTMLDivElement>) {
|
||||
|
@ -360,9 +391,9 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
</AlertDialogContent>
|
||||
</AlertDialog>
|
||||
)}
|
||||
{message.startsWith("/") && message.split(" ").length === 1 && (
|
||||
{showCommandList && (
|
||||
<div className="flex justify-center text-center">
|
||||
<Popover open={message.startsWith("/")}>
|
||||
<Popover open={showCommandList} onOpenChange={setShowCommandList}>
|
||||
<PopoverTrigger className="flex justify-center text-center"></PopoverTrigger>
|
||||
<PopoverContent
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
|
@ -413,6 +444,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
</Popover>
|
||||
</div>
|
||||
)}
|
||||
<div>
|
||||
<div
|
||||
className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
|
||||
onDragOver={handleDragOver}
|
||||
|
@ -426,7 +458,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
onChange={handleFileChange}
|
||||
style={{ display: "none" }}
|
||||
/>
|
||||
<div className="flex items-end pb-4">
|
||||
<div className="flex items-center">
|
||||
<Button
|
||||
variant={"ghost"}
|
||||
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
|
||||
|
@ -436,11 +468,14 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
<Paperclip className="w-8 h-8" />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex-grow flex flex-col w-full gap-1.5 relative pb-2">
|
||||
<div className="flex-grow flex flex-col w-full gap-1.5 relative">
|
||||
<div className="flex items-center gap-2 overflow-x-auto">
|
||||
{imageUploaded &&
|
||||
imagePaths.map((path, index) => (
|
||||
<div key={index} className="relative flex-shrink-0 pb-3 pt-2 group">
|
||||
<div
|
||||
key={index}
|
||||
className="relative flex-shrink-0 pb-3 pt-2 group"
|
||||
>
|
||||
<img
|
||||
src={path}
|
||||
alt={`img-${index}`}
|
||||
|
@ -459,7 +494,10 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
</div>
|
||||
<Textarea
|
||||
ref={chatInputRef}
|
||||
className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`}
|
||||
className={`border-none focus:border-none
|
||||
focus:outline-none focus-visible:ring-transparent
|
||||
w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700
|
||||
${props.isMobileWidth ? "text-md" : "text-lg"}`}
|
||||
placeholder="Type / to see a list of commands"
|
||||
id="message"
|
||||
autoFocus={true}
|
||||
|
@ -476,7 +514,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
disabled={props.sendDisabled || recording}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-end pb-4">
|
||||
<div className="flex items-center">
|
||||
{recording ? (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
|
@ -530,6 +568,34 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
|
|||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
className="float-right justify-center gap-1 flex items-center p-1.5 mr-2 h-fit"
|
||||
onClick={() => {
|
||||
setUseResearchMode(!useResearchMode);
|
||||
chatInputRef?.current?.focus();
|
||||
}}
|
||||
>
|
||||
<span className="text-muted-foreground text-sm">Research Mode</span>
|
||||
{useResearchMode ? (
|
||||
<ToggleRight
|
||||
className={`w-6 h-6 inline-block ${props.agentColor ? convertColorToTextClass(props.agentColor) : convertColorToTextClass("orange")} rounded-full`}
|
||||
/>
|
||||
) : (
|
||||
<ToggleLeft className={`w-6 h-6 inline-block rounded-full`} />
|
||||
)}
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="text-xs">
|
||||
Research Mode allows you to get more deeply researched, detailed
|
||||
responses. Response times may be longer.
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
|
|
@ -10,6 +10,7 @@ import { createRoot } from "react-dom/client";
|
|||
import "katex/dist/katex.min.css";
|
||||
|
||||
import { TeaserReferencesSection, constructAllReferences } from "../referencePanel/referencePanel";
|
||||
import { renderCodeGenImageInline } from "@/app/common/chatFunctions";
|
||||
|
||||
import {
|
||||
ThumbsUp,
|
||||
|
@ -26,6 +27,7 @@ import {
|
|||
Palette,
|
||||
ClipboardText,
|
||||
Check,
|
||||
Code,
|
||||
Shapes,
|
||||
} from "@phosphor-icons/react";
|
||||
|
||||
|
@ -99,6 +101,26 @@ export interface OnlineContextData {
|
|||
peopleAlsoAsk: PeopleAlsoAsk[];
|
||||
}
|
||||
|
||||
export interface CodeContext {
|
||||
[key: string]: CodeContextData;
|
||||
}
|
||||
|
||||
export interface CodeContextData {
|
||||
code: string;
|
||||
results: {
|
||||
success: boolean;
|
||||
output_files: CodeContextFile[];
|
||||
std_out: string;
|
||||
std_err: string;
|
||||
code_runtime: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface CodeContextFile {
|
||||
filename: string;
|
||||
b64_data: string;
|
||||
}
|
||||
|
||||
interface Intent {
|
||||
type: string;
|
||||
query: string;
|
||||
|
@ -106,6 +128,11 @@ interface Intent {
|
|||
"inferred-queries": string[];
|
||||
}
|
||||
|
||||
interface TrainOfThoughtObject {
|
||||
type: string;
|
||||
data: string;
|
||||
}
|
||||
|
||||
export interface SingleChatMessage {
|
||||
automationId: string;
|
||||
by: string;
|
||||
|
@ -113,6 +140,8 @@ export interface SingleChatMessage {
|
|||
created: string;
|
||||
context: Context[];
|
||||
onlineContext: OnlineContext;
|
||||
codeContext: CodeContext;
|
||||
trainOfThought?: TrainOfThoughtObject[];
|
||||
rawQuery?: string;
|
||||
intent?: Intent;
|
||||
agent?: AgentData;
|
||||
|
@ -124,6 +153,7 @@ export interface StreamMessage {
|
|||
trainOfThought: string[];
|
||||
context: Context[];
|
||||
onlineContext: OnlineContext;
|
||||
codeContext: CodeContext;
|
||||
completed: boolean;
|
||||
rawQuery: string;
|
||||
timestamp: string;
|
||||
|
@ -263,6 +293,10 @@ function chooseIconFromHeader(header: string, iconColor: string) {
|
|||
return <Palette className={`${classNames}`} />;
|
||||
}
|
||||
|
||||
if (compareHeader.includes("code")) {
|
||||
return <Code className={`${classNames}`} />;
|
||||
}
|
||||
|
||||
return <Brain className={`${classNames}`} />;
|
||||
}
|
||||
|
||||
|
@ -273,12 +307,15 @@ export function TrainOfThought(props: TrainOfThoughtProps) {
|
|||
const iconColor = props.primary ? convertColorToTextClass(props.agentColor) : "text-gray-500";
|
||||
const icon = chooseIconFromHeader(header, iconColor);
|
||||
let markdownRendered = DOMPurify.sanitize(md.render(props.message));
|
||||
|
||||
// Remove any header tags from markdownRendered
|
||||
markdownRendered = markdownRendered.replace(/<h[1-6].*?<\/h[1-6]>/g, "");
|
||||
return (
|
||||
<div
|
||||
className={`${styles.trainOfThoughtElement} break-all items-center ${props.primary ? "text-gray-400" : "text-gray-300"} ${styles.trainOfThought} ${props.primary ? styles.primary : ""}`}
|
||||
className={`${styles.trainOfThoughtElement} break-words items-center ${props.primary ? "text-gray-400" : "text-gray-300"} ${styles.trainOfThought} ${props.primary ? styles.primary : ""}`}
|
||||
>
|
||||
{icon}
|
||||
<div dangerouslySetInnerHTML={{ __html: markdownRendered }} />
|
||||
<div dangerouslySetInnerHTML={{ __html: markdownRendered }} className="break-words" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
@ -388,6 +425,48 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
|||
messageToRender = `${userImagesInHtml}${messageToRender}`;
|
||||
}
|
||||
|
||||
if (props.chatMessage.intent && props.chatMessage.intent.type == "text-to-image") {
|
||||
message = `![generated image](data:image/png;base64,${message})`;
|
||||
} else if (props.chatMessage.intent && props.chatMessage.intent.type == "text-to-image2") {
|
||||
message = `![generated image](${message})`;
|
||||
} else if (
|
||||
props.chatMessage.intent &&
|
||||
props.chatMessage.intent.type == "text-to-image-v3"
|
||||
) {
|
||||
message = `![generated image](data:image/webp;base64,${message})`;
|
||||
}
|
||||
if (
|
||||
props.chatMessage.intent &&
|
||||
props.chatMessage.intent.type.includes("text-to-image") &&
|
||||
props.chatMessage.intent["inferred-queries"]?.length > 0
|
||||
) {
|
||||
message += `\n\n${props.chatMessage.intent["inferred-queries"][0]}`;
|
||||
}
|
||||
|
||||
// Replace file links with base64 data
|
||||
message = renderCodeGenImageInline(message, props.chatMessage.codeContext);
|
||||
|
||||
// Add code context files to the message
|
||||
if (props.chatMessage.codeContext) {
|
||||
Object.entries(props.chatMessage.codeContext).forEach(([key, value]) => {
|
||||
value.results.output_files?.forEach((file) => {
|
||||
if (file.filename.endsWith(".png") || file.filename.endsWith(".jpg")) {
|
||||
// Don't add the image again if it's already in the message!
|
||||
if (!message.includes(`![${file.filename}](`)) {
|
||||
message += `\n\n![${file.filename}](data:image/png;base64,${file.b64_data})`;
|
||||
}
|
||||
} else if (
|
||||
file.filename.endsWith(".txt") ||
|
||||
file.filename.endsWith(".org") ||
|
||||
file.filename.endsWith(".md")
|
||||
) {
|
||||
const decodedText = atob(file.b64_data);
|
||||
message += `\n\n\`\`\`\n${decodedText}\n\`\`\``;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// Set the message text
|
||||
setTextRendered(messageForClipboard);
|
||||
|
||||
|
@ -578,6 +657,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
|||
const allReferences = constructAllReferences(
|
||||
props.chatMessage.context,
|
||||
props.chatMessage.onlineContext,
|
||||
props.chatMessage.codeContext,
|
||||
);
|
||||
|
||||
return (
|
||||
|
@ -600,6 +680,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
|
|||
isMobileWidth={props.isMobileWidth}
|
||||
notesReferenceCardData={allReferences.notesReferenceCardData}
|
||||
onlineReferenceCardData={allReferences.onlineReferenceCardData}
|
||||
codeReferenceCardData={allReferences.codeReferenceCardData}
|
||||
/>
|
||||
</div>
|
||||
<div className={styles.chatFooter}>
|
||||
|
|
|
@ -11,7 +11,7 @@ const md = new markdownIt({
|
|||
typographer: true,
|
||||
});
|
||||
|
||||
import { Context, WebPage, OnlineContext } from "../chatMessage/chatMessage";
|
||||
import { Context, WebPage, OnlineContext, CodeContext } from "../chatMessage/chatMessage";
|
||||
import { Card } from "@/components/ui/card";
|
||||
|
||||
import {
|
||||
|
@ -94,11 +94,67 @@ function NotesContextReferenceCard(props: NotesContextReferenceCardProps) {
|
|||
);
|
||||
}
|
||||
|
||||
interface CodeContextReferenceCardProps {
|
||||
code: string;
|
||||
output: string;
|
||||
error: string;
|
||||
showFullContent: boolean;
|
||||
}
|
||||
|
||||
function CodeContextReferenceCard(props: CodeContextReferenceCardProps) {
|
||||
const fileIcon = getIconFromFilename(".py", "w-6 h-6 text-muted-foreground inline-flex mr-2");
|
||||
const snippet = DOMPurify.sanitize(props.code);
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Popover open={isHovering && !props.showFullContent} onOpenChange={setIsHovering}>
|
||||
<PopoverTrigger asChild>
|
||||
<Card
|
||||
onMouseEnter={() => setIsHovering(true)}
|
||||
onMouseLeave={() => setIsHovering(false)}
|
||||
className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words text-balance rounded-lg p-2 bg-muted border-none`}
|
||||
>
|
||||
<h3
|
||||
className={`${props.showFullContent ? "block" : "line-clamp-1"} text-muted-foreground}`}
|
||||
>
|
||||
{fileIcon}
|
||||
Code
|
||||
</h3>
|
||||
<p
|
||||
className={`${props.showFullContent ? "block" : "overflow-hidden line-clamp-2"}`}
|
||||
>
|
||||
{snippet}
|
||||
</p>
|
||||
</Card>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-[400px] mx-2">
|
||||
<Card
|
||||
className={`w-auto overflow-hidden break-words text-balance rounded-lg p-2 border-none`}
|
||||
>
|
||||
<h3 className={`line-clamp-2 text-muted-foreground}`}>
|
||||
{fileIcon}
|
||||
Code
|
||||
</h3>
|
||||
<p className={`overflow-hidden line-clamp-3`}>{snippet}</p>
|
||||
</Card>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export interface ReferencePanelData {
|
||||
notesReferenceCardData: NotesContextReferenceData[];
|
||||
onlineReferenceCardData: OnlineReferenceData[];
|
||||
}
|
||||
|
||||
export interface CodeReferenceData {
|
||||
code: string;
|
||||
output: string;
|
||||
error: string;
|
||||
}
|
||||
|
||||
interface OnlineReferenceData {
|
||||
title: string;
|
||||
description: string;
|
||||
|
@ -214,9 +270,27 @@ function GenericOnlineReferenceCard(props: OnlineReferenceCardProps) {
|
|||
);
|
||||
}
|
||||
|
||||
export function constructAllReferences(contextData: Context[], onlineData: OnlineContext) {
|
||||
export function constructAllReferences(
|
||||
contextData: Context[],
|
||||
onlineData: OnlineContext,
|
||||
codeContext: CodeContext,
|
||||
) {
|
||||
const onlineReferences: OnlineReferenceData[] = [];
|
||||
const contextReferences: NotesContextReferenceData[] = [];
|
||||
const codeReferences: CodeReferenceData[] = [];
|
||||
|
||||
if (codeContext) {
|
||||
for (const [key, value] of Object.entries(codeContext)) {
|
||||
if (!value.results) {
|
||||
continue;
|
||||
}
|
||||
codeReferences.push({
|
||||
code: value.code,
|
||||
output: value.results.std_out,
|
||||
error: value.results.std_err,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (onlineData) {
|
||||
let localOnlineReferences = [];
|
||||
|
@ -298,12 +372,14 @@ export function constructAllReferences(contextData: Context[], onlineData: Onlin
|
|||
return {
|
||||
notesReferenceCardData: contextReferences,
|
||||
onlineReferenceCardData: onlineReferences,
|
||||
codeReferenceCardData: codeReferences,
|
||||
};
|
||||
}
|
||||
|
||||
export interface TeaserReferenceSectionProps {
|
||||
notesReferenceCardData: NotesContextReferenceData[];
|
||||
onlineReferenceCardData: OnlineReferenceData[];
|
||||
codeReferenceCardData: CodeReferenceData[];
|
||||
isMobileWidth: boolean;
|
||||
}
|
||||
|
||||
|
@ -315,16 +391,27 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
|
|||
}, [props.isMobileWidth]);
|
||||
|
||||
const notesDataToShow = props.notesReferenceCardData.slice(0, numTeaserSlots);
|
||||
const codeDataToShow = props.codeReferenceCardData.slice(
|
||||
0,
|
||||
numTeaserSlots - notesDataToShow.length,
|
||||
);
|
||||
const onlineDataToShow =
|
||||
notesDataToShow.length < numTeaserSlots
|
||||
? props.onlineReferenceCardData.slice(0, numTeaserSlots - notesDataToShow.length)
|
||||
notesDataToShow.length + codeDataToShow.length < numTeaserSlots
|
||||
? props.onlineReferenceCardData.slice(
|
||||
0,
|
||||
numTeaserSlots - codeDataToShow.length - notesDataToShow.length,
|
||||
)
|
||||
: [];
|
||||
|
||||
const shouldShowShowMoreButton =
|
||||
props.notesReferenceCardData.length > 0 || props.onlineReferenceCardData.length > 0;
|
||||
props.notesReferenceCardData.length > 0 ||
|
||||
props.codeReferenceCardData.length > 0 ||
|
||||
props.onlineReferenceCardData.length > 0;
|
||||
|
||||
const numReferences =
|
||||
props.notesReferenceCardData.length + props.onlineReferenceCardData.length;
|
||||
props.notesReferenceCardData.length +
|
||||
props.codeReferenceCardData.length +
|
||||
props.onlineReferenceCardData.length;
|
||||
|
||||
if (numReferences === 0) {
|
||||
return null;
|
||||
|
@ -346,6 +433,15 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
|
|||
/>
|
||||
);
|
||||
})}
|
||||
{codeDataToShow.map((code, index) => {
|
||||
return (
|
||||
<CodeContextReferenceCard
|
||||
showFullContent={false}
|
||||
{...code}
|
||||
key={`code-${index}`}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{onlineDataToShow.map((online, index) => {
|
||||
return (
|
||||
<GenericOnlineReferenceCard
|
||||
|
@ -359,6 +455,7 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
|
|||
<ReferencePanel
|
||||
notesReferenceCardData={props.notesReferenceCardData}
|
||||
onlineReferenceCardData={props.onlineReferenceCardData}
|
||||
codeReferenceCardData={props.codeReferenceCardData}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
@ -369,6 +466,7 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
|
|||
interface ReferencePanelDataProps {
|
||||
notesReferenceCardData: NotesContextReferenceData[];
|
||||
onlineReferenceCardData: OnlineReferenceData[];
|
||||
codeReferenceCardData: CodeReferenceData[];
|
||||
}
|
||||
|
||||
export default function ReferencePanel(props: ReferencePanelDataProps) {
|
||||
|
@ -406,6 +504,15 @@ export default function ReferencePanel(props: ReferencePanelDataProps) {
|
|||
/>
|
||||
);
|
||||
})}
|
||||
{props.codeReferenceCardData.map((code, index) => {
|
||||
return (
|
||||
<CodeContextReferenceCard
|
||||
showFullContent={true}
|
||||
{...code}
|
||||
key={`code-${index}`}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</SheetContent>
|
||||
</Sheet>
|
||||
|
|
|
@ -14,25 +14,21 @@ interface SuggestionCardProps {
|
|||
|
||||
export default function SuggestionCard(data: SuggestionCardProps) {
|
||||
const bgColors = converColorToBgGradient(data.color);
|
||||
const cardClassName = `${styles.card} ${bgColors} md:w-full md:h-fit sm:w-full h-fit md:w-[200px] md:h-[200px] cursor-pointer`;
|
||||
const titleClassName = `${styles.title} pt-2 dark:text-white dark:font-bold`;
|
||||
const cardClassName = `${styles.card} ${bgColors} md:w-full md:h-fit sm:w-full h-fit md:w-[200px] md:h-[180px] cursor-pointer md:p-2`;
|
||||
const descriptionClassName = `${styles.text} dark:text-white`;
|
||||
|
||||
const cardContent = (
|
||||
<Card className={cardClassName}>
|
||||
<CardHeader className="m-0 p-2 pb-1 relative">
|
||||
<div className="flex flex-row md:flex-col">
|
||||
<div className="flex">
|
||||
<CardContent className="m-0 p-2">
|
||||
{convertSuggestionTitleToIconClass(data.title, data.color.toLowerCase())}
|
||||
<CardTitle className={titleClassName}>{data.title}</CardTitle>
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent className="m-0 p-2 pr-4 pt-1">
|
||||
<CardDescription
|
||||
className={`${descriptionClassName} sm:line-clamp-2 md:line-clamp-4`}
|
||||
className={`${descriptionClassName} sm:line-clamp-2 md:line-clamp-4 pt-1`}
|
||||
>
|
||||
{data.body}
|
||||
</CardDescription>
|
||||
</CardContent>
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
|
||||
.card {
|
||||
padding: 0.5rem;
|
||||
margin: 0.05rem;
|
||||
border-radius: 0.5rem;
|
||||
}
|
||||
|
||||
.title {
|
||||
font-size: 1.0rem;
|
||||
font-size: 1rem;
|
||||
}
|
||||
|
||||
.text {
|
||||
|
|
|
@ -47,24 +47,24 @@ const DEFAULT_COLOR = "orange";
|
|||
|
||||
export function convertSuggestionTitleToIconClass(title: string, color: string) {
|
||||
if (title === SuggestionType.Automation)
|
||||
return getIconFromIconName("Robot", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Paint) return getIconFromIconName("Palette", color, "w-8", "h-8");
|
||||
return getIconFromIconName("Robot", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Paint) return getIconFromIconName("Palette", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.PopCulture)
|
||||
return getIconFromIconName("Confetti", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Travel) return getIconFromIconName("Jeep", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Learning) return getIconFromIconName("Book", color, "w-8", "h-8");
|
||||
return getIconFromIconName("Confetti", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Travel) return getIconFromIconName("Jeep", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Learning) return getIconFromIconName("Book", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Health)
|
||||
return getIconFromIconName("Asclepius", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Fun) return getIconFromIconName("Island", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Home) return getIconFromIconName("House", color, "w-8", "h-8");
|
||||
return getIconFromIconName("Asclepius", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Fun) return getIconFromIconName("Island", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Home) return getIconFromIconName("House", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Language)
|
||||
return getIconFromIconName("Translate", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Code) return getIconFromIconName("Code", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Food) return getIconFromIconName("BowlFood", color, "w-8", "h-8");
|
||||
return getIconFromIconName("Translate", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Code) return getIconFromIconName("Code", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Food) return getIconFromIconName("BowlFood", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Interviewing)
|
||||
return getIconFromIconName("Lectern", color, "w-8", "h-8");
|
||||
if (title === SuggestionType.Finance) return getIconFromIconName("Wallet", color, "w-8", "h-8");
|
||||
else return getIconFromIconName("Lightbulb", color, "w-8", "h-8");
|
||||
return getIconFromIconName("Lectern", color, "w-6", "h-6");
|
||||
if (title === SuggestionType.Finance) return getIconFromIconName("Wallet", color, "w-6", "h-6");
|
||||
else return getIconFromIconName("Lightbulb", color, "w-6", "h-6");
|
||||
}
|
||||
|
||||
export const suggestionsData: Suggestion[] = [
|
||||
|
|
|
@ -5,6 +5,7 @@ import { useAuthenticatedData } from "@/app/common/auth";
|
|||
import { useState, useEffect } from "react";
|
||||
|
||||
import ChatMessage, {
|
||||
CodeContext,
|
||||
Context,
|
||||
OnlineContext,
|
||||
OnlineContextData,
|
||||
|
@ -46,6 +47,7 @@ interface SupplementReferences {
|
|||
interface ResponseWithReferences {
|
||||
context?: Context[];
|
||||
online?: OnlineContext;
|
||||
code?: CodeContext;
|
||||
response?: string;
|
||||
}
|
||||
|
||||
|
@ -192,6 +194,7 @@ function ReferenceVerification(props: ReferenceVerificationProps) {
|
|||
context: [],
|
||||
created: new Date().toISOString(),
|
||||
onlineContext: {},
|
||||
codeContext: {},
|
||||
}}
|
||||
isMobileWidth={isMobileWidth}
|
||||
/>
|
||||
|
@ -622,6 +625,7 @@ export default function FactChecker() {
|
|||
context: [],
|
||||
created: new Date().toISOString(),
|
||||
onlineContext: {},
|
||||
codeContext: {},
|
||||
}}
|
||||
isMobileWidth={isMobileWidth}
|
||||
/>
|
||||
|
|
|
@ -116,8 +116,8 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
`What would you like to get done${nameSuffix}?`,
|
||||
`Hey${nameSuffix}! How can I help?`,
|
||||
`Good ${timeOfDay}${nameSuffix}! What's on your mind?`,
|
||||
`Ready to breeze through your ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]}?`,
|
||||
`Want help navigating your ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]} workload?`,
|
||||
`Ready to breeze through ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]}?`,
|
||||
`Let's navigate your ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]} workload`,
|
||||
];
|
||||
const greeting = greetings[Math.floor(Math.random() * greetings.length)];
|
||||
setGreeting(greeting);
|
||||
|
@ -305,6 +305,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
conversationId={null}
|
||||
isMobileWidth={props.isMobileWidth}
|
||||
setUploadedFiles={props.setUploadedFiles}
|
||||
agentColor={agents.find((agent) => agent.slug === selectedAgent)?.color}
|
||||
ref={chatInputRef}
|
||||
/>
|
||||
</div>
|
||||
|
@ -386,6 +387,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
conversationId={null}
|
||||
isMobileWidth={props.isMobileWidth}
|
||||
setUploadedFiles={props.setUploadedFiles}
|
||||
agentColor={agents.find((agent) => agent.slug === selectedAgent)?.color}
|
||||
ref={chatInputRef}
|
||||
/>
|
||||
</div>
|
||||
|
|
|
@ -188,6 +188,7 @@ export default function SharedChat() {
|
|||
trainOfThought: [],
|
||||
context: [],
|
||||
onlineContext: {},
|
||||
codeContext: {},
|
||||
completed: false,
|
||||
timestamp: new Date().toISOString(),
|
||||
rawQuery: queryToProcess || "",
|
||||
|
|
|
@ -1,34 +1,44 @@
|
|||
import type { Config } from "tailwindcss"
|
||||
import type { Config } from "tailwindcss";
|
||||
|
||||
const config = {
|
||||
safelist: [
|
||||
{
|
||||
pattern: /to-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
|
||||
variants: ['dark'],
|
||||
pattern:
|
||||
/to-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
|
||||
variants: ["dark"],
|
||||
},
|
||||
{
|
||||
pattern: /text-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
|
||||
variants: ['dark'],
|
||||
pattern:
|
||||
/text-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
|
||||
variants: ["dark"],
|
||||
},
|
||||
{
|
||||
pattern: /border-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
|
||||
variants: ['dark'],
|
||||
pattern:
|
||||
/border-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
|
||||
variants: ["dark"],
|
||||
},
|
||||
{
|
||||
pattern: /border-l-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
|
||||
variants: ['dark'],
|
||||
pattern:
|
||||
/border-l-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
|
||||
variants: ["dark"],
|
||||
},
|
||||
{
|
||||
pattern: /bg-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
|
||||
variants: ['dark'],
|
||||
}
|
||||
pattern:
|
||||
/bg-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
|
||||
variants: ["dark"],
|
||||
},
|
||||
{
|
||||
pattern:
|
||||
/ring-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
|
||||
variants: ["focus-visible", "dark"],
|
||||
},
|
||||
],
|
||||
darkMode: ["class"],
|
||||
content: [
|
||||
'./pages/**/*.{ts,tsx}',
|
||||
'./components/**/*.{ts,tsx}',
|
||||
'./app/**/*.{ts,tsx}',
|
||||
'./src/**/*.{ts,tsx}',
|
||||
"./pages/**/*.{ts,tsx}",
|
||||
"./components/**/*.{ts,tsx}",
|
||||
"./app/**/*.{ts,tsx}",
|
||||
"./src/**/*.{ts,tsx}",
|
||||
],
|
||||
prefix: "",
|
||||
theme: {
|
||||
|
@ -101,9 +111,7 @@ const config = {
|
|||
},
|
||||
},
|
||||
},
|
||||
plugins: [
|
||||
require("tailwindcss-animate"),
|
||||
],
|
||||
} satisfies Config
|
||||
plugins: [require("tailwindcss-animate")],
|
||||
} satisfies Config;
|
||||
|
||||
export default config
|
||||
export default config;
|
||||
|
|
|
@ -662,6 +662,8 @@ class AgentAdapters:
|
|||
|
||||
@staticmethod
|
||||
async def ais_agent_accessible(agent: Agent, user: KhojUser) -> bool:
|
||||
agent = await Agent.objects.select_related("creator").aget(pk=agent.pk)
|
||||
|
||||
if agent.privacy_level == Agent.PrivacyLevel.PUBLIC:
|
||||
return True
|
||||
if agent.creator == user:
|
||||
|
@ -871,9 +873,13 @@ class ConversationAdapters:
|
|||
agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
|
||||
if agent is None:
|
||||
raise HTTPException(status_code=400, detail="No such agent currently exists.")
|
||||
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent, title=title)
|
||||
return await Conversation.objects.select_related("agent", "agent__creator", "agent__chat_model").acreate(
|
||||
user=user, client=client_application, agent=agent, title=title
|
||||
)
|
||||
agent = await AgentAdapters.aget_default_agent()
|
||||
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent, title=title)
|
||||
return await Conversation.objects.select_related("agent", "agent__creator", "agent__chat_model").acreate(
|
||||
user=user, client=client_application, agent=agent, title=title
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_conversation_session(
|
||||
|
@ -1014,7 +1020,14 @@ class ConversationAdapters:
|
|||
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
|
||||
# Get the server chat settings
|
||||
server_chat_settings = ServerChatSettings.objects.first()
|
||||
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
|
||||
|
||||
is_subscribed = is_user_subscribed(user) if user else False
|
||||
if server_chat_settings:
|
||||
# If the user is subscribed and the advanced model is enabled, return the advanced model
|
||||
if is_subscribed and server_chat_settings.chat_advanced:
|
||||
return server_chat_settings.chat_advanced
|
||||
# If the default model is set, return it
|
||||
if server_chat_settings.chat_default:
|
||||
return server_chat_settings.chat_default
|
||||
|
||||
# Get the user's chat settings, if the server chat settings are not set
|
||||
|
@ -1031,10 +1044,19 @@ class ConversationAdapters:
|
|||
# Get the server chat settings
|
||||
server_chat_settings: ServerChatSettings = (
|
||||
await ServerChatSettings.objects.filter()
|
||||
.prefetch_related("chat_default", "chat_default__openai_config")
|
||||
.prefetch_related(
|
||||
"chat_default", "chat_default__openai_config", "chat_advanced", "chat_advanced__openai_config"
|
||||
)
|
||||
.afirst()
|
||||
)
|
||||
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
|
||||
is_subscribed = await ais_user_subscribed(user) if user else False
|
||||
|
||||
if server_chat_settings:
|
||||
# If the user is subscribed and the advanced model is enabled, return the advanced model
|
||||
if is_subscribed and server_chat_settings.chat_advanced:
|
||||
return server_chat_settings.chat_advanced
|
||||
# If the default model is set, return it
|
||||
if server_chat_settings.chat_default:
|
||||
return server_chat_settings.chat_default
|
||||
|
||||
# Get the user's chat settings, if the server chat settings are not set
|
||||
|
@ -1469,7 +1491,9 @@ class EntryAdapters:
|
|||
|
||||
@staticmethod
|
||||
async def aget_agent_entry_filepaths(agent: Agent):
|
||||
return await sync_to_async(list)(Entry.objects.filter(agent=agent).values_list("file_path", flat=True))
|
||||
return await sync_to_async(set)(
|
||||
Entry.objects.filter(agent=agent).distinct("file_path").values_list("file_path", flat=True)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all_filenames_by_source(user: KhojUser, file_source: str):
|
||||
|
@ -1544,11 +1568,11 @@ class EntryAdapters:
|
|||
|
||||
@staticmethod
|
||||
def search_with_embeddings(
|
||||
user: KhojUser,
|
||||
raw_query: str,
|
||||
embeddings: Tensor,
|
||||
user: KhojUser,
|
||||
max_results: int = 10,
|
||||
file_type_filter: str = None,
|
||||
raw_query: str = None,
|
||||
max_distance: float = math.inf,
|
||||
agent: Agent = None,
|
||||
):
|
||||
|
|
|
@ -14,11 +14,13 @@ from khoj.processor.conversation.anthropic.utils import (
|
|||
format_messages_for_anthropic,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
clean_json,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -90,15 +92,13 @@ def extract_questions_anthropic(
|
|||
model_name=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
response_type="json_object",
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Extract, Clean Message from Claude's Response
|
||||
try:
|
||||
response = response.strip()
|
||||
match = re.search(r"\{.*?\}", response)
|
||||
if match:
|
||||
response = match.group()
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response:
|
||||
|
@ -112,7 +112,7 @@ def extract_questions_anthropic(
|
|||
return questions
|
||||
|
||||
|
||||
def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
|
||||
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", tracer={}):
|
||||
"""
|
||||
Send message to model
|
||||
"""
|
||||
|
@ -124,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
|
|||
system_prompt=system_prompt,
|
||||
model_name=model,
|
||||
api_key=api_key,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
@ -132,6 +133,7 @@ def converse_anthropic(
|
|||
references,
|
||||
user_query,
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
conversation_log={},
|
||||
model: Optional[str] = "claude-3-5-sonnet-20241022",
|
||||
api_key: Optional[str] = None,
|
||||
|
@ -151,7 +153,6 @@ def converse_anthropic(
|
|||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
|
@ -175,7 +176,7 @@ def converse_anthropic(
|
|||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
|
@ -183,10 +184,13 @@ def converse_anthropic(
|
|||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(compiled_references):
|
||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n"
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
|
||||
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
|
|
|
@ -35,7 +35,15 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
|
|||
reraise=True,
|
||||
)
|
||||
def anthropic_completion_with_backoff(
|
||||
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
|
||||
messages,
|
||||
system_prompt,
|
||||
model_name,
|
||||
temperature=0,
|
||||
api_key=None,
|
||||
model_kwargs=None,
|
||||
max_tokens=None,
|
||||
response_type="text",
|
||||
tracer={},
|
||||
) -> str:
|
||||
if api_key not in anthropic_clients:
|
||||
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
|
||||
|
@ -44,8 +52,11 @@ def anthropic_completion_with_backoff(
|
|||
client = anthropic_clients[api_key]
|
||||
|
||||
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
|
||||
if response_type == "json_object":
|
||||
# Prefill model response with '{' to make it output a valid JSON object
|
||||
formatted_messages += [{"role": "assistant", "content": "{"}]
|
||||
|
||||
aggregated_response = ""
|
||||
aggregated_response = "{" if response_type == "json_object" else ""
|
||||
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
|
||||
|
||||
model_kwargs = model_kwargs or dict()
|
||||
|
|
|
@ -14,11 +14,13 @@ from khoj.processor.conversation.google.utils import (
|
|||
gemini_completion_with_backoff,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
clean_json,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -91,10 +93,7 @@ def extract_questions_gemini(
|
|||
|
||||
# Extract, Clean Message from Gemini's Response
|
||||
try:
|
||||
response = response.strip()
|
||||
match = re.search(r"\{.*?\}", response)
|
||||
if match:
|
||||
response = match.group()
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response:
|
||||
|
@ -117,8 +116,10 @@ def gemini_send_message_to_model(
|
|||
messages, system_prompt = format_messages_for_gemini(messages)
|
||||
|
||||
model_kwargs = {}
|
||||
if response_type == "json_object":
|
||||
model_kwargs["response_mime_type"] = "application/json"
|
||||
|
||||
# Sometimes, this causes unwanted behavior and terminates response early. Disable for now while it's flaky.
|
||||
# if response_type == "json_object":
|
||||
# model_kwargs["response_mime_type"] = "application/json"
|
||||
|
||||
# Get Response from Gemini
|
||||
return gemini_completion_with_backoff(
|
||||
|
@ -136,6 +137,7 @@ def converse_gemini(
|
|||
references,
|
||||
user_query,
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
conversation_log={},
|
||||
model: Optional[str] = "gemini-1.5-flash",
|
||||
api_key: Optional[str] = None,
|
||||
|
@ -156,7 +158,6 @@ def converse_gemini(
|
|||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
|
@ -181,7 +182,7 @@ def converse_gemini(
|
|||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
|
@ -189,10 +190,13 @@ def converse_gemini(
|
|||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(compiled_references):
|
||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n"
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
|
||||
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
|
|
|
@ -19,6 +19,7 @@ from khoj.utils import state
|
|||
from khoj.utils.constants import empty_escape_sequences
|
||||
from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -138,7 +139,8 @@ def filter_questions(questions: List[str]):
|
|||
def converse_offline(
|
||||
user_query,
|
||||
references=[],
|
||||
online_results=[],
|
||||
online_results={},
|
||||
code_results={},
|
||||
conversation_log={},
|
||||
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
|
||||
loaded_model: Union[Any, None] = None,
|
||||
|
@ -158,8 +160,6 @@ def converse_offline(
|
|||
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
|
||||
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
|
||||
tracer["chat_model"] = model
|
||||
|
||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||
current_date = datetime.now()
|
||||
|
||||
if agent and agent.personality:
|
||||
|
@ -184,24 +184,25 @@ def converse_offline(
|
|||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
completion_func(chat_response=prompts.no_online_results_found.format())
|
||||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(compiled_references):
|
||||
context_message += f"{prompts.notes_conversation_offline.format(references=compiled_references)}\n\n"
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n"
|
||||
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
|
||||
simplified_online_results = online_results.copy()
|
||||
for result in online_results:
|
||||
if online_results[result].get("webpages"):
|
||||
simplified_online_results[result] = online_results[result]["webpages"]
|
||||
|
||||
context_message += (
|
||||
f"{prompts.online_search_conversation_offline.format(online_results=str(simplified_online_results))}"
|
||||
)
|
||||
context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n"
|
||||
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
|
||||
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
|
|
|
@ -12,12 +12,13 @@ from khoj.processor.conversation.openai.utils import (
|
|||
completion_with_backoff,
|
||||
)
|
||||
from khoj.processor.conversation.utils import (
|
||||
clean_json,
|
||||
construct_structured_message,
|
||||
generate_chatml_messages_with_context,
|
||||
remove_json_codeblock,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
from khoj.utils.yaml import yaml_dump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -94,8 +95,7 @@ def extract_questions(
|
|||
|
||||
# Extract, Clean Message from GPT's Response
|
||||
try:
|
||||
response = response.strip()
|
||||
response = remove_json_codeblock(response)
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response:
|
||||
|
@ -133,6 +133,7 @@ def converse(
|
|||
references,
|
||||
user_query,
|
||||
online_results: Optional[Dict[str, Dict]] = None,
|
||||
code_results: Optional[Dict[str, Dict]] = None,
|
||||
conversation_log={},
|
||||
model: str = "gpt-4o-mini",
|
||||
api_key: Optional[str] = None,
|
||||
|
@ -154,7 +155,6 @@ def converse(
|
|||
"""
|
||||
# Initialize Variables
|
||||
current_date = datetime.now()
|
||||
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
|
||||
|
||||
if agent and agent.personality:
|
||||
system_prompt = prompts.custom_personality.format(
|
||||
|
@ -178,7 +178,7 @@ def converse(
|
|||
system_prompt = f"{system_prompt}\n{user_name_prompt}"
|
||||
|
||||
# Get Conversation Primer appropriate to Conversation Type
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
|
||||
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
|
||||
completion_func(chat_response=prompts.no_notes_found.format())
|
||||
return iter([prompts.no_notes_found.format()])
|
||||
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
|
||||
|
@ -186,10 +186,13 @@ def converse(
|
|||
return iter([prompts.no_online_results_found.format()])
|
||||
|
||||
context_message = ""
|
||||
if not is_none_or_empty(compiled_references):
|
||||
context_message = f"{prompts.notes_conversation.format(references=compiled_references)}\n\n"
|
||||
if not is_none_or_empty(references):
|
||||
context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n"
|
||||
if not is_none_or_empty(online_results):
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
|
||||
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
|
||||
if not is_none_or_empty(code_results):
|
||||
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
|
||||
context_message = context_message.strip()
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
|
|
|
@ -394,21 +394,23 @@ Q: {query}
|
|||
|
||||
extract_questions = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes. Disregard online search requests.
|
||||
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes and documents.
|
||||
Construct search queries to retrieve relevant information to answer the user's question.
|
||||
- You will be provided past questions(Q) and answers(A) for context.
|
||||
- You will be provided example and actual past user questions(Q), search queries(Khoj) and answers(A) for context.
|
||||
- Add as much context from the previous questions and answers as required into your search queries.
|
||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||
- Break your search down into multiple search queries from a diverse set of lenses to retrieve all related documents.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
|
||||
{personality_context}
|
||||
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
|
||||
What searches will you perform to answer the user's question? Respond with search queries as list of strings in a JSON object.
|
||||
Current Date: {day_of_week}, {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
Examples
|
||||
---
|
||||
Q: How was my trip to Cambodia?
|
||||
Khoj: {{"queries": ["How was my trip to Cambodia?"]}}
|
||||
Khoj: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
|
||||
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
|
||||
|
||||
Q: Who did i visit that temple with?
|
||||
|
@ -443,6 +445,8 @@ Q: Who all did I meet here yesterday?
|
|||
Khoj: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
|
||||
A: Yesterday's note mentions your visit to your local beach with Ram and Shyam.
|
||||
|
||||
Actual
|
||||
---
|
||||
{chat_history}
|
||||
Q: {text}
|
||||
Khoj:
|
||||
|
@ -451,11 +455,11 @@ Khoj:
|
|||
|
||||
extract_questions_anthropic_system_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes. Disregard online search requests.
|
||||
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes.
|
||||
Construct search queries to retrieve relevant information to answer the user's question.
|
||||
- You will be provided past questions(User), extracted queries(Assistant) and answers(A) for context.
|
||||
- You will be provided past questions(User), search queries(Assistant) and answers(A) for context.
|
||||
- Add as much context from the previous questions and answers as required into your search queries.
|
||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||
- Break your search down into multiple search queries from a diverse set of lenses to retrieve all related documents.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
|
||||
{personality_context}
|
||||
|
@ -468,7 +472,7 @@ User's Location: {location}
|
|||
Here are some examples of how you can construct search queries to answer the user's question:
|
||||
|
||||
User: How was my trip to Cambodia?
|
||||
Assistant: {{"queries": ["How was my trip to Cambodia?"]}}
|
||||
Assistant: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
|
||||
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
|
||||
|
||||
User: What national parks did I go to last year?
|
||||
|
@ -501,17 +505,14 @@ Assistant:
|
|||
)
|
||||
|
||||
system_prompt_extract_relevant_information = """
|
||||
As a professional analyst, create a comprehensive report of the most relevant information from a web page in response to a user's query.
|
||||
The text provided is directly from within the web page.
|
||||
The report you create should be multiple paragraphs, and it should represent the content of the website.
|
||||
Tell the user exactly what the website says in response to their query, while adhering to these guidelines:
|
||||
As a professional analyst, your job is to extract all pertinent information from documents to help answer user's query.
|
||||
You will be provided raw text directly from within the document.
|
||||
Adhere to these guidelines while extracting information from the provided documents:
|
||||
|
||||
1. Answer the user's query as specifically as possible. Include many supporting details from the website.
|
||||
2. Craft a report that is detailed, thorough, in-depth, and complex, while maintaining clarity.
|
||||
3. Rely strictly on the provided text, without including external information.
|
||||
4. Format the report in multiple paragraphs with a clear structure.
|
||||
5. Be as specific as possible in your answer to the user's query.
|
||||
6. Reproduce as much of the provided text as possible, while maintaining readability.
|
||||
1. Extract all relevant text and links from the document that can assist with further research or answer the user's query.
|
||||
2. Craft a comprehensive but compact report with all the necessary data from the document to generate an informed response.
|
||||
3. Rely strictly on the provided text to generate your summary, without including external information.
|
||||
4. Provide specific, important snippets from the document in your report to establish trust in your summary.
|
||||
""".strip()
|
||||
|
||||
extract_relevant_information = PromptTemplate.from_template(
|
||||
|
@ -519,10 +520,10 @@ extract_relevant_information = PromptTemplate.from_template(
|
|||
{personality_context}
|
||||
Target Query: {query}
|
||||
|
||||
Web Pages:
|
||||
Document:
|
||||
{corpus}
|
||||
|
||||
Collate only relevant information from the website to answer the target query.
|
||||
Collate only relevant information from the document to answer the target query.
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
@ -617,6 +618,67 @@ Khoj:
|
|||
""".strip()
|
||||
)
|
||||
|
||||
plan_function_execution = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, a smart, creative and methodical researcher. Use the provided tool AIs to investigate information to answer query.
|
||||
Create a multi-step plan and intelligently iterate on the plan based on the retrieved information to find the requested information.
|
||||
{personality_context}
|
||||
|
||||
# Instructions
|
||||
- Ask detailed queries to the tool AIs provided below, one at a time, to discover required information or run calculations. Their response will be shown to you in the next iteration.
|
||||
- Break down your research process into independent, self-contained steps that can be executed sequentially to answer the user's query. Write your step-by-step plan in the scratchpad.
|
||||
- Ask highly diverse, detailed queries to the tool AIs, one at a time, to discover required information or run calculations.
|
||||
- NEVER repeat the same query across iterations.
|
||||
- Ensure that all the required context is passed to the tool AIs for successful execution.
|
||||
- Ensure that you go deeper when possible and try more broad, creative strategies when a path is not yielding useful results. Build on the results of the previous iterations.
|
||||
- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to answer the user's question.
|
||||
- Stop when you have the required information by returning a JSON object with an empty "tool" field. E.g., {{scratchpad: "I have all I need", tool: "", query: ""}}
|
||||
|
||||
# Examples
|
||||
Assuming you can search the user's notes and the internet.
|
||||
- When they ask for the population of their hometown
|
||||
1. Try look up their hometown in their notes. Ask the note search AI to search for their birth certificate, childhood memories, school, resume etc.
|
||||
2. If not found in their notes, try infer their hometown from their online social media profiles. Ask the online search AI to look for {username}'s biography, school, resume on linkedin, facebook, website etc.
|
||||
3. Only then try find the latest population of their hometown by reading official websites with the help of the online search and web page reading AI.
|
||||
- When user for their computer's specs
|
||||
1. Try find their computer model in their notes.
|
||||
2. Now find webpages with their computer model's spec online and read them.
|
||||
- When I ask what clothes to carry for their upcoming trip
|
||||
1. Find the itinerary of their upcoming trip in their notes.
|
||||
2. Next find the weather forecast at the destination online.
|
||||
3. Then find if they mentioned what clothes they own in their notes.
|
||||
|
||||
# Background Context
|
||||
- Current Date: {day_of_week}, {current_date}
|
||||
- User Location: {location}
|
||||
- User Name: {username}
|
||||
|
||||
# Available Tool AIs
|
||||
Which of the tool AIs listed below would you use to answer the user's question? You **only** have access to the following tool AIs:
|
||||
|
||||
{tools}
|
||||
|
||||
# Previous Iterations
|
||||
{previous_iterations}
|
||||
|
||||
# Chat History:
|
||||
{chat_history}
|
||||
|
||||
Return the next tool AI to use and the query to ask it. Your response should always be a valid JSON object. Do not say anything else.
|
||||
Response format:
|
||||
{{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "tool": "<name_of_tool_ai>", "query": "<your_detailed_query_for_the_tool_ai>"}}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
previous_iteration = PromptTemplate.from_template(
|
||||
"""
|
||||
## Iteration {index}:
|
||||
- tool: {tool}
|
||||
- query: {query}
|
||||
- result: {result}
|
||||
"""
|
||||
)
|
||||
|
||||
pick_relevant_information_collection_tools = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant.
|
||||
|
@ -806,6 +868,53 @@ Khoj:
|
|||
""".strip()
|
||||
)
|
||||
|
||||
# Code Generation
|
||||
# --
|
||||
python_code_generation_prompt = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an advanced python programmer. You are tasked with constructing **up to three** python programs to best answer the user query.
|
||||
- The python program will run in a pyodide python sandbox with no network access.
|
||||
- You can write programs to run complex calculations, analyze data, create charts, generate documents to meticulously answer the query
|
||||
- The sandbox has access to the standard library, matplotlib, panda, numpy, scipy, bs4, sympy, brotli, cryptography, fast-parquet
|
||||
- Do not try display images or plots in the code directly. The code should save the image or plot to a file instead.
|
||||
- Write any document, charts etc. to be shared with the user to file. These files can be seen by the user.
|
||||
- Use as much context from the previous questions and answers as required to generate your code.
|
||||
{personality_context}
|
||||
What code will you need to write, if any, to answer the user's question?
|
||||
Provide code programs as a list of strings in a JSON object with key "codes".
|
||||
Current Date: {current_date}
|
||||
User's Location: {location}
|
||||
{username}
|
||||
|
||||
The JSON schema is of the form {{"codes": ["code1", "code2", "code3"]}}
|
||||
For example:
|
||||
{{"codes": ["print('Hello, World!')", "print('Goodbye, World!')"]}}
|
||||
|
||||
Now it's your turn to construct python programs to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else.
|
||||
Context:
|
||||
---
|
||||
{context}
|
||||
|
||||
Chat History:
|
||||
---
|
||||
{chat_history}
|
||||
|
||||
User: {query}
|
||||
Khoj:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
code_executed_context = PromptTemplate.from_template(
|
||||
"""
|
||||
Use the provided code executions to inform your response.
|
||||
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided code execution results or past conversations.
|
||||
|
||||
Code Execution Results:
|
||||
{code_results}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
# Automations
|
||||
# --
|
||||
crontime_prompt = PromptTemplate.from_template(
|
||||
|
|
|
@ -6,9 +6,10 @@ import os
|
|||
import queue
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from time import perf_counter
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import PIL.Image
|
||||
import requests
|
||||
|
@ -23,8 +24,17 @@ from khoj.database.adapters import ConversationAdapters
|
|||
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
|
||||
from khoj.search_filter.base_filter import BaseFilter
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
in_debug_mode,
|
||||
is_none_or_empty,
|
||||
merge_dicts,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
model_to_prompt_size = {
|
||||
|
@ -85,8 +95,105 @@ class ThreadedGenerator:
|
|||
self.queue.put(StopIteration)
|
||||
|
||||
|
||||
class InformationCollectionIteration:
|
||||
def __init__(
|
||||
self,
|
||||
tool: str,
|
||||
query: str,
|
||||
context: list = None,
|
||||
onlineContext: dict = None,
|
||||
codeContext: dict = None,
|
||||
summarizedResult: str = None,
|
||||
):
|
||||
self.tool = tool
|
||||
self.query = query
|
||||
self.context = context
|
||||
self.onlineContext = onlineContext
|
||||
self.codeContext = codeContext
|
||||
self.summarizedResult = summarizedResult
|
||||
|
||||
|
||||
def construct_iteration_history(
|
||||
previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
|
||||
) -> str:
|
||||
previous_iterations_history = ""
|
||||
for idx, iteration in enumerate(previous_iterations):
|
||||
iteration_data = previous_iteration_prompt.format(
|
||||
tool=iteration.tool,
|
||||
query=iteration.query,
|
||||
result=iteration.summarizedResult,
|
||||
index=idx + 1,
|
||||
)
|
||||
|
||||
previous_iterations_history += iteration_data
|
||||
return previous_iterations_history
|
||||
|
||||
|
||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||
chat_history = ""
|
||||
for chat in conversation_history.get("chat", [])[-n:]:
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: {chat['message']}\n"
|
||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
||||
return chat_history
|
||||
|
||||
|
||||
def construct_tool_chat_history(
|
||||
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
|
||||
) -> Dict[str, list]:
|
||||
chat_history: list = []
|
||||
inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
|
||||
if tool == ConversationCommand.Notes:
|
||||
inferred_query_extractor = (
|
||||
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
|
||||
)
|
||||
elif tool == ConversationCommand.Online:
|
||||
inferred_query_extractor = (
|
||||
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
|
||||
)
|
||||
elif tool == ConversationCommand.Code:
|
||||
inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
|
||||
for iteration in previous_iterations:
|
||||
chat_history += [
|
||||
{
|
||||
"by": "you",
|
||||
"message": iteration.query,
|
||||
},
|
||||
{
|
||||
"by": "khoj",
|
||||
"intent": {
|
||||
"type": "remember",
|
||||
"inferred-queries": inferred_query_extractor(iteration),
|
||||
"query": iteration.query,
|
||||
},
|
||||
"message": iteration.summarizedResult,
|
||||
},
|
||||
]
|
||||
|
||||
return {"chat": chat_history}
|
||||
|
||||
|
||||
class ChatEvent(Enum):
|
||||
START_LLM_RESPONSE = "start_llm_response"
|
||||
END_LLM_RESPONSE = "end_llm_response"
|
||||
MESSAGE = "message"
|
||||
REFERENCES = "references"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
def message_to_log(
|
||||
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
|
||||
user_message,
|
||||
chat_response,
|
||||
user_message_metadata={},
|
||||
khoj_message_metadata={},
|
||||
conversation_log=[],
|
||||
train_of_thought=[],
|
||||
):
|
||||
"""Create json logs from messages, metadata for conversation log"""
|
||||
default_khoj_message_metadata = {
|
||||
|
@ -114,6 +221,7 @@ def save_to_conversation_log(
|
|||
user_message_time: str = None,
|
||||
compiled_references: List[Dict[str, Any]] = [],
|
||||
online_results: Dict[str, Any] = {},
|
||||
code_results: Dict[str, Any] = {},
|
||||
inferred_queries: List[str] = [],
|
||||
intent_type: str = "remember",
|
||||
client_application: ClientApplication = None,
|
||||
|
@ -121,6 +229,7 @@ def save_to_conversation_log(
|
|||
automation_id: str = None,
|
||||
query_images: List[str] = None,
|
||||
tracer: Dict[str, Any] = {},
|
||||
train_of_thought: List[Any] = [],
|
||||
):
|
||||
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
updated_conversation = message_to_log(
|
||||
|
@ -134,9 +243,12 @@ def save_to_conversation_log(
|
|||
"context": compiled_references,
|
||||
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
|
||||
"onlineContext": online_results,
|
||||
"codeContext": code_results,
|
||||
"automationId": automation_id,
|
||||
"trainOfThought": train_of_thought,
|
||||
},
|
||||
conversation_log=meta_log.get("chat", []),
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
ConversationAdapters.save_conversation(
|
||||
user,
|
||||
|
@ -330,9 +442,23 @@ def reciprocal_conversation_to_chatml(message_pair):
|
|||
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
|
||||
|
||||
|
||||
def remove_json_codeblock(response: str):
|
||||
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
|
||||
return response.removeprefix("```json").removesuffix("```")
|
||||
def clean_json(response: str):
|
||||
"""Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
||||
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
|
||||
|
||||
|
||||
def clean_code_python(code: str):
|
||||
"""Remove any markdown codeblock and newline formatting if present. Useful for non schema enforceable models"""
|
||||
return code.strip().removeprefix("```python").removesuffix("```")
|
||||
|
||||
|
||||
def defilter_query(query: str):
|
||||
"""Remove any query filters in query"""
|
||||
defiltered_query = query
|
||||
filters: List[BaseFilter] = [WordFilter(), FileFilter(), DateFilter()]
|
||||
for filter in filters:
|
||||
defiltered_query = filter.defilter(defiltered_query)
|
||||
return defiltered_query
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
@ -4,7 +4,7 @@ import logging
|
|||
import os
|
||||
import urllib.parse
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
from bs4 import BeautifulSoup
|
||||
|
@ -52,7 +52,8 @@ OLOSTEP_QUERY_PARAMS = {
|
|||
"expandMarkdown": "True",
|
||||
"expandHtml": "False",
|
||||
}
|
||||
MAX_WEBPAGES_TO_READ = 1
|
||||
|
||||
DEFAULT_MAX_WEBPAGES_TO_READ = 1
|
||||
|
||||
|
||||
async def search_online(
|
||||
|
@ -62,6 +63,7 @@ async def search_online(
|
|||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
custom_filters: List[str] = [],
|
||||
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
|
@ -97,7 +99,7 @@ async def search_online(
|
|||
for subquery in response_dict:
|
||||
if "answerBox" in response_dict[subquery]:
|
||||
continue
|
||||
for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]:
|
||||
for organic in response_dict[subquery].get("organic", [])[:max_webpages_to_read]:
|
||||
link = organic.get("link")
|
||||
if link in webpages:
|
||||
webpages[link]["queries"].add(subquery)
|
||||
|
|
144
src/khoj/processor/tools/run_code.py
Normal file
144
src/khoj/processor/tools/run_code.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from khoj.database.adapters import ais_user_subscribed
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
clean_code_python,
|
||||
clean_json,
|
||||
construct_chat_history,
|
||||
)
|
||||
from khoj.routers.helpers import send_message_to_model_wrapper
|
||||
from khoj.utils.helpers import timer
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080")
|
||||
|
||||
|
||||
async def run_code(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
context: str,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
sandbox_url: str = SANDBOX_URL,
|
||||
tracer: dict = {},
|
||||
):
|
||||
# Generate Code
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Generate code snippets** for {query}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
try:
|
||||
with timer("Chat actor: Generate programs to execute", logger):
|
||||
codes = await generate_python_code(
|
||||
query,
|
||||
conversation_history,
|
||||
context,
|
||||
location_data,
|
||||
user,
|
||||
query_images,
|
||||
agent,
|
||||
tracer,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to generate code for {query} with error: {e}")
|
||||
|
||||
# Run Code
|
||||
if send_status_func:
|
||||
async for event in send_status_func(f"**Running {len(codes)} code snippets**"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
try:
|
||||
tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes]
|
||||
with timer("Chat actor: Execute generated programs", logger):
|
||||
results = await asyncio.gather(*tasks)
|
||||
for result in results:
|
||||
code = result.pop("code")
|
||||
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
|
||||
yield {query: {"code": code, "results": result}}
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to run code for {query} with error: {e}")
|
||||
|
||||
|
||||
async def generate_python_code(
|
||||
q: str,
|
||||
conversation_history: dict,
|
||||
context: str,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
tracer: dict = {},
|
||||
) -> List[str]:
|
||||
location = f"{location_data}" if location_data else "Unknown"
|
||||
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
|
||||
subscribed = await ais_user_subscribed(user)
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
code_generation_prompt = prompts.python_code_generation_prompt.format(
|
||||
current_date=utc_date,
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
context=context,
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
response = await send_message_to_model_wrapper(
|
||||
code_generation_prompt,
|
||||
query_images=query_images,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
codes = [code.strip() for code in response["codes"] if code.strip()]
|
||||
|
||||
if not isinstance(codes, list) or not codes or len(codes) == 0:
|
||||
raise ValueError
|
||||
return codes
|
||||
|
||||
|
||||
async def execute_sandboxed_python(code: str, sandbox_url: str = SANDBOX_URL) -> dict[str, Any]:
|
||||
"""
|
||||
Takes code to run as a string and calls the terrarium API to execute it.
|
||||
Returns the result of the code execution as a dictionary.
|
||||
"""
|
||||
headers = {"Content-Type": "application/json"}
|
||||
cleaned_code = clean_code_python(code)
|
||||
data = {"code": cleaned_code}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(sandbox_url, json=data, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
result: dict[str, Any] = await response.json()
|
||||
result["code"] = cleaned_code
|
||||
return result
|
||||
else:
|
||||
return {
|
||||
"code": cleaned_code,
|
||||
"success": False,
|
||||
"std_err": f"Failed to execute code with {response.status}",
|
||||
}
|
|
@ -44,6 +44,7 @@ from khoj.processor.conversation.offline.chat_model import extract_questions_off
|
|||
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
|
||||
from khoj.processor.conversation.openai.gpt import extract_questions
|
||||
from khoj.processor.conversation.openai.whisper import transcribe_audio
|
||||
from khoj.processor.conversation.utils import defilter_query
|
||||
from khoj.routers.helpers import (
|
||||
ApiUserRateLimiter,
|
||||
ChatEvent,
|
||||
|
@ -167,8 +168,8 @@ async def execute_search(
|
|||
search_futures += [
|
||||
executor.submit(
|
||||
text_search.query,
|
||||
user,
|
||||
user_query,
|
||||
user,
|
||||
t,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
max_distance=max_distance,
|
||||
|
@ -355,7 +356,7 @@ async def extract_references_and_questions(
|
|||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
# Initialize Variables
|
||||
compiled_references: List[Any] = []
|
||||
compiled_references: List[dict[str, str]] = []
|
||||
inferred_queries: List[str] = []
|
||||
|
||||
agent_has_entries = False
|
||||
|
@ -384,9 +385,7 @@ async def extract_references_and_questions(
|
|||
return
|
||||
|
||||
# Extract filter terms from user message
|
||||
defiltered_query = q
|
||||
for filter in [DateFilter(), WordFilter(), FileFilter()]:
|
||||
defiltered_query = filter.defilter(defiltered_query)
|
||||
defiltered_query = defilter_query(q)
|
||||
filters_in_query = q.replace(defiltered_query, "").strip()
|
||||
conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id)
|
||||
|
||||
|
@ -502,7 +501,8 @@ async def extract_references_and_questions(
|
|||
)
|
||||
search_results = text_search.deduplicated_search_responses(search_results)
|
||||
compiled_references = [
|
||||
{"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results
|
||||
{"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]}
|
||||
for q, item in zip(inferred_queries, search_results)
|
||||
]
|
||||
|
||||
yield compiled_references, inferred_queries, defiltered_query
|
||||
|
|
|
@ -6,7 +6,7 @@ import time
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
|
@ -25,10 +25,11 @@ from khoj.database.adapters import (
|
|||
)
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
|
||||
from khoj.processor.image.generate import text_to_image
|
||||
from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.processor.tools.run_code import run_code
|
||||
from khoj.routers.api import extract_references_and_questions
|
||||
from khoj.routers.helpers import (
|
||||
ApiImageRateLimiter,
|
||||
|
@ -42,8 +43,10 @@ from khoj.routers.helpers import (
|
|||
aget_relevant_output_modes,
|
||||
construct_automation_created_message,
|
||||
create_automation,
|
||||
extract_relevant_info,
|
||||
extract_relevant_summary,
|
||||
generate_excalidraw_diagram,
|
||||
generate_summary_from_files,
|
||||
get_conversation_command,
|
||||
is_query_empty,
|
||||
is_ready_to_chat,
|
||||
|
@ -51,6 +54,10 @@ from khoj.routers.helpers import (
|
|||
update_telemetry_state,
|
||||
validate_conversation_config,
|
||||
)
|
||||
from khoj.routers.research import (
|
||||
InformationCollectionIteration,
|
||||
execute_information_collection,
|
||||
)
|
||||
from khoj.routers.storage import upload_image_to_bucket
|
||||
from khoj.utils import state
|
||||
from khoj.utils.helpers import (
|
||||
|
@ -563,7 +570,9 @@ async def chat(
|
|||
user: KhojUser = request.user.object
|
||||
event_delimiter = "␃🔚␗"
|
||||
q = unquote(q)
|
||||
train_of_thought = []
|
||||
nonlocal conversation_id
|
||||
|
||||
tracer: dict = {
|
||||
"mid": f"{uuid.uuid4()}",
|
||||
"cid": conversation_id,
|
||||
|
@ -583,7 +592,7 @@ async def chat(
|
|||
uploaded_images.append(uploaded_image)
|
||||
|
||||
async def send_event(event_type: ChatEvent, data: str | dict):
|
||||
nonlocal connection_alive, ttft
|
||||
nonlocal connection_alive, ttft, train_of_thought
|
||||
if not connection_alive or await request.is_disconnected():
|
||||
connection_alive = False
|
||||
logger.warning(f"User {user} disconnected from {common.client} client")
|
||||
|
@ -591,8 +600,11 @@ async def chat(
|
|||
try:
|
||||
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||
collect_telemetry()
|
||||
if event_type == ChatEvent.START_LLM_RESPONSE:
|
||||
elif event_type == ChatEvent.START_LLM_RESPONSE:
|
||||
ttft = time.perf_counter() - start_time
|
||||
elif event_type == ChatEvent.STATUS:
|
||||
train_of_thought.append({"type": event_type.value, "data": data})
|
||||
|
||||
if event_type == ChatEvent.MESSAGE:
|
||||
yield data
|
||||
elif event_type == ChatEvent.REFERENCES or stream:
|
||||
|
@ -681,6 +693,14 @@ async def chat(
|
|||
meta_log = conversation.conversation_log
|
||||
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
|
||||
|
||||
researched_results = ""
|
||||
online_results: Dict = dict()
|
||||
code_results: Dict = dict()
|
||||
## Extract Document References
|
||||
compiled_references: List[Any] = []
|
||||
inferred_queries: List[Any] = []
|
||||
defiltered_query = defilter_query(q)
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||
conversation_commands = await aget_relevant_information_sources(
|
||||
q,
|
||||
|
@ -691,6 +711,11 @@ async def chat(
|
|||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
# If we're doing research, we don't want to do anything else
|
||||
if ConversationCommand.Research in conversation_commands:
|
||||
conversation_commands = [ConversationCommand.Research]
|
||||
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
|
||||
|
@ -705,6 +730,38 @@ async def chat(
|
|||
if mode not in conversation_commands:
|
||||
conversation_commands.append(mode)
|
||||
|
||||
if conversation_commands == [ConversationCommand.Research]:
|
||||
async for research_result in execute_information_collection(
|
||||
request=request,
|
||||
user=user,
|
||||
query=defiltered_query,
|
||||
conversation_id=conversation_id,
|
||||
conversation_history=meta_log,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
user_name=user_name,
|
||||
location=location,
|
||||
file_filters=conversation.file_filters if conversation else [],
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(research_result, InformationCollectionIteration):
|
||||
if research_result.summarizedResult:
|
||||
if research_result.onlineContext:
|
||||
online_results.update(research_result.onlineContext)
|
||||
if research_result.codeContext:
|
||||
code_results.update(research_result.codeContext)
|
||||
if research_result.context:
|
||||
compiled_references.extend(research_result.context)
|
||||
|
||||
researched_results += research_result.summarizedResult
|
||||
|
||||
else:
|
||||
yield research_result
|
||||
|
||||
# researched_results = await extract_relevant_info(q, researched_results, agent)
|
||||
logger.info(f"Researched Results: {researched_results}")
|
||||
|
||||
for cmd in conversation_commands:
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
|
||||
q = q.replace(f"/{cmd.value}", "").strip()
|
||||
|
@ -733,48 +790,24 @@ async def chat(
|
|||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
else:
|
||||
try:
|
||||
file_object = None
|
||||
if await EntryAdapters.aagent_has_entries(agent):
|
||||
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
||||
if len(file_names) > 0:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(
|
||||
None, file_names[0], agent
|
||||
)
|
||||
|
||||
if len(file_filters) > 0:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
|
||||
if len(file_object) == 0:
|
||||
response_log = "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
return
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
if not q:
|
||||
q = "Create a general summary of the file"
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
|
||||
):
|
||||
yield result
|
||||
|
||||
response = await extract_relevant_summary(
|
||||
q,
|
||||
contextual_data,
|
||||
conversation_history=meta_log,
|
||||
query_images=uploaded_images,
|
||||
async for response in generate_summary_from_files(
|
||||
q=q,
|
||||
user=user,
|
||||
file_filters=file_filters,
|
||||
meta_log=meta_log,
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
tracer=tracer,
|
||||
)
|
||||
response_log = str(response)
|
||||
async for result in send_llm_response(response_log):
|
||||
yield result
|
||||
except Exception as e:
|
||||
response_log = "Error summarizing file. Please try again, or contact support."
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||
async for result in send_llm_response(response_log):
|
||||
):
|
||||
if isinstance(response, dict) and ChatEvent.STATUS in response:
|
||||
yield response[ChatEvent.STATUS]
|
||||
else:
|
||||
if isinstance(response, str):
|
||||
response_log = response
|
||||
async for result in send_llm_response(response):
|
||||
yield result
|
||||
|
||||
await sync_to_async(save_to_conversation_log)(
|
||||
q,
|
||||
response_log,
|
||||
|
@ -786,6 +819,7 @@ async def chat(
|
|||
conversation_id=conversation_id,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -794,7 +828,7 @@ async def chat(
|
|||
if not q:
|
||||
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
if conversation_config == None:
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
model_type = conversation_config.model_type
|
||||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||
async for result in send_llm_response(formatted_help):
|
||||
|
@ -830,6 +864,7 @@ async def chat(
|
|||
automation_id=automation.id,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
async for result in send_llm_response(llm_response):
|
||||
yield result
|
||||
|
@ -837,7 +872,7 @@ async def chat(
|
|||
|
||||
# Gather Context
|
||||
## Extract Document References
|
||||
compiled_references, inferred_queries, defiltered_query = [], [], q
|
||||
if not ConversationCommand.Research in conversation_commands:
|
||||
try:
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
|
@ -860,7 +895,9 @@ async def chat(
|
|||
inferred_queries.extend(result[1])
|
||||
defiltered_query = result[2]
|
||||
except Exception as e:
|
||||
error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
||||
error_message = (
|
||||
f"Error searching knowledge base: {e}. Attempting to respond without document references."
|
||||
)
|
||||
logger.error(error_message, exc_info=True)
|
||||
async for result in send_event(
|
||||
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
|
||||
|
@ -874,8 +911,6 @@ async def chat(
|
|||
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
|
||||
yield result
|
||||
|
||||
online_results: Dict = dict()
|
||||
|
||||
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
|
||||
async for result in send_llm_response(f"{no_entries_found.format()}"):
|
||||
yield result
|
||||
|
@ -948,6 +983,33 @@ async def chat(
|
|||
):
|
||||
yield result
|
||||
|
||||
## Gather Code Results
|
||||
if ConversationCommand.Code in conversation_commands:
|
||||
try:
|
||||
context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
|
||||
async for result in run_code(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
context,
|
||||
location,
|
||||
user,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
query_images=uploaded_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
code_results = result
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
|
||||
yield result
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
## Send Gathered References
|
||||
async for result in send_event(
|
||||
ChatEvent.REFERENCES,
|
||||
|
@ -955,6 +1017,7 @@ async def chat(
|
|||
"inferredQueries": inferred_queries,
|
||||
"context": compiled_references,
|
||||
"onlineContext": online_results,
|
||||
"codeContext": code_results,
|
||||
},
|
||||
):
|
||||
yield result
|
||||
|
@ -1004,6 +1067,7 @@ async def chat(
|
|||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
content_obj = {
|
||||
"intentType": intent_type,
|
||||
|
@ -1061,6 +1125,7 @@ async def chat(
|
|||
online_results=online_results,
|
||||
query_images=uploaded_images,
|
||||
tracer=tracer,
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
|
||||
async for result in send_llm_response(json.dumps(content_obj)):
|
||||
|
@ -1076,6 +1141,7 @@ async def chat(
|
|||
conversation,
|
||||
compiled_references,
|
||||
online_results,
|
||||
code_results,
|
||||
inferred_queries,
|
||||
conversation_commands,
|
||||
user,
|
||||
|
@ -1083,8 +1149,10 @@ async def chat(
|
|||
conversation_id,
|
||||
location,
|
||||
user_name,
|
||||
researched_results,
|
||||
uploaded_images,
|
||||
tracer,
|
||||
train_of_thought,
|
||||
)
|
||||
|
||||
# Send Response
|
||||
|
|
|
@ -43,6 +43,7 @@ from khoj.database.adapters import (
|
|||
AutomationAdapters,
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
ais_user_subscribed,
|
||||
create_khoj_token,
|
||||
get_khoj_tokens,
|
||||
|
@ -87,9 +88,11 @@ from khoj.processor.conversation.offline.chat_model import (
|
|||
)
|
||||
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
|
||||
from khoj.processor.conversation.utils import (
|
||||
ChatEvent,
|
||||
ThreadedGenerator,
|
||||
clean_json,
|
||||
construct_chat_history,
|
||||
generate_chatml_messages_with_context,
|
||||
remove_json_codeblock,
|
||||
save_to_conversation_log,
|
||||
)
|
||||
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
|
||||
|
@ -137,7 +140,7 @@ def validate_conversation_config(user: KhojUser):
|
|||
async def is_ready_to_chat(user: KhojUser):
|
||||
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
|
||||
if user_conversation_config == None:
|
||||
user_conversation_config = await ConversationAdapters.aget_default_conversation_config()
|
||||
user_conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
|
||||
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
|
||||
chat_model = user_conversation_config.chat_model
|
||||
|
@ -210,21 +213,6 @@ def get_next_url(request: Request) -> str:
|
|||
return urljoin(str(request.base_url).rstrip("/"), next_path)
|
||||
|
||||
|
||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||
chat_history = ""
|
||||
for chat in conversation_history.get("chat", [])[-n:]:
|
||||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: {chat['message']}\n"
|
||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: [generated image redacted for space]\n"
|
||||
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
|
||||
chat_history += f"User: {chat['intent']['query']}\n"
|
||||
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
|
||||
return chat_history
|
||||
|
||||
|
||||
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
|
||||
if query.startswith("/notes"):
|
||||
return ConversationCommand.Notes
|
||||
|
@ -244,6 +232,10 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
|
|||
return ConversationCommand.Summarize
|
||||
elif query.startswith("/diagram"):
|
||||
return ConversationCommand.Diagram
|
||||
elif query.startswith("/code"):
|
||||
return ConversationCommand.Code
|
||||
elif query.startswith("/research"):
|
||||
return ConversationCommand.Research
|
||||
# If no relevant notes found for the given query
|
||||
elif not any_references:
|
||||
return ConversationCommand.General
|
||||
|
@ -342,8 +334,7 @@ async def aget_relevant_information_sources(
|
|||
)
|
||||
|
||||
try:
|
||||
response = response.strip()
|
||||
response = remove_json_codeblock(response)
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["source"] if q.strip()]
|
||||
if not isinstance(response, list) or not response or len(response) == 0:
|
||||
|
@ -421,8 +412,7 @@ async def aget_relevant_output_modes(
|
|||
)
|
||||
|
||||
try:
|
||||
response = response.strip()
|
||||
response = remove_json_codeblock(response)
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
|
||||
if is_none_or_empty(response):
|
||||
|
@ -483,7 +473,7 @@ async def infer_webpage_urls(
|
|||
|
||||
# Validate that the response is a non-empty, JSON-serializable list of URLs
|
||||
try:
|
||||
response = response.strip()
|
||||
response = clean_json(response)
|
||||
urls = json.loads(response)
|
||||
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
|
||||
if is_none_or_empty(valid_unique_urls):
|
||||
|
@ -534,8 +524,7 @@ async def generate_online_subqueries(
|
|||
|
||||
# Validate that the response is a non-empty, JSON-serializable list
|
||||
try:
|
||||
response = response.strip()
|
||||
response = remove_json_codeblock(response)
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
response = [q.strip() for q in response["queries"] if q.strip()]
|
||||
if not isinstance(response, list) or not response or len(response) == 0:
|
||||
|
@ -644,6 +633,53 @@ async def extract_relevant_summary(
|
|||
return response.strip()
|
||||
|
||||
|
||||
async def generate_summary_from_files(
|
||||
q: str,
|
||||
user: KhojUser,
|
||||
file_filters: List[str],
|
||||
meta_log: dict,
|
||||
query_images: List[str] = None,
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
try:
|
||||
file_object = None
|
||||
if await EntryAdapters.aagent_has_entries(agent):
|
||||
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
|
||||
if len(file_names) > 0:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent)
|
||||
|
||||
if len(file_filters) > 0:
|
||||
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
|
||||
|
||||
if len(file_object) == 0:
|
||||
response_log = "Sorry, I couldn't find the full text of this file."
|
||||
yield response_log
|
||||
return
|
||||
contextual_data = " ".join([file.raw_text for file in file_object])
|
||||
if not q:
|
||||
q = "Create a general summary of the file"
|
||||
async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"):
|
||||
yield {ChatEvent.STATUS: result}
|
||||
|
||||
response = await extract_relevant_summary(
|
||||
q,
|
||||
contextual_data,
|
||||
conversation_history=meta_log,
|
||||
query_images=query_images,
|
||||
user=user,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
yield str(response)
|
||||
except Exception as e:
|
||||
response_log = "Error summarizing file. Please try again, or contact support."
|
||||
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
|
||||
yield result
|
||||
|
||||
|
||||
async def generate_excalidraw_diagram(
|
||||
q: str,
|
||||
conversation_history: Dict[str, Any],
|
||||
|
@ -759,10 +795,9 @@ async def generate_excalidraw_diagram_from_description(
|
|||
|
||||
with timer("Chat actor: Generate excalidraw diagram", logger):
|
||||
raw_response = await send_message_to_model_wrapper(
|
||||
message=excalidraw_diagram_generation, user=user, tracer=tracer
|
||||
query=excalidraw_diagram_generation, user=user, tracer=tracer
|
||||
)
|
||||
raw_response = raw_response.strip()
|
||||
raw_response = remove_json_codeblock(raw_response)
|
||||
raw_response = clean_json(raw_response)
|
||||
response: Dict[str, str] = json.loads(raw_response)
|
||||
if not response or not isinstance(response, List) or not isinstance(response[0], Dict):
|
||||
# TODO Some additional validation here that it's a valid Excalidraw diagram
|
||||
|
@ -839,11 +874,12 @@ async def generate_better_image_prompt(
|
|||
|
||||
|
||||
async def send_message_to_model_wrapper(
|
||||
message: str,
|
||||
query: str,
|
||||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = None,
|
||||
context: str = "",
|
||||
tracer: dict = {},
|
||||
):
|
||||
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
|
||||
|
@ -874,7 +910,8 @@ async def send_message_to_model_wrapper(
|
|||
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
loaded_model=loaded_model,
|
||||
|
@ -899,7 +936,8 @@ async def send_message_to_model_wrapper(
|
|||
api_key = openai_chat_config.api_key
|
||||
api_base_url = openai_chat_config.api_base_url
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
|
@ -920,7 +958,8 @@ async def send_message_to_model_wrapper(
|
|||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
|
@ -934,12 +973,14 @@ async def send_message_to_model_wrapper(
|
|||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
elif model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message,
|
||||
user_message=query,
|
||||
context_message=context,
|
||||
system_message=system_message,
|
||||
model_name=chat_model,
|
||||
max_prompt_size=max_tokens,
|
||||
|
@ -1033,6 +1074,7 @@ def send_message_to_model_wrapper_sync(
|
|||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
|
@ -1064,6 +1106,7 @@ def generate_chat_response(
|
|||
conversation: Conversation,
|
||||
compiled_references: List[Dict] = [],
|
||||
online_results: Dict[str, Dict] = {},
|
||||
code_results: Dict[str, Dict] = {},
|
||||
inferred_queries: List[str] = [],
|
||||
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
|
||||
user: KhojUser = None,
|
||||
|
@ -1071,8 +1114,10 @@ def generate_chat_response(
|
|||
conversation_id: str = None,
|
||||
location_data: LocationData = None,
|
||||
user_name: Optional[str] = None,
|
||||
meta_research: str = "",
|
||||
query_images: Optional[List[str]] = None,
|
||||
tracer: dict = {},
|
||||
train_of_thought: List[Any] = [],
|
||||
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
|
||||
# Initialize Variables
|
||||
chat_response = None
|
||||
|
@ -1080,6 +1125,9 @@ def generate_chat_response(
|
|||
|
||||
metadata = {}
|
||||
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||
query_to_run = q
|
||||
if meta_research:
|
||||
query_to_run = f"AI Research: {meta_research} {q}"
|
||||
try:
|
||||
partial_completion = partial(
|
||||
save_to_conversation_log,
|
||||
|
@ -1088,11 +1136,13 @@ def generate_chat_response(
|
|||
meta_log=meta_log,
|
||||
compiled_references=compiled_references,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
inferred_queries=inferred_queries,
|
||||
client_application=client_application,
|
||||
conversation_id=conversation_id,
|
||||
query_images=query_images,
|
||||
tracer=tracer,
|
||||
train_of_thought=train_of_thought,
|
||||
)
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
|
@ -1106,9 +1156,9 @@ def generate_chat_response(
|
|||
if conversation_config.model_type == "offline":
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
chat_response = converse_offline(
|
||||
user_query=query_to_run,
|
||||
references=compiled_references,
|
||||
online_results=online_results,
|
||||
user_query=q,
|
||||
loaded_model=loaded_model,
|
||||
conversation_log=meta_log,
|
||||
completion_func=partial_completion,
|
||||
|
@ -1128,9 +1178,10 @@ def generate_chat_response(
|
|||
chat_model = conversation_config.chat_model
|
||||
chat_response = converse(
|
||||
compiled_references,
|
||||
q,
|
||||
query_to_run,
|
||||
query_images=query_images,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
conversation_log=meta_log,
|
||||
model=chat_model,
|
||||
api_key=api_key,
|
||||
|
@ -1150,9 +1201,10 @@ def generate_chat_response(
|
|||
api_key = conversation_config.openai_config.api_key
|
||||
chat_response = converse_anthropic(
|
||||
compiled_references,
|
||||
q,
|
||||
query_to_run,
|
||||
query_images=query_images,
|
||||
online_results=online_results,
|
||||
code_results=code_results,
|
||||
conversation_log=meta_log,
|
||||
model=conversation_config.chat_model,
|
||||
api_key=api_key,
|
||||
|
@ -1170,10 +1222,10 @@ def generate_chat_response(
|
|||
api_key = conversation_config.openai_config.api_key
|
||||
chat_response = converse_gemini(
|
||||
compiled_references,
|
||||
q,
|
||||
query_images=query_images,
|
||||
online_results=online_results,
|
||||
conversation_log=meta_log,
|
||||
query_to_run,
|
||||
online_results,
|
||||
code_results,
|
||||
meta_log,
|
||||
model=conversation_config.chat_model,
|
||||
api_key=api_key,
|
||||
completion_func=partial_completion,
|
||||
|
@ -1627,14 +1679,6 @@ Manage your automations [here](/automations).
|
|||
""".strip()
|
||||
|
||||
|
||||
class ChatEvent(Enum):
|
||||
START_LLM_RESPONSE = "start_llm_response"
|
||||
END_LLM_RESPONSE = "end_llm_response"
|
||||
MESSAGE = "message"
|
||||
REFERENCES = "references"
|
||||
STATUS = "status"
|
||||
|
||||
|
||||
class MessageProcessor:
|
||||
def __init__(self):
|
||||
self.references = {}
|
||||
|
|
321
src/khoj/routers/research.py
Normal file
321
src/khoj/routers/research.py
Normal file
|
@ -0,0 +1,321 @@
|
|||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
from fastapi import Request
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
InformationCollectionIteration,
|
||||
clean_json,
|
||||
construct_iteration_history,
|
||||
construct_tool_chat_history,
|
||||
)
|
||||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.processor.tools.run_code import run_code
|
||||
from khoj.routers.api import extract_references_and_questions
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
construct_chat_history,
|
||||
extract_relevant_info,
|
||||
generate_summary_from_files,
|
||||
send_message_to_model_wrapper,
|
||||
)
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
function_calling_description_for_llm,
|
||||
is_none_or_empty,
|
||||
timer,
|
||||
)
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def apick_next_tool(
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
user: KhojUser = None,
|
||||
query_images: List[str] = [],
|
||||
location: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
previous_iterations_history: str = None,
|
||||
max_iterations: int = 5,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
tracer: dict = {},
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
|
||||
"""
|
||||
|
||||
tool_options = dict()
|
||||
tool_options_str = ""
|
||||
|
||||
agent_tools = agent.input_tools if agent else []
|
||||
|
||||
for tool, description in function_calling_description_for_llm.items():
|
||||
tool_options[tool.value] = description
|
||||
if len(agent_tools) == 0 or tool.value in agent_tools:
|
||||
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||
|
||||
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
|
||||
|
||||
if query_images:
|
||||
query = f"[placeholder for user attached images]\n{query}"
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
# Extract Past User Message and Inferred Questions from Conversation Log
|
||||
today = datetime.today()
|
||||
location_data = f"{location}" if location else "Unknown"
|
||||
|
||||
function_planning_prompt = prompts.plan_function_execution.format(
|
||||
tools=tool_options_str,
|
||||
chat_history=chat_history,
|
||||
personality_context=personality_context,
|
||||
current_date=today.strftime("%Y-%m-%d"),
|
||||
day_of_week=today.strftime("%A"),
|
||||
username=user_name or "Unknown",
|
||||
location=location_data,
|
||||
previous_iterations=previous_iterations_history,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
response = await send_message_to_model_wrapper(
|
||||
query=query,
|
||||
context=function_planning_prompt,
|
||||
response_type="json_object",
|
||||
user=user,
|
||||
query_images=query_images,
|
||||
tracer=tracer,
|
||||
)
|
||||
|
||||
try:
|
||||
response = clean_json(response)
|
||||
response = json.loads(response)
|
||||
selected_tool = response.get("tool", None)
|
||||
generated_query = response.get("query", None)
|
||||
scratchpad = response.get("scratchpad", None)
|
||||
logger.info(f"Response for determining relevant tools: {response}")
|
||||
if send_status_func:
|
||||
determined_tool_message = "**Determined Tool**: "
|
||||
determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond."
|
||||
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
|
||||
async for event in send_status_func(f"{scratchpad}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
|
||||
yield InformationCollectionIteration(
|
||||
tool=selected_tool,
|
||||
query=generated_query,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
|
||||
yield InformationCollectionIteration(
|
||||
tool=None,
|
||||
query=None,
|
||||
)
|
||||
|
||||
|
||||
async def execute_information_collection(
|
||||
request: Request,
|
||||
user: KhojUser,
|
||||
query: str,
|
||||
conversation_id: str,
|
||||
conversation_history: dict,
|
||||
query_images: List[str],
|
||||
agent: Agent = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
user_name: str = None,
|
||||
location: LocationData = None,
|
||||
file_filters: List[str] = [],
|
||||
tracer: dict = {},
|
||||
):
|
||||
current_iteration = 0
|
||||
MAX_ITERATIONS = 5
|
||||
previous_iterations: List[InformationCollectionIteration] = []
|
||||
while current_iteration < MAX_ITERATIONS:
|
||||
online_results: Dict = dict()
|
||||
code_results: Dict = dict()
|
||||
document_results: List[Dict[str, str]] = []
|
||||
summarize_files: str = ""
|
||||
this_iteration = InformationCollectionIteration(tool=None, query=query)
|
||||
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
|
||||
|
||||
async for result in apick_next_tool(
|
||||
query,
|
||||
conversation_history,
|
||||
user,
|
||||
query_images,
|
||||
location,
|
||||
user_name,
|
||||
agent,
|
||||
previous_iterations_history,
|
||||
MAX_ITERATIONS,
|
||||
send_status_func,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
elif isinstance(result, InformationCollectionIteration):
|
||||
this_iteration = result
|
||||
|
||||
if this_iteration.tool == ConversationCommand.Notes:
|
||||
this_iteration.context = []
|
||||
document_results = []
|
||||
async for result in extract_references_and_questions(
|
||||
request,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
|
||||
this_iteration.query,
|
||||
7,
|
||||
None,
|
||||
conversation_id,
|
||||
[ConversationCommand.Default],
|
||||
location,
|
||||
send_status_func,
|
||||
query_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
elif isinstance(result, tuple):
|
||||
document_results = result[0]
|
||||
this_iteration.context += document_results
|
||||
|
||||
if not is_none_or_empty(document_results):
|
||||
try:
|
||||
distinct_files = {d["file"] for d in document_results}
|
||||
distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d])
|
||||
# Strip only leading # from headings
|
||||
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
|
||||
async for result in send_status_func(
|
||||
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
|
||||
):
|
||||
yield result
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting document references: {e}", exc_info=True)
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Online:
|
||||
async for result in search_online(
|
||||
this_iteration.query,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
|
||||
location,
|
||||
user,
|
||||
send_status_func,
|
||||
[],
|
||||
max_webpages_to_read=0,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
online_results: Dict[str, Dict] = result # type: ignore
|
||||
this_iteration.onlineContext = online_results
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Webpage:
|
||||
try:
|
||||
async for result in read_webpages(
|
||||
this_iteration.query,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage),
|
||||
location,
|
||||
user,
|
||||
send_status_func,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
direct_web_pages: Dict[str, Dict] = result # type: ignore
|
||||
|
||||
webpages = []
|
||||
for web_query in direct_web_pages:
|
||||
if online_results.get(web_query):
|
||||
online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"]
|
||||
else:
|
||||
online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]}
|
||||
|
||||
for webpage in direct_web_pages[web_query]["webpages"]:
|
||||
webpages.append(webpage["link"])
|
||||
this_iteration.onlineContext = online_results
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading webpages: {e}", exc_info=True)
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Code:
|
||||
try:
|
||||
async for result in run_code(
|
||||
this_iteration.query,
|
||||
construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage),
|
||||
"",
|
||||
location,
|
||||
user,
|
||||
send_status_func,
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
tracer=tracer,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
code_results: Dict[str, Dict] = result # type: ignore
|
||||
this_iteration.codeContext = code_results
|
||||
async for result in send_status_func(f"**Ran code snippets**: {len(this_iteration.codeContext)}"):
|
||||
yield result
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
f"Failed to use code tool: {e}. Attempting to respond without code results",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
elif this_iteration.tool == ConversationCommand.Summarize:
|
||||
try:
|
||||
async for result in generate_summary_from_files(
|
||||
this_iteration.query,
|
||||
user,
|
||||
file_filters,
|
||||
construct_tool_chat_history(previous_iterations),
|
||||
query_images=query_images,
|
||||
agent=agent,
|
||||
send_status_func=send_status_func,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
else:
|
||||
summarize_files = result # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating summary: {e}", exc_info=True)
|
||||
|
||||
else:
|
||||
# No valid tools. This is our exit condition.
|
||||
current_iteration = MAX_ITERATIONS
|
||||
|
||||
current_iteration += 1
|
||||
|
||||
if document_results or online_results or code_results or summarize_files:
|
||||
results_data = f"**Results**:\n"
|
||||
if document_results:
|
||||
results_data += f"**Document References**: {yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
|
||||
if online_results:
|
||||
results_data += f"**Online Results**: {yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
|
||||
if code_results:
|
||||
results_data += f"**Code Results**: {yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
|
||||
if summarize_files:
|
||||
results_data += f"**Summarized Files**: {yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
|
||||
|
||||
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
|
||||
this_iteration.summarizedResult = results_data
|
||||
|
||||
previous_iterations.append(this_iteration)
|
||||
yield this_iteration
|
|
@ -7,8 +7,6 @@ from math import inf
|
|||
from typing import List, Tuple
|
||||
|
||||
import dateparser as dtparse
|
||||
from dateparser.search import search_dates
|
||||
from dateparser_data.settings import default_parsers
|
||||
from dateutil.relativedelta import relativedelta
|
||||
|
||||
from khoj.search_filter.base_filter import BaseFilter
|
||||
|
@ -23,7 +21,7 @@ class DateFilter(BaseFilter):
|
|||
# - dt>="yesterday" dt<"tomorrow"
|
||||
# - dt>="last week"
|
||||
# - dt:"2 years ago"
|
||||
date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']"
|
||||
date_regex = r"dt([:><=]{1,2})[\"'‘’](.*?)[\"'‘’]"
|
||||
|
||||
def __init__(self, entry_key="compiled"):
|
||||
self.entry_key = entry_key
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
import fnmatch
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
from khoj.search_filter.base_filter import BaseFilter
|
||||
from khoj.utils.helpers import LRU, timer
|
||||
from khoj.utils.helpers import LRU
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
|
@ -102,8 +102,8 @@ def load_embeddings(
|
|||
|
||||
|
||||
async def query(
|
||||
user: KhojUser,
|
||||
raw_query: str,
|
||||
user: KhojUser,
|
||||
type: SearchType = SearchType.All,
|
||||
question_embedding: Union[torch.Tensor, None] = None,
|
||||
max_distance: float = None,
|
||||
|
@ -130,12 +130,12 @@ async def query(
|
|||
top_k = 10
|
||||
with timer("Search Time", logger, state.device):
|
||||
hits = EntryAdapters.search_with_embeddings(
|
||||
user=user,
|
||||
raw_query=raw_query,
|
||||
embeddings=question_embedding,
|
||||
max_results=top_k,
|
||||
file_type_filter=file_type,
|
||||
raw_query=raw_query,
|
||||
max_distance=max_distance,
|
||||
user=user,
|
||||
agent=agent,
|
||||
).all()
|
||||
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
||||
|
|
|
@ -313,12 +313,14 @@ class ConversationCommand(str, Enum):
|
|||
Help = "help"
|
||||
Online = "online"
|
||||
Webpage = "webpage"
|
||||
Code = "code"
|
||||
Image = "image"
|
||||
Text = "text"
|
||||
Automation = "automation"
|
||||
AutomatedTask = "automated_task"
|
||||
Summarize = "summarize"
|
||||
Diagram = "diagram"
|
||||
Research = "research"
|
||||
|
||||
|
||||
command_descriptions = {
|
||||
|
@ -327,11 +329,13 @@ command_descriptions = {
|
|||
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
|
||||
ConversationCommand.Online: "Search for information on the internet.",
|
||||
ConversationCommand.Webpage: "Get information from webpage suggested by you.",
|
||||
ConversationCommand.Code: "Run Python code to parse information, run complex calculations, create documents and charts.",
|
||||
ConversationCommand.Image: "Generate illustrative, creative images by describing your imagination in words.",
|
||||
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
|
||||
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
|
||||
ConversationCommand.Summarize: "Get help with a question pertaining to an entire document.",
|
||||
ConversationCommand.Diagram: "Draw a flowchart, diagram, or any other visual representation best expressed with primitives like lines, rectangles, and text.",
|
||||
ConversationCommand.Research: "Do deep research on a topic. This will take longer than usual, but give a more detailed, comprehensive answer.",
|
||||
}
|
||||
|
||||
command_descriptions_for_agent = {
|
||||
|
@ -340,6 +344,7 @@ command_descriptions_for_agent = {
|
|||
ConversationCommand.Online: "Agent can search the internet for information.",
|
||||
ConversationCommand.Webpage: "Agent can read suggested web pages for information.",
|
||||
ConversationCommand.Summarize: "Agent can read an entire document. Agents knowledge base must be a single document.",
|
||||
ConversationCommand.Research: "Agent can do deep research on a topic.",
|
||||
}
|
||||
|
||||
tool_descriptions_for_llm = {
|
||||
|
@ -348,18 +353,26 @@ tool_descriptions_for_llm = {
|
|||
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
|
||||
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
|
||||
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
|
||||
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.",
|
||||
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
|
||||
}
|
||||
|
||||
function_calling_description_for_llm = {
|
||||
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
|
||||
ConversationCommand.Online: "To search the internet for information. Provide all relevant context to ensure new searches, not previously run, are performed.",
|
||||
ConversationCommand.Webpage: "To extract information from a webpage. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage link and information to extract in your query.",
|
||||
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.",
|
||||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
ConversationCommand.Image: "Use this if you are confident the user is requesting you to create a new picture based on their description.",
|
||||
ConversationCommand.Image: "Use this if you are confident the user is requesting you to create a new picture based on their description. This does not support generating charts or graphs.",
|
||||
ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency",
|
||||
ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query.",
|
||||
ConversationCommand.Diagram: "Use this if the user is requesting a diagram or visual representation that requires primitives like lines, rectangles, and text.",
|
||||
}
|
||||
|
||||
mode_descriptions_for_agent = {
|
||||
ConversationCommand.Image: "Agent can generate image in response.",
|
||||
ConversationCommand.Image: "Agent can generate images in response. It cannot not use this to generate charts and graphs.",
|
||||
ConversationCommand.Automation: "Agent can schedule a task to run at a scheduled date, time and frequency in response.",
|
||||
ConversationCommand.Text: "Agent can generate text in response.",
|
||||
ConversationCommand.Diagram: "Agent can generate a visual representation that requires primitives like lines, rectangles, and text.",
|
||||
|
|
|
@ -41,3 +41,7 @@ def parse_config_from_string(yaml_config: dict) -> FullConfig:
|
|||
def parse_config_from_file(yaml_config_file):
|
||||
"Parse and validate config in YML file"
|
||||
return parse_config_from_string(load_config_from_file(yaml_config_file))
|
||||
|
||||
|
||||
def yaml_dump(data):
|
||||
return yaml.dump(data, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
||||
|
|
|
@ -164,7 +164,7 @@ async def test_text_search(search_config: SearchConfig):
|
|||
query = "Load Khoj on Emacs?"
|
||||
|
||||
# Act
|
||||
hits = await text_search.query(default_user, query)
|
||||
hits = await text_search.query(query, default_user)
|
||||
results = text_search.collate_results(hits)
|
||||
results = sorted(results, key=lambda x: float(x.score))[:1]
|
||||
|
||||
|
|
Loading…
Reference in a new issue