Research Mode: Give Khoj the ability to perform more advanced reasoning (#952)

## Overview
Khoj can now go into research mode and use a python code interpreter. These are experimental features that are being released early for feedback and testing.

- Research mode allows Khoj to dynamically select the tools it needs to best answer the question. It is also allowed more iterations to get to a satisfactory answer. Its more dynamic train of thought is shown to improve visibility into its thinking.
- Adding ability for Khoj to use a python code interpreter is an adjacent capability. It can help Khoj do some data analysis and generate charts for you. A sandboxed python to run code is provided using [cohere-terrarium](https://github.com/cohere-ai/cohere-terrarium?tab=readme-ov-file), [pyodide](https://pyodide.org/).

## Analysis
Research mode (significantly?) improves Khoj's information retrieval for more complex queries requiring multi-step lookups but takes longer to run. It can research for longer, requiring less back-n-forth with the user to find an answer.

Research mode gives most gains when used with more advanced chat models (like o1, 4o, new claude sonnet and gemini-pro-002). Smaller models improve their response quality but tend to get into repetitive loops more often. 

## Next Steps
- Get community feedback on research mode. What works, what fails, what is confusing, what'd be cool to have.
- Tune Khoj's capabilities for longer autonomous runs and to generalize across a larger range of model sizes

## Miscellaneous Improvements
- Khoj's train of thought is saved and shown for all messages, not just the latest one
- Render charts generated by Khoj and code running using the code tool on the web app
- Align chat input color to currently selected agent color
This commit is contained in:
Debanjum 2024-11-01 14:46:29 -07:00 committed by GitHub
commit 22f3ed3f5d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1726 additions and 471 deletions

View file

@ -12,7 +12,12 @@ import { processMessageChunk } from "../common/chatFunctions";
import "katex/dist/katex.min.css";
import { Context, OnlineContext, StreamMessage } from "../components/chatMessage/chatMessage";
import {
CodeContext,
Context,
OnlineContext,
StreamMessage,
} from "../components/chatMessage/chatMessage";
import { useIPLocationData, useIsMobileWidth, welcomeConsole } from "../common/utils";
import { ChatInputArea, ChatOptions } from "../components/chatInputArea/chatInputArea";
import { useAuthenticatedData } from "../common/auth";
@ -37,6 +42,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
const [images, setImages] = useState<string[]>([]);
const [processingMessage, setProcessingMessage] = useState(false);
const [agentMetadata, setAgentMetadata] = useState<AgentData | null>(null);
const [isInResearchMode, setIsInResearchMode] = useState(false);
const chatInputRef = useRef<HTMLTextAreaElement>(null);
const setQueryToProcess = props.setQueryToProcess;
@ -65,6 +71,10 @@ function ChatBodyData(props: ChatBodyDataProps) {
if (storedMessage) {
setProcessingMessage(true);
setQueryToProcess(storedMessage);
if (storedMessage.trim().startsWith("/research")) {
setIsInResearchMode(true);
}
}
}, [setQueryToProcess, props.setImages]);
@ -125,6 +135,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
isMobileWidth={props.isMobileWidth}
setUploadedFiles={props.setUploadedFiles}
ref={chatInputRef}
isResearchModeEnabled={isInResearchMode}
/>
</div>
</>
@ -174,6 +185,7 @@ export default function Chat() {
trainOfThought: [],
context: [],
onlineContext: {},
codeContext: {},
completed: false,
timestamp: new Date().toISOString(),
rawQuery: queryToProcess || "",
@ -202,6 +214,7 @@ export default function Chat() {
// Track context used for chat response
let context: Context[] = [];
let onlineContext: OnlineContext = {};
let codeContext: CodeContext = {};
while (true) {
const { done, value } = await reader.read();
@ -228,11 +241,12 @@ export default function Chat() {
}
// Track context used for chat response. References are rendered at the end of the chat
({ context, onlineContext } = processMessageChunk(
({ context, onlineContext, codeContext } = processMessageChunk(
event,
currentMessage,
context,
onlineContext,
codeContext,
));
setMessages([...messages]);

View file

@ -1,8 +1,14 @@
import { Context, OnlineContext, StreamMessage } from "../components/chatMessage/chatMessage";
import {
CodeContext,
Context,
OnlineContext,
StreamMessage,
} from "../components/chatMessage/chatMessage";
export interface RawReferenceData {
context?: Context[];
onlineContext?: OnlineContext;
codeContext?: CodeContext;
}
export interface ResponseWithIntent {
@ -67,10 +73,11 @@ export function processMessageChunk(
currentMessage: StreamMessage,
context: Context[] = [],
onlineContext: OnlineContext = {},
): { context: Context[]; onlineContext: OnlineContext } {
codeContext: CodeContext = {},
): { context: Context[]; onlineContext: OnlineContext; codeContext: CodeContext } {
const chunk = convertMessageChunkToJson(rawChunk);
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext };
if (!currentMessage || !chunk || !chunk.type) return { context, onlineContext, codeContext };
if (chunk.type === "status") {
console.log(`status: ${chunk.data}`);
@ -81,7 +88,8 @@ export function processMessageChunk(
if (references.context) context = references.context;
if (references.onlineContext) onlineContext = references.onlineContext;
return { context, onlineContext };
if (references.codeContext) codeContext = references.codeContext;
return { context, onlineContext, codeContext };
} else if (chunk.type === "message") {
const chunkData = chunk.data;
// Here, handle if the response is a JSON response with an image, but the intentType is excalidraw
@ -119,13 +127,41 @@ export function processMessageChunk(
console.log(`Completed streaming: ${new Date()}`);
// Append any references after all the data has been streamed
if (codeContext) currentMessage.codeContext = codeContext;
if (onlineContext) currentMessage.onlineContext = onlineContext;
if (context) currentMessage.context = context;
// Replace file links with base64 data
currentMessage.rawResponse = renderCodeGenImageInline(
currentMessage.rawResponse,
codeContext,
);
// Add code context files to the message
if (codeContext) {
Object.entries(codeContext).forEach(([key, value]) => {
value.results.output_files?.forEach((file) => {
if (file.filename.endsWith(".png") || file.filename.endsWith(".jpg")) {
// Don't add the image again if it's already in the message!
if (!currentMessage.rawResponse.includes(`![${file.filename}](`)) {
currentMessage.rawResponse += `\n\n![${file.filename}](data:image/png;base64,${file.b64_data})`;
}
} else if (
file.filename.endsWith(".txt") ||
file.filename.endsWith(".org") ||
file.filename.endsWith(".md")
) {
const decodedText = atob(file.b64_data);
currentMessage.rawResponse += `\n\n\`\`\`\n${decodedText}\n\`\`\``;
}
});
});
}
// Mark current message streaming as completed
currentMessage.completed = true;
}
return { context, onlineContext };
return { context, onlineContext, codeContext };
}
export function handleImageResponse(imageJson: any, liveStream: boolean): ResponseWithIntent {
@ -150,6 +186,22 @@ export function handleImageResponse(imageJson: any, liveStream: boolean): Respon
return responseWithIntent;
}
export function renderCodeGenImageInline(message: string, codeContext: CodeContext) {
if (!codeContext) return message;
Object.values(codeContext).forEach((contextData) => {
contextData.results.output_files?.forEach((file) => {
const regex = new RegExp(`!?\\[.*?\\]\\(.*${file.filename}\\)`, "g");
if (file.filename.match(/\.(png|jpg|jpeg|gif|webp)$/i)) {
const replacement = `![${file.filename}](data:image/${file.filename.split(".").pop()};base64,${file.b64_data})`;
message = message.replace(regex, replacement);
}
});
});
return message;
}
export function modifyFileFilterForConversation(
conversationId: string | null,
filenames: string[],

View file

@ -42,6 +42,13 @@ export function converColorToBgGradient(color: string) {
return `${convertToBGGradientClass(color)} dark:border dark:border-neutral-700`;
}
export function convertColorToRingClass(color: string | undefined) {
if (color && tailwindColors.includes(color)) {
return `focus-visible:ring-${color}-500`;
}
return `focus-visible:ring-orange-500`;
}
export function convertColorToBorderClass(color: string) {
if (tailwindColors.includes(color)) {
return `border-${color}-500`;

View file

@ -13,13 +13,14 @@ import { ScrollArea } from "@/components/ui/scroll-area";
import { InlineLoading } from "../loading/loading";
import { Lightbulb, ArrowDown } from "@phosphor-icons/react";
import { Lightbulb, ArrowDown, XCircle } from "@phosphor-icons/react";
import AgentProfileCard from "../profileCard/profileCard";
import { getIconFromIconName } from "@/app/common/iconUtils";
import { AgentData } from "@/app/agents/page";
import React from "react";
import { useIsMobileWidth } from "@/app/common/utils";
import { Button } from "@/components/ui/button";
interface ChatResponse {
status: string;
@ -40,24 +41,52 @@ interface ChatHistoryProps {
customClassName?: string;
}
function constructTrainOfThought(
trainOfThought: string[],
lastMessage: boolean,
agentColor: string,
key: string,
completed: boolean = false,
) {
const lastIndex = trainOfThought.length - 1;
return (
<div className={`${styles.trainOfThought} shadow-sm`} key={key}>
{!completed && <InlineLoading className="float-right" />}
interface TrainOfThoughtComponentProps {
trainOfThought: string[];
lastMessage: boolean;
agentColor: string;
key: string;
completed?: boolean;
}
{trainOfThought.map((train, index) => (
function TrainOfThoughtComponent(props: TrainOfThoughtComponentProps) {
const lastIndex = props.trainOfThought.length - 1;
const [collapsed, setCollapsed] = useState(props.completed);
return (
<div
className={`${!collapsed ? styles.trainOfThought + " shadow-sm" : ""}`}
key={props.key}
>
{!props.completed && <InlineLoading className="float-right" />}
{props.completed &&
(collapsed ? (
<Button
className="w-fit text-left justify-start content-start text-xs"
onClick={() => setCollapsed(false)}
variant="ghost"
size="sm"
>
What was my train of thought?
</Button>
) : (
<Button
className="w-fit text-left justify-start content-start text-xs p-0 h-fit"
onClick={() => setCollapsed(true)}
variant="ghost"
size="sm"
>
<XCircle size={16} className="mr-1" />
Close
</Button>
))}
{!collapsed &&
props.trainOfThought.map((train, index) => (
<TrainOfThought
key={`train-${index}`}
message={train}
primary={index === lastIndex && lastMessage && !completed}
agentColor={agentColor}
primary={index === lastIndex && props.lastMessage && !props.completed}
agentColor={props.agentColor}
/>
))}
</div>
@ -254,17 +283,16 @@ export default function ChatHistory(props: ChatHistoryProps) {
}
return (
<ScrollArea className={`h-[80vh] relative`} ref={scrollAreaRef}>
<ScrollArea className={`h-[73vh] relative`} ref={scrollAreaRef}>
<div>
<div className={`${styles.chatHistory} ${props.customClassName}`}>
<div ref={sentinelRef} style={{ height: "1px" }}>
{fetchingData && (
<InlineLoading message="Loading Conversation" className="opacity-50" />
)}
{fetchingData && <InlineLoading className="opacity-50" />}
</div>
{data &&
data.chat &&
data.chat.map((chatMessage, index) => (
<>
<ChatMessage
key={`${index}fullHistory`}
ref={
@ -274,7 +302,8 @@ export default function ChatHistory(props: ChatHistoryProps) {
: // attach ref to the newest fetched message to handle scroll on fetch
// note: stabilize index selection against last page having less messages than fetchMessageCount
index ===
data.chat.length - (currentPage - 1) * fetchMessageCount
data.chat.length -
(currentPage - 1) * fetchMessageCount
? latestFetchedMessageRef
: null
}
@ -284,6 +313,18 @@ export default function ChatHistory(props: ChatHistoryProps) {
borderLeftColor={`${data?.agent?.color}-500`}
isLastMessage={index === data.chat.length - 1}
/>
{chatMessage.trainOfThought && chatMessage.by === "khoj" && (
<TrainOfThoughtComponent
trainOfThought={chatMessage.trainOfThought?.map(
(train) => train.data,
)}
lastMessage={false}
agentColor={data?.agent?.color || "orange"}
key={`${index}trainOfThought`}
completed={true}
/>
)}
</>
))}
{props.incomingMessages &&
props.incomingMessages.map((message, index) => {
@ -296,6 +337,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
message: message.rawQuery,
context: [],
onlineContext: {},
codeContext: {},
created: message.timestamp,
by: "you",
automationId: "",
@ -304,13 +346,14 @@ export default function ChatHistory(props: ChatHistoryProps) {
customClassName="fullHistory"
borderLeftColor={`${data?.agent?.color}-500`}
/>
{message.trainOfThought &&
constructTrainOfThought(
message.trainOfThought,
index === incompleteIncomingMessageIndex,
data?.agent?.color || "orange",
`${index}trainOfThought`,
message.completed,
{message.trainOfThought && (
<TrainOfThoughtComponent
trainOfThought={message.trainOfThought}
lastMessage={index === incompleteIncomingMessageIndex}
agentColor={data?.agent?.color || "orange"}
key={`${index}trainOfThought`}
completed={message.completed}
/>
)}
<ChatMessage
key={`${index}incoming`}
@ -319,6 +362,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
message: message.rawResponse,
context: message.context,
onlineContext: message.onlineContext,
codeContext: message.codeContext,
created: message.timestamp,
by: "khoj",
automationId: "",
@ -345,6 +389,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
message: props.pendingMessage,
context: [],
onlineContext: {},
codeContext: {},
created: new Date().getTime().toString(),
by: "you",
automationId: "",
@ -372,7 +417,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
</div>
)}
</div>
<div className={`${props.customClassName} fixed bottom-[15%] z-10`}>
<div className={`${props.customClassName} fixed bottom-[20%] z-10`}>
{!isNearBottom && (
<button
title="Scroll to bottom"

View file

@ -3,7 +3,15 @@ import React, { useEffect, useRef, useState, forwardRef } from "react";
import DOMPurify from "dompurify";
import "katex/dist/katex.min.css";
import { ArrowUp, Microphone, Paperclip, X, Stop } from "@phosphor-icons/react";
import {
ArrowUp,
Microphone,
Paperclip,
X,
Stop,
ToggleLeft,
ToggleRight,
} from "@phosphor-icons/react";
import {
Command,
@ -29,7 +37,7 @@ import { Popover, PopoverContent } from "@/components/ui/popover";
import { PopoverTrigger } from "@radix-ui/react-popover";
import { Textarea } from "@/components/ui/textarea";
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from "@/components/ui/tooltip";
import { convertToBGClass } from "@/app/common/colorUtils";
import { convertColorToTextClass, convertToBGClass } from "@/app/common/colorUtils";
import LoginPrompt from "../loginPrompt/loginPrompt";
import { uploadDataForIndexing } from "../../common/chatFunctions";
@ -50,6 +58,7 @@ interface ChatInputProps {
isMobileWidth?: boolean;
isLoggedIn: boolean;
agentColor?: string;
isResearchModeEnabled?: boolean;
}
export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((props, ref) => {
@ -72,6 +81,11 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
const [progressValue, setProgressValue] = useState(0);
const [isDragAndDropping, setIsDragAndDropping] = useState(false);
const [showCommandList, setShowCommandList] = useState(false);
const [useResearchMode, setUseResearchMode] = useState<boolean>(
props.isResearchModeEnabled || false,
);
const chatInputRef = ref as React.MutableRefObject<HTMLTextAreaElement>;
useEffect(() => {
if (!uploading) {
@ -112,6 +126,12 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
fetchImageData();
}, [imagePaths]);
useEffect(() => {
if (props.isResearchModeEnabled) {
setUseResearchMode(props.isResearchModeEnabled);
}
}, [props.isResearchModeEnabled]);
function onSendMessage() {
if (imageUploaded) {
setImageUploaded(false);
@ -128,7 +148,12 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
return;
}
props.sendMessage(message.trim());
let messageToSend = message.trim();
if (useResearchMode) {
messageToSend = `/research ${messageToSend}`;
}
props.sendMessage(messageToSend);
setMessage("");
}
@ -275,6 +300,12 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
chatInputRef.current.style.height = "auto";
chatInputRef.current.style.height =
Math.max(chatInputRef.current.scrollHeight - 24, 64) + "px";
if (message.startsWith("/") && message.split(" ").length === 1) {
setShowCommandList(true);
} else {
setShowCommandList(false);
}
}, [message]);
function handleDragOver(event: React.DragEvent<HTMLDivElement>) {
@ -360,9 +391,9 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
</AlertDialogContent>
</AlertDialog>
)}
{message.startsWith("/") && message.split(" ").length === 1 && (
{showCommandList && (
<div className="flex justify-center text-center">
<Popover open={message.startsWith("/")}>
<Popover open={showCommandList} onOpenChange={setShowCommandList}>
<PopoverTrigger className="flex justify-center text-center"></PopoverTrigger>
<PopoverContent
onOpenAutoFocus={(e) => e.preventDefault()}
@ -413,6 +444,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
</Popover>
</div>
)}
<div>
<div
className={`${styles.actualInputArea} justify-between dark:bg-neutral-700 relative ${isDragAndDropping && "animate-pulse"}`}
onDragOver={handleDragOver}
@ -426,7 +458,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
onChange={handleFileChange}
style={{ display: "none" }}
/>
<div className="flex items-end pb-4">
<div className="flex items-center">
<Button
variant={"ghost"}
className="!bg-none p-0 m-2 h-auto text-3xl rounded-full text-gray-300 hover:text-gray-500"
@ -436,11 +468,14 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
<Paperclip className="w-8 h-8" />
</Button>
</div>
<div className="flex-grow flex flex-col w-full gap-1.5 relative pb-2">
<div className="flex-grow flex flex-col w-full gap-1.5 relative">
<div className="flex items-center gap-2 overflow-x-auto">
{imageUploaded &&
imagePaths.map((path, index) => (
<div key={index} className="relative flex-shrink-0 pb-3 pt-2 group">
<div
key={index}
className="relative flex-shrink-0 pb-3 pt-2 group"
>
<img
src={path}
alt={`img-${index}`}
@ -459,7 +494,10 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
</div>
<Textarea
ref={chatInputRef}
className={`border-none w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700 ${props.isMobileWidth ? "text-md" : "text-lg"}`}
className={`border-none focus:border-none
focus:outline-none focus-visible:ring-transparent
w-full h-16 min-h-16 max-h-[128px] md:py-4 rounded-lg resize-none dark:bg-neutral-700
${props.isMobileWidth ? "text-md" : "text-lg"}`}
placeholder="Type / to see a list of commands"
id="message"
autoFocus={true}
@ -476,7 +514,7 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
disabled={props.sendDisabled || recording}
/>
</div>
<div className="flex items-end pb-4">
<div className="flex items-center">
{recording ? (
<TooltipProvider>
<Tooltip>
@ -530,6 +568,34 @@ export const ChatInputArea = forwardRef<HTMLTextAreaElement, ChatInputProps>((pr
</Button>
</div>
</div>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="ghost"
className="float-right justify-center gap-1 flex items-center p-1.5 mr-2 h-fit"
onClick={() => {
setUseResearchMode(!useResearchMode);
chatInputRef?.current?.focus();
}}
>
<span className="text-muted-foreground text-sm">Research Mode</span>
{useResearchMode ? (
<ToggleRight
className={`w-6 h-6 inline-block ${props.agentColor ? convertColorToTextClass(props.agentColor) : convertColorToTextClass("orange")} rounded-full`}
/>
) : (
<ToggleLeft className={`w-6 h-6 inline-block rounded-full`} />
)}
</Button>
</TooltipTrigger>
<TooltipContent className="text-xs">
Research Mode allows you to get more deeply researched, detailed
responses. Response times may be longer.
</TooltipContent>
</Tooltip>
</TooltipProvider>
</div>
</>
);
});

View file

@ -10,6 +10,7 @@ import { createRoot } from "react-dom/client";
import "katex/dist/katex.min.css";
import { TeaserReferencesSection, constructAllReferences } from "../referencePanel/referencePanel";
import { renderCodeGenImageInline } from "@/app/common/chatFunctions";
import {
ThumbsUp,
@ -26,6 +27,7 @@ import {
Palette,
ClipboardText,
Check,
Code,
Shapes,
} from "@phosphor-icons/react";
@ -99,6 +101,26 @@ export interface OnlineContextData {
peopleAlsoAsk: PeopleAlsoAsk[];
}
export interface CodeContext {
[key: string]: CodeContextData;
}
export interface CodeContextData {
code: string;
results: {
success: boolean;
output_files: CodeContextFile[];
std_out: string;
std_err: string;
code_runtime: number;
};
}
export interface CodeContextFile {
filename: string;
b64_data: string;
}
interface Intent {
type: string;
query: string;
@ -106,6 +128,11 @@ interface Intent {
"inferred-queries": string[];
}
interface TrainOfThoughtObject {
type: string;
data: string;
}
export interface SingleChatMessage {
automationId: string;
by: string;
@ -113,6 +140,8 @@ export interface SingleChatMessage {
created: string;
context: Context[];
onlineContext: OnlineContext;
codeContext: CodeContext;
trainOfThought?: TrainOfThoughtObject[];
rawQuery?: string;
intent?: Intent;
agent?: AgentData;
@ -124,6 +153,7 @@ export interface StreamMessage {
trainOfThought: string[];
context: Context[];
onlineContext: OnlineContext;
codeContext: CodeContext;
completed: boolean;
rawQuery: string;
timestamp: string;
@ -263,6 +293,10 @@ function chooseIconFromHeader(header: string, iconColor: string) {
return <Palette className={`${classNames}`} />;
}
if (compareHeader.includes("code")) {
return <Code className={`${classNames}`} />;
}
return <Brain className={`${classNames}`} />;
}
@ -273,12 +307,15 @@ export function TrainOfThought(props: TrainOfThoughtProps) {
const iconColor = props.primary ? convertColorToTextClass(props.agentColor) : "text-gray-500";
const icon = chooseIconFromHeader(header, iconColor);
let markdownRendered = DOMPurify.sanitize(md.render(props.message));
// Remove any header tags from markdownRendered
markdownRendered = markdownRendered.replace(/<h[1-6].*?<\/h[1-6]>/g, "");
return (
<div
className={`${styles.trainOfThoughtElement} break-all items-center ${props.primary ? "text-gray-400" : "text-gray-300"} ${styles.trainOfThought} ${props.primary ? styles.primary : ""}`}
className={`${styles.trainOfThoughtElement} break-words items-center ${props.primary ? "text-gray-400" : "text-gray-300"} ${styles.trainOfThought} ${props.primary ? styles.primary : ""}`}
>
{icon}
<div dangerouslySetInnerHTML={{ __html: markdownRendered }} />
<div dangerouslySetInnerHTML={{ __html: markdownRendered }} className="break-words" />
</div>
);
}
@ -388,6 +425,48 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
messageToRender = `${userImagesInHtml}${messageToRender}`;
}
if (props.chatMessage.intent && props.chatMessage.intent.type == "text-to-image") {
message = `![generated image](data:image/png;base64,${message})`;
} else if (props.chatMessage.intent && props.chatMessage.intent.type == "text-to-image2") {
message = `![generated image](${message})`;
} else if (
props.chatMessage.intent &&
props.chatMessage.intent.type == "text-to-image-v3"
) {
message = `![generated image](data:image/webp;base64,${message})`;
}
if (
props.chatMessage.intent &&
props.chatMessage.intent.type.includes("text-to-image") &&
props.chatMessage.intent["inferred-queries"]?.length > 0
) {
message += `\n\n${props.chatMessage.intent["inferred-queries"][0]}`;
}
// Replace file links with base64 data
message = renderCodeGenImageInline(message, props.chatMessage.codeContext);
// Add code context files to the message
if (props.chatMessage.codeContext) {
Object.entries(props.chatMessage.codeContext).forEach(([key, value]) => {
value.results.output_files?.forEach((file) => {
if (file.filename.endsWith(".png") || file.filename.endsWith(".jpg")) {
// Don't add the image again if it's already in the message!
if (!message.includes(`![${file.filename}](`)) {
message += `\n\n![${file.filename}](data:image/png;base64,${file.b64_data})`;
}
} else if (
file.filename.endsWith(".txt") ||
file.filename.endsWith(".org") ||
file.filename.endsWith(".md")
) {
const decodedText = atob(file.b64_data);
message += `\n\n\`\`\`\n${decodedText}\n\`\`\``;
}
});
});
}
// Set the message text
setTextRendered(messageForClipboard);
@ -578,6 +657,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
const allReferences = constructAllReferences(
props.chatMessage.context,
props.chatMessage.onlineContext,
props.chatMessage.codeContext,
);
return (
@ -600,6 +680,7 @@ const ChatMessage = forwardRef<HTMLDivElement, ChatMessageProps>((props, ref) =>
isMobileWidth={props.isMobileWidth}
notesReferenceCardData={allReferences.notesReferenceCardData}
onlineReferenceCardData={allReferences.onlineReferenceCardData}
codeReferenceCardData={allReferences.codeReferenceCardData}
/>
</div>
<div className={styles.chatFooter}>

View file

@ -11,7 +11,7 @@ const md = new markdownIt({
typographer: true,
});
import { Context, WebPage, OnlineContext } from "../chatMessage/chatMessage";
import { Context, WebPage, OnlineContext, CodeContext } from "../chatMessage/chatMessage";
import { Card } from "@/components/ui/card";
import {
@ -94,11 +94,67 @@ function NotesContextReferenceCard(props: NotesContextReferenceCardProps) {
);
}
interface CodeContextReferenceCardProps {
code: string;
output: string;
error: string;
showFullContent: boolean;
}
function CodeContextReferenceCard(props: CodeContextReferenceCardProps) {
const fileIcon = getIconFromFilename(".py", "w-6 h-6 text-muted-foreground inline-flex mr-2");
const snippet = DOMPurify.sanitize(props.code);
const [isHovering, setIsHovering] = useState(false);
return (
<>
<Popover open={isHovering && !props.showFullContent} onOpenChange={setIsHovering}>
<PopoverTrigger asChild>
<Card
onMouseEnter={() => setIsHovering(true)}
onMouseLeave={() => setIsHovering(false)}
className={`${props.showFullContent ? "w-auto" : "w-[200px]"} overflow-hidden break-words text-balance rounded-lg p-2 bg-muted border-none`}
>
<h3
className={`${props.showFullContent ? "block" : "line-clamp-1"} text-muted-foreground}`}
>
{fileIcon}
Code
</h3>
<p
className={`${props.showFullContent ? "block" : "overflow-hidden line-clamp-2"}`}
>
{snippet}
</p>
</Card>
</PopoverTrigger>
<PopoverContent className="w-[400px] mx-2">
<Card
className={`w-auto overflow-hidden break-words text-balance rounded-lg p-2 border-none`}
>
<h3 className={`line-clamp-2 text-muted-foreground}`}>
{fileIcon}
Code
</h3>
<p className={`overflow-hidden line-clamp-3`}>{snippet}</p>
</Card>
</PopoverContent>
</Popover>
</>
);
}
export interface ReferencePanelData {
notesReferenceCardData: NotesContextReferenceData[];
onlineReferenceCardData: OnlineReferenceData[];
}
export interface CodeReferenceData {
code: string;
output: string;
error: string;
}
interface OnlineReferenceData {
title: string;
description: string;
@ -214,9 +270,27 @@ function GenericOnlineReferenceCard(props: OnlineReferenceCardProps) {
);
}
export function constructAllReferences(contextData: Context[], onlineData: OnlineContext) {
export function constructAllReferences(
contextData: Context[],
onlineData: OnlineContext,
codeContext: CodeContext,
) {
const onlineReferences: OnlineReferenceData[] = [];
const contextReferences: NotesContextReferenceData[] = [];
const codeReferences: CodeReferenceData[] = [];
if (codeContext) {
for (const [key, value] of Object.entries(codeContext)) {
if (!value.results) {
continue;
}
codeReferences.push({
code: value.code,
output: value.results.std_out,
error: value.results.std_err,
});
}
}
if (onlineData) {
let localOnlineReferences = [];
@ -298,12 +372,14 @@ export function constructAllReferences(contextData: Context[], onlineData: Onlin
return {
notesReferenceCardData: contextReferences,
onlineReferenceCardData: onlineReferences,
codeReferenceCardData: codeReferences,
};
}
export interface TeaserReferenceSectionProps {
notesReferenceCardData: NotesContextReferenceData[];
onlineReferenceCardData: OnlineReferenceData[];
codeReferenceCardData: CodeReferenceData[];
isMobileWidth: boolean;
}
@ -315,16 +391,27 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
}, [props.isMobileWidth]);
const notesDataToShow = props.notesReferenceCardData.slice(0, numTeaserSlots);
const codeDataToShow = props.codeReferenceCardData.slice(
0,
numTeaserSlots - notesDataToShow.length,
);
const onlineDataToShow =
notesDataToShow.length < numTeaserSlots
? props.onlineReferenceCardData.slice(0, numTeaserSlots - notesDataToShow.length)
notesDataToShow.length + codeDataToShow.length < numTeaserSlots
? props.onlineReferenceCardData.slice(
0,
numTeaserSlots - codeDataToShow.length - notesDataToShow.length,
)
: [];
const shouldShowShowMoreButton =
props.notesReferenceCardData.length > 0 || props.onlineReferenceCardData.length > 0;
props.notesReferenceCardData.length > 0 ||
props.codeReferenceCardData.length > 0 ||
props.onlineReferenceCardData.length > 0;
const numReferences =
props.notesReferenceCardData.length + props.onlineReferenceCardData.length;
props.notesReferenceCardData.length +
props.codeReferenceCardData.length +
props.onlineReferenceCardData.length;
if (numReferences === 0) {
return null;
@ -346,6 +433,15 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
/>
);
})}
{codeDataToShow.map((code, index) => {
return (
<CodeContextReferenceCard
showFullContent={false}
{...code}
key={`code-${index}`}
/>
);
})}
{onlineDataToShow.map((online, index) => {
return (
<GenericOnlineReferenceCard
@ -359,6 +455,7 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
<ReferencePanel
notesReferenceCardData={props.notesReferenceCardData}
onlineReferenceCardData={props.onlineReferenceCardData}
codeReferenceCardData={props.codeReferenceCardData}
/>
)}
</div>
@ -369,6 +466,7 @@ export function TeaserReferencesSection(props: TeaserReferenceSectionProps) {
interface ReferencePanelDataProps {
notesReferenceCardData: NotesContextReferenceData[];
onlineReferenceCardData: OnlineReferenceData[];
codeReferenceCardData: CodeReferenceData[];
}
export default function ReferencePanel(props: ReferencePanelDataProps) {
@ -406,6 +504,15 @@ export default function ReferencePanel(props: ReferencePanelDataProps) {
/>
);
})}
{props.codeReferenceCardData.map((code, index) => {
return (
<CodeContextReferenceCard
showFullContent={true}
{...code}
key={`code-${index}`}
/>
);
})}
</div>
</SheetContent>
</Sheet>

View file

@ -14,25 +14,21 @@ interface SuggestionCardProps {
export default function SuggestionCard(data: SuggestionCardProps) {
const bgColors = converColorToBgGradient(data.color);
const cardClassName = `${styles.card} ${bgColors} md:w-full md:h-fit sm:w-full h-fit md:w-[200px] md:h-[200px] cursor-pointer`;
const titleClassName = `${styles.title} pt-2 dark:text-white dark:font-bold`;
const cardClassName = `${styles.card} ${bgColors} md:w-full md:h-fit sm:w-full h-fit md:w-[200px] md:h-[180px] cursor-pointer md:p-2`;
const descriptionClassName = `${styles.text} dark:text-white`;
const cardContent = (
<Card className={cardClassName}>
<CardHeader className="m-0 p-2 pb-1 relative">
<div className="flex flex-row md:flex-col">
<div className="flex">
<CardContent className="m-0 p-2">
{convertSuggestionTitleToIconClass(data.title, data.color.toLowerCase())}
<CardTitle className={titleClassName}>{data.title}</CardTitle>
</div>
</CardHeader>
<CardContent className="m-0 p-2 pr-4 pt-1">
<CardDescription
className={`${descriptionClassName} sm:line-clamp-2 md:line-clamp-4`}
className={`${descriptionClassName} sm:line-clamp-2 md:line-clamp-4 pt-1`}
>
{data.body}
</CardDescription>
</CardContent>
</div>
</Card>
);

View file

@ -1,12 +1,9 @@
.card {
padding: 0.5rem;
margin: 0.05rem;
border-radius: 0.5rem;
}
.title {
font-size: 1.0rem;
font-size: 1rem;
}
.text {

View file

@ -47,24 +47,24 @@ const DEFAULT_COLOR = "orange";
export function convertSuggestionTitleToIconClass(title: string, color: string) {
if (title === SuggestionType.Automation)
return getIconFromIconName("Robot", color, "w-8", "h-8");
if (title === SuggestionType.Paint) return getIconFromIconName("Palette", color, "w-8", "h-8");
return getIconFromIconName("Robot", color, "w-6", "h-6");
if (title === SuggestionType.Paint) return getIconFromIconName("Palette", color, "w-6", "h-6");
if (title === SuggestionType.PopCulture)
return getIconFromIconName("Confetti", color, "w-8", "h-8");
if (title === SuggestionType.Travel) return getIconFromIconName("Jeep", color, "w-8", "h-8");
if (title === SuggestionType.Learning) return getIconFromIconName("Book", color, "w-8", "h-8");
return getIconFromIconName("Confetti", color, "w-6", "h-6");
if (title === SuggestionType.Travel) return getIconFromIconName("Jeep", color, "w-6", "h-6");
if (title === SuggestionType.Learning) return getIconFromIconName("Book", color, "w-6", "h-6");
if (title === SuggestionType.Health)
return getIconFromIconName("Asclepius", color, "w-8", "h-8");
if (title === SuggestionType.Fun) return getIconFromIconName("Island", color, "w-8", "h-8");
if (title === SuggestionType.Home) return getIconFromIconName("House", color, "w-8", "h-8");
return getIconFromIconName("Asclepius", color, "w-6", "h-6");
if (title === SuggestionType.Fun) return getIconFromIconName("Island", color, "w-6", "h-6");
if (title === SuggestionType.Home) return getIconFromIconName("House", color, "w-6", "h-6");
if (title === SuggestionType.Language)
return getIconFromIconName("Translate", color, "w-8", "h-8");
if (title === SuggestionType.Code) return getIconFromIconName("Code", color, "w-8", "h-8");
if (title === SuggestionType.Food) return getIconFromIconName("BowlFood", color, "w-8", "h-8");
return getIconFromIconName("Translate", color, "w-6", "h-6");
if (title === SuggestionType.Code) return getIconFromIconName("Code", color, "w-6", "h-6");
if (title === SuggestionType.Food) return getIconFromIconName("BowlFood", color, "w-6", "h-6");
if (title === SuggestionType.Interviewing)
return getIconFromIconName("Lectern", color, "w-8", "h-8");
if (title === SuggestionType.Finance) return getIconFromIconName("Wallet", color, "w-8", "h-8");
else return getIconFromIconName("Lightbulb", color, "w-8", "h-8");
return getIconFromIconName("Lectern", color, "w-6", "h-6");
if (title === SuggestionType.Finance) return getIconFromIconName("Wallet", color, "w-6", "h-6");
else return getIconFromIconName("Lightbulb", color, "w-6", "h-6");
}
export const suggestionsData: Suggestion[] = [

View file

@ -5,6 +5,7 @@ import { useAuthenticatedData } from "@/app/common/auth";
import { useState, useEffect } from "react";
import ChatMessage, {
CodeContext,
Context,
OnlineContext,
OnlineContextData,
@ -46,6 +47,7 @@ interface SupplementReferences {
interface ResponseWithReferences {
context?: Context[];
online?: OnlineContext;
code?: CodeContext;
response?: string;
}
@ -192,6 +194,7 @@ function ReferenceVerification(props: ReferenceVerificationProps) {
context: [],
created: new Date().toISOString(),
onlineContext: {},
codeContext: {},
}}
isMobileWidth={isMobileWidth}
/>
@ -622,6 +625,7 @@ export default function FactChecker() {
context: [],
created: new Date().toISOString(),
onlineContext: {},
codeContext: {},
}}
isMobileWidth={isMobileWidth}
/>

View file

@ -116,8 +116,8 @@ function ChatBodyData(props: ChatBodyDataProps) {
`What would you like to get done${nameSuffix}?`,
`Hey${nameSuffix}! How can I help?`,
`Good ${timeOfDay}${nameSuffix}! What's on your mind?`,
`Ready to breeze through your ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]}?`,
`Want help navigating your ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]} workload?`,
`Ready to breeze through ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]}?`,
`Let's navigate your ${["Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"][day]} workload`,
];
const greeting = greetings[Math.floor(Math.random() * greetings.length)];
setGreeting(greeting);
@ -305,6 +305,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
conversationId={null}
isMobileWidth={props.isMobileWidth}
setUploadedFiles={props.setUploadedFiles}
agentColor={agents.find((agent) => agent.slug === selectedAgent)?.color}
ref={chatInputRef}
/>
</div>
@ -386,6 +387,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
conversationId={null}
isMobileWidth={props.isMobileWidth}
setUploadedFiles={props.setUploadedFiles}
agentColor={agents.find((agent) => agent.slug === selectedAgent)?.color}
ref={chatInputRef}
/>
</div>

View file

@ -188,6 +188,7 @@ export default function SharedChat() {
trainOfThought: [],
context: [],
onlineContext: {},
codeContext: {},
completed: false,
timestamp: new Date().toISOString(),
rawQuery: queryToProcess || "",

View file

@ -1,34 +1,44 @@
import type { Config } from "tailwindcss"
import type { Config } from "tailwindcss";
const config = {
safelist: [
{
pattern: /to-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
variants: ['dark'],
pattern:
/to-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
variants: ["dark"],
},
{
pattern: /text-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
variants: ['dark'],
pattern:
/text-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
variants: ["dark"],
},
{
pattern: /border-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
variants: ['dark'],
pattern:
/border-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|950)/,
variants: ["dark"],
},
{
pattern: /border-l-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ['dark'],
pattern:
/border-l-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ["dark"],
},
{
pattern: /bg-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ['dark'],
}
pattern:
/bg-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ["dark"],
},
{
pattern:
/ring-(blue|yellow|green|pink|purple|orange|red|slate|gray|zinc|neutral|stone|amber|lime|green|emerald|teal|cyan|sky|blue|indigo|violet|fuchsia|rose)-(50|100|200|400|500|950)/,
variants: ["focus-visible", "dark"],
},
],
darkMode: ["class"],
content: [
'./pages/**/*.{ts,tsx}',
'./components/**/*.{ts,tsx}',
'./app/**/*.{ts,tsx}',
'./src/**/*.{ts,tsx}',
"./pages/**/*.{ts,tsx}",
"./components/**/*.{ts,tsx}",
"./app/**/*.{ts,tsx}",
"./src/**/*.{ts,tsx}",
],
prefix: "",
theme: {
@ -101,9 +111,7 @@ const config = {
},
},
},
plugins: [
require("tailwindcss-animate"),
],
} satisfies Config
plugins: [require("tailwindcss-animate")],
} satisfies Config;
export default config
export default config;

View file

@ -662,6 +662,8 @@ class AgentAdapters:
@staticmethod
async def ais_agent_accessible(agent: Agent, user: KhojUser) -> bool:
agent = await Agent.objects.select_related("creator").aget(pk=agent.pk)
if agent.privacy_level == Agent.PrivacyLevel.PUBLIC:
return True
if agent.creator == user:
@ -871,9 +873,13 @@ class ConversationAdapters:
agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
if agent is None:
raise HTTPException(status_code=400, detail="No such agent currently exists.")
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent, title=title)
return await Conversation.objects.select_related("agent", "agent__creator", "agent__chat_model").acreate(
user=user, client=client_application, agent=agent, title=title
)
agent = await AgentAdapters.aget_default_agent()
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent, title=title)
return await Conversation.objects.select_related("agent", "agent__creator", "agent__chat_model").acreate(
user=user, client=client_application, agent=agent, title=title
)
@staticmethod
def create_conversation_session(
@ -1014,7 +1020,14 @@ class ConversationAdapters:
"""Get default conversation config. Prefer chat model by server admin > user > first created chat model"""
# Get the server chat settings
server_chat_settings = ServerChatSettings.objects.first()
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
is_subscribed = is_user_subscribed(user) if user else False
if server_chat_settings:
# If the user is subscribed and the advanced model is enabled, return the advanced model
if is_subscribed and server_chat_settings.chat_advanced:
return server_chat_settings.chat_advanced
# If the default model is set, return it
if server_chat_settings.chat_default:
return server_chat_settings.chat_default
# Get the user's chat settings, if the server chat settings are not set
@ -1031,10 +1044,19 @@ class ConversationAdapters:
# Get the server chat settings
server_chat_settings: ServerChatSettings = (
await ServerChatSettings.objects.filter()
.prefetch_related("chat_default", "chat_default__openai_config")
.prefetch_related(
"chat_default", "chat_default__openai_config", "chat_advanced", "chat_advanced__openai_config"
)
.afirst()
)
if server_chat_settings is not None and server_chat_settings.chat_default is not None:
is_subscribed = await ais_user_subscribed(user) if user else False
if server_chat_settings:
# If the user is subscribed and the advanced model is enabled, return the advanced model
if is_subscribed and server_chat_settings.chat_advanced:
return server_chat_settings.chat_advanced
# If the default model is set, return it
if server_chat_settings.chat_default:
return server_chat_settings.chat_default
# Get the user's chat settings, if the server chat settings are not set
@ -1469,7 +1491,9 @@ class EntryAdapters:
@staticmethod
async def aget_agent_entry_filepaths(agent: Agent):
return await sync_to_async(list)(Entry.objects.filter(agent=agent).values_list("file_path", flat=True))
return await sync_to_async(set)(
Entry.objects.filter(agent=agent).distinct("file_path").values_list("file_path", flat=True)
)
@staticmethod
def get_all_filenames_by_source(user: KhojUser, file_source: str):
@ -1544,11 +1568,11 @@ class EntryAdapters:
@staticmethod
def search_with_embeddings(
user: KhojUser,
raw_query: str,
embeddings: Tensor,
user: KhojUser,
max_results: int = 10,
file_type_filter: str = None,
raw_query: str = None,
max_distance: float = math.inf,
agent: Agent = None,
):

View file

@ -14,11 +14,13 @@ from khoj.processor.conversation.anthropic.utils import (
format_messages_for_anthropic,
)
from khoj.processor.conversation.utils import (
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@ -90,15 +92,13 @@ def extract_questions_anthropic(
model_name=model,
temperature=temperature,
api_key=api_key,
response_type="json_object",
tracer=tracer,
)
# Extract, Clean Message from Claude's Response
try:
response = response.strip()
match = re.search(r"\{.*?\}", response)
if match:
response = match.group()
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
@ -112,7 +112,7 @@ def extract_questions_anthropic(
return questions
def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
def anthropic_send_message_to_model(messages, api_key, model, response_type="text", tracer={}):
"""
Send message to model
"""
@ -124,6 +124,7 @@ def anthropic_send_message_to_model(messages, api_key, model, tracer={}):
system_prompt=system_prompt,
model_name=model,
api_key=api_key,
response_type=response_type,
tracer=tracer,
)
@ -132,6 +133,7 @@ def converse_anthropic(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: Optional[str] = "claude-3-5-sonnet-20241022",
api_key: Optional[str] = None,
@ -151,7 +153,6 @@ def converse_anthropic(
"""
# Initialize Variables
current_date = datetime.now()
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
@ -175,7 +176,7 @@ def converse_anthropic(
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
@ -183,10 +184,13 @@ def converse_anthropic(
return iter([prompts.no_online_results_found.format()])
context_message = ""
if not is_none_or_empty(compiled_references):
context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n"
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
context_message = context_message.strip()
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(

View file

@ -35,7 +35,15 @@ DEFAULT_MAX_TOKENS_ANTHROPIC = 3000
reraise=True,
)
def anthropic_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None, tracer={}
messages,
system_prompt,
model_name,
temperature=0,
api_key=None,
model_kwargs=None,
max_tokens=None,
response_type="text",
tracer={},
) -> str:
if api_key not in anthropic_clients:
client: anthropic.Anthropic = anthropic.Anthropic(api_key=api_key)
@ -44,8 +52,11 @@ def anthropic_completion_with_backoff(
client = anthropic_clients[api_key]
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
if response_type == "json_object":
# Prefill model response with '{' to make it output a valid JSON object
formatted_messages += [{"role": "assistant", "content": "{"}]
aggregated_response = ""
aggregated_response = "{" if response_type == "json_object" else ""
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_ANTHROPIC
model_kwargs = model_kwargs or dict()

View file

@ -14,11 +14,13 @@ from khoj.processor.conversation.google.utils import (
gemini_completion_with_backoff,
)
from khoj.processor.conversation.utils import (
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@ -91,10 +93,7 @@ def extract_questions_gemini(
# Extract, Clean Message from Gemini's Response
try:
response = response.strip()
match = re.search(r"\{.*?\}", response)
if match:
response = match.group()
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
@ -117,8 +116,10 @@ def gemini_send_message_to_model(
messages, system_prompt = format_messages_for_gemini(messages)
model_kwargs = {}
if response_type == "json_object":
model_kwargs["response_mime_type"] = "application/json"
# Sometimes, this causes unwanted behavior and terminates response early. Disable for now while it's flaky.
# if response_type == "json_object":
# model_kwargs["response_mime_type"] = "application/json"
# Get Response from Gemini
return gemini_completion_with_backoff(
@ -136,6 +137,7 @@ def converse_gemini(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: Optional[str] = "gemini-1.5-flash",
api_key: Optional[str] = None,
@ -156,7 +158,6 @@ def converse_gemini(
"""
# Initialize Variables
current_date = datetime.now()
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
@ -181,7 +182,7 @@ def converse_gemini(
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
@ -189,10 +190,13 @@ def converse_gemini(
return iter([prompts.no_online_results_found.format()])
context_message = ""
if not is_none_or_empty(compiled_references):
context_message = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n"
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(query=user_query, references=yaml_dump(references))}\n\n"
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
context_message = context_message.strip()
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(

View file

@ -19,6 +19,7 @@ from khoj.utils import state
from khoj.utils.constants import empty_escape_sequences
from khoj.utils.helpers import ConversationCommand, in_debug_mode, is_none_or_empty
from khoj.utils.rawconfig import LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@ -138,7 +139,8 @@ def filter_questions(questions: List[str]):
def converse_offline(
user_query,
references=[],
online_results=[],
online_results={},
code_results={},
conversation_log={},
model: str = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF",
loaded_model: Union[Any, None] = None,
@ -158,8 +160,6 @@ def converse_offline(
assert loaded_model is None or isinstance(loaded_model, Llama), "loaded_model must be of type Llama, if configured"
offline_chat_model = loaded_model or download_model(model, max_tokens=max_prompt_size)
tracer["chat_model"] = model
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
current_date = datetime.now()
if agent and agent.personality:
@ -184,24 +184,25 @@ def converse_offline(
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])
context_message = ""
if not is_none_or_empty(compiled_references):
context_message += f"{prompts.notes_conversation_offline.format(references=compiled_references)}\n\n"
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation_offline.format(references=yaml_dump(references))}\n\n"
if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
simplified_online_results = online_results.copy()
for result in online_results:
if online_results[result].get("webpages"):
simplified_online_results[result] = online_results[result]["webpages"]
context_message += (
f"{prompts.online_search_conversation_offline.format(online_results=str(simplified_online_results))}"
)
context_message += f"{prompts.online_search_conversation_offline.format(online_results=yaml_dump(simplified_online_results))}\n\n"
if ConversationCommand.Code in conversation_commands and not is_none_or_empty(code_results):
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
context_message = context_message.strip()
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(

View file

@ -12,12 +12,13 @@ from khoj.processor.conversation.openai.utils import (
completion_with_backoff,
)
from khoj.processor.conversation.utils import (
clean_json,
construct_structured_message,
generate_chatml_messages_with_context,
remove_json_codeblock,
)
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData
from khoj.utils.yaml import yaml_dump
logger = logging.getLogger(__name__)
@ -94,8 +95,7 @@ def extract_questions(
# Extract, Clean Message from GPT's Response
try:
response = response.strip()
response = remove_json_codeblock(response)
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
@ -133,6 +133,7 @@ def converse(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
code_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: str = "gpt-4o-mini",
api_key: Optional[str] = None,
@ -154,7 +155,6 @@ def converse(
"""
# Initialize Variables
current_date = datetime.now()
compiled_references = "\n\n".join({f"# File: {item['file']}\n## {item['compiled']}\n" for item in references})
if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
@ -178,7 +178,7 @@ def converse(
system_prompt = f"{system_prompt}\n{user_name_prompt}"
# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
@ -186,10 +186,13 @@ def converse(
return iter([prompts.no_online_results_found.format()])
context_message = ""
if not is_none_or_empty(compiled_references):
context_message = f"{prompts.notes_conversation.format(references=compiled_references)}\n\n"
if not is_none_or_empty(references):
context_message = f"{prompts.notes_conversation.format(references=yaml_dump(references))}\n\n"
if not is_none_or_empty(online_results):
context_message += f"{prompts.online_search_conversation.format(online_results=str(online_results))}"
context_message += f"{prompts.online_search_conversation.format(online_results=yaml_dump(online_results))}\n\n"
if not is_none_or_empty(code_results):
context_message += f"{prompts.code_executed_context.format(code_results=str(code_results))}\n\n"
context_message = context_message.strip()
# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(

View file

@ -394,21 +394,23 @@ Q: {query}
extract_questions = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes. Disregard online search requests.
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes and documents.
Construct search queries to retrieve relevant information to answer the user's question.
- You will be provided past questions(Q) and answers(A) for context.
- You will be provided example and actual past user questions(Q), search queries(Khoj) and answers(A) for context.
- Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
- Break your search down into multiple search queries from a diverse set of lenses to retrieve all related documents.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
{personality_context}
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
What searches will you perform to answer the user's question? Respond with search queries as list of strings in a JSON object.
Current Date: {day_of_week}, {current_date}
User's Location: {location}
{username}
Examples
---
Q: How was my trip to Cambodia?
Khoj: {{"queries": ["How was my trip to Cambodia?"]}}
Khoj: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
Q: Who did i visit that temple with?
@ -443,6 +445,8 @@ Q: Who all did I meet here yesterday?
Khoj: {{"queries": ["Met in {location} on {yesterday_date} dt>='{yesterday_date}' dt<'{current_date}'"]}}
A: Yesterday's note mentions your visit to your local beach with Ram and Shyam.
Actual
---
{chat_history}
Q: {text}
Khoj:
@ -451,11 +455,11 @@ Khoj:
extract_questions_anthropic_system_prompt = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes. Disregard online search requests.
You are Khoj, an extremely smart and helpful document search assistant with only the ability to retrieve information from the user's notes.
Construct search queries to retrieve relevant information to answer the user's question.
- You will be provided past questions(User), extracted queries(Assistant) and answers(A) for context.
- You will be provided past questions(User), search queries(Assistant) and answers(A) for context.
- Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
- Break your search down into multiple search queries from a diverse set of lenses to retrieve all related documents.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
{personality_context}
@ -468,7 +472,7 @@ User's Location: {location}
Here are some examples of how you can construct search queries to answer the user's question:
User: How was my trip to Cambodia?
Assistant: {{"queries": ["How was my trip to Cambodia?"]}}
Assistant: {{"queries": ["How was my trip to Cambodia?", "Angkor Wat temple visit", "Flight to Phnom Penh", "Expenses in Cambodia", "Stay in Cambodia"]}}
A: The trip was amazing. You went to the Angkor Wat temple and it was beautiful.
User: What national parks did I go to last year?
@ -501,17 +505,14 @@ Assistant:
)
system_prompt_extract_relevant_information = """
As a professional analyst, create a comprehensive report of the most relevant information from a web page in response to a user's query.
The text provided is directly from within the web page.
The report you create should be multiple paragraphs, and it should represent the content of the website.
Tell the user exactly what the website says in response to their query, while adhering to these guidelines:
As a professional analyst, your job is to extract all pertinent information from documents to help answer user's query.
You will be provided raw text directly from within the document.
Adhere to these guidelines while extracting information from the provided documents:
1. Answer the user's query as specifically as possible. Include many supporting details from the website.
2. Craft a report that is detailed, thorough, in-depth, and complex, while maintaining clarity.
3. Rely strictly on the provided text, without including external information.
4. Format the report in multiple paragraphs with a clear structure.
5. Be as specific as possible in your answer to the user's query.
6. Reproduce as much of the provided text as possible, while maintaining readability.
1. Extract all relevant text and links from the document that can assist with further research or answer the user's query.
2. Craft a comprehensive but compact report with all the necessary data from the document to generate an informed response.
3. Rely strictly on the provided text to generate your summary, without including external information.
4. Provide specific, important snippets from the document in your report to establish trust in your summary.
""".strip()
extract_relevant_information = PromptTemplate.from_template(
@ -519,10 +520,10 @@ extract_relevant_information = PromptTemplate.from_template(
{personality_context}
Target Query: {query}
Web Pages:
Document:
{corpus}
Collate only relevant information from the website to answer the target query.
Collate only relevant information from the document to answer the target query.
""".strip()
)
@ -617,6 +618,67 @@ Khoj:
""".strip()
)
plan_function_execution = PromptTemplate.from_template(
"""
You are Khoj, a smart, creative and methodical researcher. Use the provided tool AIs to investigate information to answer query.
Create a multi-step plan and intelligently iterate on the plan based on the retrieved information to find the requested information.
{personality_context}
# Instructions
- Ask detailed queries to the tool AIs provided below, one at a time, to discover required information or run calculations. Their response will be shown to you in the next iteration.
- Break down your research process into independent, self-contained steps that can be executed sequentially to answer the user's query. Write your step-by-step plan in the scratchpad.
- Ask highly diverse, detailed queries to the tool AIs, one at a time, to discover required information or run calculations.
- NEVER repeat the same query across iterations.
- Ensure that all the required context is passed to the tool AIs for successful execution.
- Ensure that you go deeper when possible and try more broad, creative strategies when a path is not yielding useful results. Build on the results of the previous iterations.
- You are allowed upto {max_iterations} iterations to use the help of the provided tool AIs to answer the user's question.
- Stop when you have the required information by returning a JSON object with an empty "tool" field. E.g., {{scratchpad: "I have all I need", tool: "", query: ""}}
# Examples
Assuming you can search the user's notes and the internet.
- When they ask for the population of their hometown
1. Try look up their hometown in their notes. Ask the note search AI to search for their birth certificate, childhood memories, school, resume etc.
2. If not found in their notes, try infer their hometown from their online social media profiles. Ask the online search AI to look for {username}'s biography, school, resume on linkedin, facebook, website etc.
3. Only then try find the latest population of their hometown by reading official websites with the help of the online search and web page reading AI.
- When user for their computer's specs
1. Try find their computer model in their notes.
2. Now find webpages with their computer model's spec online and read them.
- When I ask what clothes to carry for their upcoming trip
1. Find the itinerary of their upcoming trip in their notes.
2. Next find the weather forecast at the destination online.
3. Then find if they mentioned what clothes they own in their notes.
# Background Context
- Current Date: {day_of_week}, {current_date}
- User Location: {location}
- User Name: {username}
# Available Tool AIs
Which of the tool AIs listed below would you use to answer the user's question? You **only** have access to the following tool AIs:
{tools}
# Previous Iterations
{previous_iterations}
# Chat History:
{chat_history}
Return the next tool AI to use and the query to ask it. Your response should always be a valid JSON object. Do not say anything else.
Response format:
{{"scratchpad": "<your_scratchpad_to_reason_about_which_tool_to_use>", "tool": "<name_of_tool_ai>", "query": "<your_detailed_query_for_the_tool_ai>"}}
""".strip()
)
previous_iteration = PromptTemplate.from_template(
"""
## Iteration {index}:
- tool: {tool}
- query: {query}
- result: {result}
"""
)
pick_relevant_information_collection_tools = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant.
@ -806,6 +868,53 @@ Khoj:
""".strip()
)
# Code Generation
# --
python_code_generation_prompt = PromptTemplate.from_template(
"""
You are Khoj, an advanced python programmer. You are tasked with constructing **up to three** python programs to best answer the user query.
- The python program will run in a pyodide python sandbox with no network access.
- You can write programs to run complex calculations, analyze data, create charts, generate documents to meticulously answer the query
- The sandbox has access to the standard library, matplotlib, panda, numpy, scipy, bs4, sympy, brotli, cryptography, fast-parquet
- Do not try display images or plots in the code directly. The code should save the image or plot to a file instead.
- Write any document, charts etc. to be shared with the user to file. These files can be seen by the user.
- Use as much context from the previous questions and answers as required to generate your code.
{personality_context}
What code will you need to write, if any, to answer the user's question?
Provide code programs as a list of strings in a JSON object with key "codes".
Current Date: {current_date}
User's Location: {location}
{username}
The JSON schema is of the form {{"codes": ["code1", "code2", "code3"]}}
For example:
{{"codes": ["print('Hello, World!')", "print('Goodbye, World!')"]}}
Now it's your turn to construct python programs to answer the user's question. Provide them as a list of strings in a JSON object. Do not say anything else.
Context:
---
{context}
Chat History:
---
{chat_history}
User: {query}
Khoj:
""".strip()
)
code_executed_context = PromptTemplate.from_template(
"""
Use the provided code executions to inform your response.
Ask crisp follow-up questions to get additional context, when a helpful response cannot be provided from the provided code execution results or past conversations.
Code Execution Results:
{code_results}
""".strip()
)
# Automations
# --
crontime_prompt = PromptTemplate.from_template(

View file

@ -6,9 +6,10 @@ import os
import queue
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from io import BytesIO
from time import perf_counter
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional
import PIL.Image
import requests
@ -23,8 +24,17 @@ from khoj.database.adapters import ConversationAdapters
from khoj.database.models import ChatModelOptions, ClientApplication, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.utils import download_model, infer_max_tokens
from khoj.search_filter.base_filter import BaseFilter
from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter
from khoj.utils import state
from khoj.utils.helpers import in_debug_mode, is_none_or_empty, merge_dicts
from khoj.utils.helpers import (
ConversationCommand,
in_debug_mode,
is_none_or_empty,
merge_dicts,
)
logger = logging.getLogger(__name__)
model_to_prompt_size = {
@ -85,8 +95,105 @@ class ThreadedGenerator:
self.queue.put(StopIteration)
class InformationCollectionIteration:
def __init__(
self,
tool: str,
query: str,
context: list = None,
onlineContext: dict = None,
codeContext: dict = None,
summarizedResult: str = None,
):
self.tool = tool
self.query = query
self.context = context
self.onlineContext = onlineContext
self.codeContext = codeContext
self.summarizedResult = summarizedResult
def construct_iteration_history(
previous_iterations: List[InformationCollectionIteration], previous_iteration_prompt: str
) -> str:
previous_iterations_history = ""
for idx, iteration in enumerate(previous_iterations):
iteration_data = previous_iteration_prompt.format(
tool=iteration.tool,
query=iteration.query,
result=iteration.summarizedResult,
index=idx + 1,
)
previous_iterations_history += iteration_data
return previous_iterations_history
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['message']}\n"
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: [generated image redacted for space]\n"
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
return chat_history
def construct_tool_chat_history(
previous_iterations: List[InformationCollectionIteration], tool: ConversationCommand = None
) -> Dict[str, list]:
chat_history: list = []
inferred_query_extractor: Callable[[InformationCollectionIteration], List[str]] = lambda x: []
if tool == ConversationCommand.Notes:
inferred_query_extractor = (
lambda iteration: [c["query"] for c in iteration.context] if iteration.context else []
)
elif tool == ConversationCommand.Online:
inferred_query_extractor = (
lambda iteration: list(iteration.onlineContext.keys()) if iteration.onlineContext else []
)
elif tool == ConversationCommand.Code:
inferred_query_extractor = lambda iteration: list(iteration.codeContext.keys()) if iteration.codeContext else []
for iteration in previous_iterations:
chat_history += [
{
"by": "you",
"message": iteration.query,
},
{
"by": "khoj",
"intent": {
"type": "remember",
"inferred-queries": inferred_query_extractor(iteration),
"query": iteration.query,
},
"message": iteration.summarizedResult,
},
]
return {"chat": chat_history}
class ChatEvent(Enum):
START_LLM_RESPONSE = "start_llm_response"
END_LLM_RESPONSE = "end_llm_response"
MESSAGE = "message"
REFERENCES = "references"
STATUS = "status"
def message_to_log(
user_message, chat_response, user_message_metadata={}, khoj_message_metadata={}, conversation_log=[]
user_message,
chat_response,
user_message_metadata={},
khoj_message_metadata={},
conversation_log=[],
train_of_thought=[],
):
"""Create json logs from messages, metadata for conversation log"""
default_khoj_message_metadata = {
@ -114,6 +221,7 @@ def save_to_conversation_log(
user_message_time: str = None,
compiled_references: List[Dict[str, Any]] = [],
online_results: Dict[str, Any] = {},
code_results: Dict[str, Any] = {},
inferred_queries: List[str] = [],
intent_type: str = "remember",
client_application: ClientApplication = None,
@ -121,6 +229,7 @@ def save_to_conversation_log(
automation_id: str = None,
query_images: List[str] = None,
tracer: Dict[str, Any] = {},
train_of_thought: List[Any] = [],
):
user_message_time = user_message_time or datetime.now().strftime("%Y-%m-%d %H:%M:%S")
updated_conversation = message_to_log(
@ -134,9 +243,12 @@ def save_to_conversation_log(
"context": compiled_references,
"intent": {"inferred-queries": inferred_queries, "type": intent_type},
"onlineContext": online_results,
"codeContext": code_results,
"automationId": automation_id,
"trainOfThought": train_of_thought,
},
conversation_log=meta_log.get("chat", []),
train_of_thought=train_of_thought,
)
ConversationAdapters.save_conversation(
user,
@ -330,9 +442,23 @@ def reciprocal_conversation_to_chatml(message_pair):
return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])]
def remove_json_codeblock(response: str):
"""Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models"""
return response.removeprefix("```json").removesuffix("```")
def clean_json(response: str):
"""Remove any markdown json codeblock and newline formatting if present. Useful for non schema enforceable models"""
return response.strip().replace("\n", "").removeprefix("```json").removesuffix("```")
def clean_code_python(code: str):
"""Remove any markdown codeblock and newline formatting if present. Useful for non schema enforceable models"""
return code.strip().removeprefix("```python").removesuffix("```")
def defilter_query(query: str):
"""Remove any query filters in query"""
defiltered_query = query
filters: List[BaseFilter] = [WordFilter(), FileFilter(), DateFilter()]
for filter in filters:
defiltered_query = filter.defilter(defiltered_query)
return defiltered_query
@dataclass

View file

@ -4,7 +4,7 @@ import logging
import os
import urllib.parse
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import aiohttp
from bs4 import BeautifulSoup
@ -52,7 +52,8 @@ OLOSTEP_QUERY_PARAMS = {
"expandMarkdown": "True",
"expandHtml": "False",
}
MAX_WEBPAGES_TO_READ = 1
DEFAULT_MAX_WEBPAGES_TO_READ = 1
async def search_online(
@ -62,6 +63,7 @@ async def search_online(
user: KhojUser,
send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [],
max_webpages_to_read: int = DEFAULT_MAX_WEBPAGES_TO_READ,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
@ -97,7 +99,7 @@ async def search_online(
for subquery in response_dict:
if "answerBox" in response_dict[subquery]:
continue
for organic in response_dict[subquery].get("organic", [])[:MAX_WEBPAGES_TO_READ]:
for organic in response_dict[subquery].get("organic", [])[:max_webpages_to_read]:
link = organic.get("link")
if link in webpages:
webpages[link]["queries"].add(subquery)

View file

@ -0,0 +1,144 @@
import asyncio
import datetime
import json
import logging
import os
from typing import Any, Callable, List, Optional
import aiohttp
from khoj.database.adapters import ais_user_subscribed
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
ChatEvent,
clean_code_python,
clean_json,
construct_chat_history,
)
from khoj.routers.helpers import send_message_to_model_wrapper
from khoj.utils.helpers import timer
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
SANDBOX_URL = os.getenv("KHOJ_TERRARIUM_URL", "http://localhost:8080")
async def run_code(
query: str,
conversation_history: dict,
context: str,
location_data: LocationData,
user: KhojUser,
send_status_func: Optional[Callable] = None,
query_images: List[str] = None,
agent: Agent = None,
sandbox_url: str = SANDBOX_URL,
tracer: dict = {},
):
# Generate Code
if send_status_func:
async for event in send_status_func(f"**Generate code snippets** for {query}"):
yield {ChatEvent.STATUS: event}
try:
with timer("Chat actor: Generate programs to execute", logger):
codes = await generate_python_code(
query,
conversation_history,
context,
location_data,
user,
query_images,
agent,
tracer,
)
except Exception as e:
raise ValueError(f"Failed to generate code for {query} with error: {e}")
# Run Code
if send_status_func:
async for event in send_status_func(f"**Running {len(codes)} code snippets**"):
yield {ChatEvent.STATUS: event}
try:
tasks = [execute_sandboxed_python(code, sandbox_url) for code in codes]
with timer("Chat actor: Execute generated programs", logger):
results = await asyncio.gather(*tasks)
for result in results:
code = result.pop("code")
logger.info(f"Executed Code:\n--@@--\n{code}\n--@@--Result:\n--@@--\n{result}\n--@@--")
yield {query: {"code": code, "results": result}}
except Exception as e:
raise ValueError(f"Failed to run code for {query} with error: {e}")
async def generate_python_code(
q: str,
conversation_history: dict,
context: str,
location_data: LocationData,
user: KhojUser,
query_images: List[str] = None,
agent: Agent = None,
tracer: dict = {},
) -> List[str]:
location = f"{location_data}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user.get_full_name() else ""
subscribed = await ais_user_subscribed(user)
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
code_generation_prompt = prompts.python_code_generation_prompt.format(
current_date=utc_date,
query=q,
chat_history=chat_history,
context=context,
location=location,
username=username,
personality_context=personality_context,
)
response = await send_message_to_model_wrapper(
code_generation_prompt,
query_images=query_images,
response_type="json_object",
user=user,
tracer=tracer,
)
# Validate that the response is a non-empty, JSON-serializable list
response = clean_json(response)
response = json.loads(response)
codes = [code.strip() for code in response["codes"] if code.strip()]
if not isinstance(codes, list) or not codes or len(codes) == 0:
raise ValueError
return codes
async def execute_sandboxed_python(code: str, sandbox_url: str = SANDBOX_URL) -> dict[str, Any]:
"""
Takes code to run as a string and calls the terrarium API to execute it.
Returns the result of the code execution as a dictionary.
"""
headers = {"Content-Type": "application/json"}
cleaned_code = clean_code_python(code)
data = {"code": cleaned_code}
async with aiohttp.ClientSession() as session:
async with session.post(sandbox_url, json=data, headers=headers) as response:
if response.status == 200:
result: dict[str, Any] = await response.json()
result["code"] = cleaned_code
return result
else:
return {
"code": cleaned_code,
"success": False,
"std_err": f"Failed to execute code with {response.status}",
}

View file

@ -44,6 +44,7 @@ from khoj.processor.conversation.offline.chat_model import extract_questions_off
from khoj.processor.conversation.offline.whisper import transcribe_audio_offline
from khoj.processor.conversation.openai.gpt import extract_questions
from khoj.processor.conversation.openai.whisper import transcribe_audio
from khoj.processor.conversation.utils import defilter_query
from khoj.routers.helpers import (
ApiUserRateLimiter,
ChatEvent,
@ -167,8 +168,8 @@ async def execute_search(
search_futures += [
executor.submit(
text_search.query,
user,
user_query,
user,
t,
question_embedding=encoded_asymmetric_query,
max_distance=max_distance,
@ -355,7 +356,7 @@ async def extract_references_and_questions(
user = request.user.object if request.user.is_authenticated else None
# Initialize Variables
compiled_references: List[Any] = []
compiled_references: List[dict[str, str]] = []
inferred_queries: List[str] = []
agent_has_entries = False
@ -384,9 +385,7 @@ async def extract_references_and_questions(
return
# Extract filter terms from user message
defiltered_query = q
for filter in [DateFilter(), WordFilter(), FileFilter()]:
defiltered_query = filter.defilter(defiltered_query)
defiltered_query = defilter_query(q)
filters_in_query = q.replace(defiltered_query, "").strip()
conversation = await sync_to_async(ConversationAdapters.get_conversation_by_id)(conversation_id)
@ -502,7 +501,8 @@ async def extract_references_and_questions(
)
search_results = text_search.deduplicated_search_responses(search_results)
compiled_references = [
{"compiled": item.additional["compiled"], "file": item.additional["file"]} for item in search_results
{"query": q, "compiled": item.additional["compiled"], "file": item.additional["file"]}
for q, item in zip(inferred_queries, search_results)
]
yield compiled_references, inferred_queries, defiltered_query

View file

@ -6,7 +6,7 @@ import time
import uuid
from datetime import datetime
from functools import partial
from typing import Dict, Optional
from typing import Any, Dict, List, Optional
from urllib.parse import unquote
from asgiref.sync import sync_to_async
@ -25,10 +25,11 @@ from khoj.database.adapters import (
)
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.conversation.utils import defilter_query, save_to_conversation_log
from khoj.processor.image.generate import text_to_image
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code
from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ApiImageRateLimiter,
@ -42,8 +43,10 @@ from khoj.routers.helpers import (
aget_relevant_output_modes,
construct_automation_created_message,
create_automation,
extract_relevant_info,
extract_relevant_summary,
generate_excalidraw_diagram,
generate_summary_from_files,
get_conversation_command,
is_query_empty,
is_ready_to_chat,
@ -51,6 +54,10 @@ from khoj.routers.helpers import (
update_telemetry_state,
validate_conversation_config,
)
from khoj.routers.research import (
InformationCollectionIteration,
execute_information_collection,
)
from khoj.routers.storage import upload_image_to_bucket
from khoj.utils import state
from khoj.utils.helpers import (
@ -563,7 +570,9 @@ async def chat(
user: KhojUser = request.user.object
event_delimiter = "␃🔚␗"
q = unquote(q)
train_of_thought = []
nonlocal conversation_id
tracer: dict = {
"mid": f"{uuid.uuid4()}",
"cid": conversation_id,
@ -583,7 +592,7 @@ async def chat(
uploaded_images.append(uploaded_image)
async def send_event(event_type: ChatEvent, data: str | dict):
nonlocal connection_alive, ttft
nonlocal connection_alive, ttft, train_of_thought
if not connection_alive or await request.is_disconnected():
connection_alive = False
logger.warning(f"User {user} disconnected from {common.client} client")
@ -591,8 +600,11 @@ async def chat(
try:
if event_type == ChatEvent.END_LLM_RESPONSE:
collect_telemetry()
if event_type == ChatEvent.START_LLM_RESPONSE:
elif event_type == ChatEvent.START_LLM_RESPONSE:
ttft = time.perf_counter() - start_time
elif event_type == ChatEvent.STATUS:
train_of_thought.append({"type": event_type.value, "data": data})
if event_type == ChatEvent.MESSAGE:
yield data
elif event_type == ChatEvent.REFERENCES or stream:
@ -681,6 +693,14 @@ async def chat(
meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]
researched_results = ""
online_results: Dict = dict()
code_results: Dict = dict()
## Extract Document References
compiled_references: List[Any] = []
inferred_queries: List[Any] = []
defiltered_query = defilter_query(q)
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(
q,
@ -691,6 +711,11 @@ async def chat(
agent=agent,
tracer=tracer,
)
# If we're doing research, we don't want to do anything else
if ConversationCommand.Research in conversation_commands:
conversation_commands = [ConversationCommand.Research]
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
ChatEvent.STATUS, f"**Chose Data Sources to Search:** {conversation_commands_str}"
@ -705,6 +730,38 @@ async def chat(
if mode not in conversation_commands:
conversation_commands.append(mode)
if conversation_commands == [ConversationCommand.Research]:
async for research_result in execute_information_collection(
request=request,
user=user,
query=defiltered_query,
conversation_id=conversation_id,
conversation_history=meta_log,
query_images=uploaded_images,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
user_name=user_name,
location=location,
file_filters=conversation.file_filters if conversation else [],
tracer=tracer,
):
if isinstance(research_result, InformationCollectionIteration):
if research_result.summarizedResult:
if research_result.onlineContext:
online_results.update(research_result.onlineContext)
if research_result.codeContext:
code_results.update(research_result.codeContext)
if research_result.context:
compiled_references.extend(research_result.context)
researched_results += research_result.summarizedResult
else:
yield research_result
# researched_results = await extract_relevant_info(q, researched_results, agent)
logger.info(f"Researched Results: {researched_results}")
for cmd in conversation_commands:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
@ -733,48 +790,24 @@ async def chat(
async for result in send_llm_response(response_log):
yield result
else:
try:
file_object = None
if await EntryAdapters.aagent_has_entries(agent):
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
if len(file_names) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(
None, file_names[0], agent
)
if len(file_filters) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, I couldn't find the full text of this file. Please re-upload the document and try again."
async for result in send_llm_response(response_log):
yield result
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_event(
ChatEvent.STATUS, f"**Constructing Summary Using:** {file_object[0].file_name}"
):
yield result
response = await extract_relevant_summary(
q,
contextual_data,
conversation_history=meta_log,
query_images=uploaded_images,
async for response in generate_summary_from_files(
q=q,
user=user,
file_filters=file_filters,
meta_log=meta_log,
query_images=uploaded_images,
agent=agent,
send_status_func=partial(send_event, ChatEvent.STATUS),
tracer=tracer,
)
response_log = str(response)
async for result in send_llm_response(response_log):
yield result
except Exception as e:
response_log = "Error summarizing file. Please try again, or contact support."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
async for result in send_llm_response(response_log):
):
if isinstance(response, dict) and ChatEvent.STATUS in response:
yield response[ChatEvent.STATUS]
else:
if isinstance(response, str):
response_log = response
async for result in send_llm_response(response):
yield result
await sync_to_async(save_to_conversation_log)(
q,
response_log,
@ -786,6 +819,7 @@ async def chat(
conversation_id=conversation_id,
query_images=uploaded_images,
tracer=tracer,
train_of_thought=train_of_thought,
)
return
@ -794,7 +828,7 @@ async def chat(
if not q:
conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if conversation_config == None:
conversation_config = await ConversationAdapters.aget_default_conversation_config()
conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
model_type = conversation_config.model_type
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
async for result in send_llm_response(formatted_help):
@ -830,6 +864,7 @@ async def chat(
automation_id=automation.id,
query_images=uploaded_images,
tracer=tracer,
train_of_thought=train_of_thought,
)
async for result in send_llm_response(llm_response):
yield result
@ -837,7 +872,7 @@ async def chat(
# Gather Context
## Extract Document References
compiled_references, inferred_queries, defiltered_query = [], [], q
if not ConversationCommand.Research in conversation_commands:
try:
async for result in extract_references_and_questions(
request,
@ -860,7 +895,9 @@ async def chat(
inferred_queries.extend(result[1])
defiltered_query = result[2]
except Exception as e:
error_message = f"Error searching knowledge base: {e}. Attempting to respond without document references."
error_message = (
f"Error searching knowledge base: {e}. Attempting to respond without document references."
)
logger.error(error_message, exc_info=True)
async for result in send_event(
ChatEvent.STATUS, "Document search failed. I'll try respond without document references"
@ -874,8 +911,6 @@ async def chat(
async for result in send_event(ChatEvent.STATUS, f"**Found Relevant Notes**: {headings}"):
yield result
online_results: Dict = dict()
if conversation_commands == [ConversationCommand.Notes] and not await EntryAdapters.auser_has_entries(user):
async for result in send_llm_response(f"{no_entries_found.format()}"):
yield result
@ -948,6 +983,33 @@ async def chat(
):
yield result
## Gather Code Results
if ConversationCommand.Code in conversation_commands:
try:
context = f"# Iteration 1:\n#---\nNotes:\n{compiled_references}\n\nOnline Results:{online_results}"
async for result in run_code(
defiltered_query,
meta_log,
context,
location,
user,
partial(send_event, ChatEvent.STATUS),
query_images=uploaded_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
code_results = result
async for result in send_event(ChatEvent.STATUS, f"**Ran code snippets**: {len(code_results)}"):
yield result
except ValueError as e:
logger.warning(
f"Failed to use code tool: {e}. Attempting to respond without code results",
exc_info=True,
)
## Send Gathered References
async for result in send_event(
ChatEvent.REFERENCES,
@ -955,6 +1017,7 @@ async def chat(
"inferredQueries": inferred_queries,
"context": compiled_references,
"onlineContext": online_results,
"codeContext": code_results,
},
):
yield result
@ -1004,6 +1067,7 @@ async def chat(
online_results=online_results,
query_images=uploaded_images,
tracer=tracer,
train_of_thought=train_of_thought,
)
content_obj = {
"intentType": intent_type,
@ -1061,6 +1125,7 @@ async def chat(
online_results=online_results,
query_images=uploaded_images,
tracer=tracer,
train_of_thought=train_of_thought,
)
async for result in send_llm_response(json.dumps(content_obj)):
@ -1076,6 +1141,7 @@ async def chat(
conversation,
compiled_references,
online_results,
code_results,
inferred_queries,
conversation_commands,
user,
@ -1083,8 +1149,10 @@ async def chat(
conversation_id,
location,
user_name,
researched_results,
uploaded_images,
tracer,
train_of_thought,
)
# Send Response

View file

@ -43,6 +43,7 @@ from khoj.database.adapters import (
AutomationAdapters,
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
ais_user_subscribed,
create_khoj_token,
get_khoj_tokens,
@ -87,9 +88,11 @@ from khoj.processor.conversation.offline.chat_model import (
)
from khoj.processor.conversation.openai.gpt import converse, send_message_to_model
from khoj.processor.conversation.utils import (
ChatEvent,
ThreadedGenerator,
clean_json,
construct_chat_history,
generate_chatml_messages_with_context,
remove_json_codeblock,
save_to_conversation_log,
)
from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled
@ -137,7 +140,7 @@ def validate_conversation_config(user: KhojUser):
async def is_ready_to_chat(user: KhojUser):
user_conversation_config = await ConversationAdapters.aget_user_conversation_config(user)
if user_conversation_config == None:
user_conversation_config = await ConversationAdapters.aget_default_conversation_config()
user_conversation_config = await ConversationAdapters.aget_default_conversation_config(user)
if user_conversation_config and user_conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
chat_model = user_conversation_config.chat_model
@ -210,21 +213,6 @@ def get_next_url(request: Request) -> str:
return urljoin(str(request.base_url).rstrip("/"), next_path)
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
chat_history = ""
for chat in conversation_history.get("chat", [])[-n:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder", "summarize"]:
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['message']}\n"
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: [generated image redacted for space]\n"
elif chat["by"] == "khoj" and ("excalidraw" in chat["intent"].get("type")):
chat_history += f"User: {chat['intent']['query']}\n"
chat_history += f"{agent_name}: {chat['intent']['inferred-queries'][0]}\n"
return chat_history
def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
if query.startswith("/notes"):
return ConversationCommand.Notes
@ -244,6 +232,10 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.Summarize
elif query.startswith("/diagram"):
return ConversationCommand.Diagram
elif query.startswith("/code"):
return ConversationCommand.Code
elif query.startswith("/research"):
return ConversationCommand.Research
# If no relevant notes found for the given query
elif not any_references:
return ConversationCommand.General
@ -342,8 +334,7 @@ async def aget_relevant_information_sources(
)
try:
response = response.strip()
response = remove_json_codeblock(response)
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["source"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
@ -421,8 +412,7 @@ async def aget_relevant_output_modes(
)
try:
response = response.strip()
response = remove_json_codeblock(response)
response = clean_json(response)
response = json.loads(response)
if is_none_or_empty(response):
@ -483,7 +473,7 @@ async def infer_webpage_urls(
# Validate that the response is a non-empty, JSON-serializable list of URLs
try:
response = response.strip()
response = clean_json(response)
urls = json.loads(response)
valid_unique_urls = {str(url).strip() for url in urls["links"] if is_valid_url(url)}
if is_none_or_empty(valid_unique_urls):
@ -534,8 +524,7 @@ async def generate_online_subqueries(
# Validate that the response is a non-empty, JSON-serializable list
try:
response = response.strip()
response = remove_json_codeblock(response)
response = clean_json(response)
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response or len(response) == 0:
@ -644,6 +633,53 @@ async def extract_relevant_summary(
return response.strip()
async def generate_summary_from_files(
q: str,
user: KhojUser,
file_filters: List[str],
meta_log: dict,
query_images: List[str] = None,
agent: Agent = None,
send_status_func: Optional[Callable] = None,
tracer: dict = {},
):
try:
file_object = None
if await EntryAdapters.aagent_has_entries(agent):
file_names = await EntryAdapters.aget_agent_entry_filepaths(agent)
if len(file_names) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(None, file_names.pop(), agent)
if len(file_filters) > 0:
file_object = await FileObjectAdapters.async_get_file_objects_by_name(user, file_filters[0])
if len(file_object) == 0:
response_log = "Sorry, I couldn't find the full text of this file."
yield response_log
return
contextual_data = " ".join([file.raw_text for file in file_object])
if not q:
q = "Create a general summary of the file"
async for result in send_status_func(f"**Constructing Summary Using:** {file_object[0].file_name}"):
yield {ChatEvent.STATUS: result}
response = await extract_relevant_summary(
q,
contextual_data,
conversation_history=meta_log,
query_images=query_images,
user=user,
agent=agent,
tracer=tracer,
)
yield str(response)
except Exception as e:
response_log = "Error summarizing file. Please try again, or contact support."
logger.error(f"Error summarizing file for {user.email}: {e}", exc_info=True)
yield result
async def generate_excalidraw_diagram(
q: str,
conversation_history: Dict[str, Any],
@ -759,10 +795,9 @@ async def generate_excalidraw_diagram_from_description(
with timer("Chat actor: Generate excalidraw diagram", logger):
raw_response = await send_message_to_model_wrapper(
message=excalidraw_diagram_generation, user=user, tracer=tracer
query=excalidraw_diagram_generation, user=user, tracer=tracer
)
raw_response = raw_response.strip()
raw_response = remove_json_codeblock(raw_response)
raw_response = clean_json(raw_response)
response: Dict[str, str] = json.loads(raw_response)
if not response or not isinstance(response, List) or not isinstance(response[0], Dict):
# TODO Some additional validation here that it's a valid Excalidraw diagram
@ -839,11 +874,12 @@ async def generate_better_image_prompt(
async def send_message_to_model_wrapper(
message: str,
query: str,
system_message: str = "",
response_type: str = "text",
user: KhojUser = None,
query_images: List[str] = None,
context: str = "",
tracer: dict = {},
):
conversation_config: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config(user)
@ -874,7 +910,8 @@ async def send_message_to_model_wrapper(
loaded_model = state.offline_chat_processor_config.loaded_model
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
loaded_model=loaded_model,
@ -899,7 +936,8 @@ async def send_message_to_model_wrapper(
api_key = openai_chat_config.api_key
api_base_url = openai_chat_config.api_base_url
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
@ -920,7 +958,8 @@ async def send_message_to_model_wrapper(
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
@ -934,12 +973,14 @@ async def send_message_to_model_wrapper(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
elif model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
user_message=message,
user_message=query,
context_message=context,
system_message=system_message,
model_name=chat_model,
max_prompt_size=max_tokens,
@ -1033,6 +1074,7 @@ def send_message_to_model_wrapper_sync(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
tracer=tracer,
)
@ -1064,6 +1106,7 @@ def generate_chat_response(
conversation: Conversation,
compiled_references: List[Dict] = [],
online_results: Dict[str, Dict] = {},
code_results: Dict[str, Dict] = {},
inferred_queries: List[str] = [],
conversation_commands: List[ConversationCommand] = [ConversationCommand.Default],
user: KhojUser = None,
@ -1071,8 +1114,10 @@ def generate_chat_response(
conversation_id: str = None,
location_data: LocationData = None,
user_name: Optional[str] = None,
meta_research: str = "",
query_images: Optional[List[str]] = None,
tracer: dict = {},
train_of_thought: List[Any] = [],
) -> Tuple[Union[ThreadedGenerator, Iterator[str]], Dict[str, str]]:
# Initialize Variables
chat_response = None
@ -1080,6 +1125,9 @@ def generate_chat_response(
metadata = {}
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
query_to_run = q
if meta_research:
query_to_run = f"AI Research: {meta_research} {q}"
try:
partial_completion = partial(
save_to_conversation_log,
@ -1088,11 +1136,13 @@ def generate_chat_response(
meta_log=meta_log,
compiled_references=compiled_references,
online_results=online_results,
code_results=code_results,
inferred_queries=inferred_queries,
client_application=client_application,
conversation_id=conversation_id,
query_images=query_images,
tracer=tracer,
train_of_thought=train_of_thought,
)
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
@ -1106,9 +1156,9 @@ def generate_chat_response(
if conversation_config.model_type == "offline":
loaded_model = state.offline_chat_processor_config.loaded_model
chat_response = converse_offline(
user_query=query_to_run,
references=compiled_references,
online_results=online_results,
user_query=q,
loaded_model=loaded_model,
conversation_log=meta_log,
completion_func=partial_completion,
@ -1128,9 +1178,10 @@ def generate_chat_response(
chat_model = conversation_config.chat_model
chat_response = converse(
compiled_references,
q,
query_to_run,
query_images=query_images,
online_results=online_results,
code_results=code_results,
conversation_log=meta_log,
model=chat_model,
api_key=api_key,
@ -1150,9 +1201,10 @@ def generate_chat_response(
api_key = conversation_config.openai_config.api_key
chat_response = converse_anthropic(
compiled_references,
q,
query_to_run,
query_images=query_images,
online_results=online_results,
code_results=code_results,
conversation_log=meta_log,
model=conversation_config.chat_model,
api_key=api_key,
@ -1170,10 +1222,10 @@ def generate_chat_response(
api_key = conversation_config.openai_config.api_key
chat_response = converse_gemini(
compiled_references,
q,
query_images=query_images,
online_results=online_results,
conversation_log=meta_log,
query_to_run,
online_results,
code_results,
meta_log,
model=conversation_config.chat_model,
api_key=api_key,
completion_func=partial_completion,
@ -1627,14 +1679,6 @@ Manage your automations [here](/automations).
""".strip()
class ChatEvent(Enum):
START_LLM_RESPONSE = "start_llm_response"
END_LLM_RESPONSE = "end_llm_response"
MESSAGE = "message"
REFERENCES = "references"
STATUS = "status"
class MessageProcessor:
def __init__(self):
self.references = {}

View file

@ -0,0 +1,321 @@
import json
import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
import yaml
from fastapi import Request
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.utils import (
InformationCollectionIteration,
clean_json,
construct_iteration_history,
construct_tool_chat_history,
)
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.processor.tools.run_code import run_code
from khoj.routers.api import extract_references_and_questions
from khoj.routers.helpers import (
ChatEvent,
construct_chat_history,
extract_relevant_info,
generate_summary_from_files,
send_message_to_model_wrapper,
)
from khoj.utils.helpers import (
ConversationCommand,
function_calling_description_for_llm,
is_none_or_empty,
timer,
)
from khoj.utils.rawconfig import LocationData
logger = logging.getLogger(__name__)
async def apick_next_tool(
query: str,
conversation_history: dict,
user: KhojUser = None,
query_images: List[str] = [],
location: LocationData = None,
user_name: str = None,
agent: Agent = None,
previous_iterations_history: str = None,
max_iterations: int = 5,
send_status_func: Optional[Callable] = None,
tracer: dict = {},
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately. One at a time, and it's able to use subsequent iterations to refine the answer.
"""
tool_options = dict()
tool_options_str = ""
agent_tools = agent.input_tools if agent else []
for tool, description in function_calling_description_for_llm.items():
tool_options[tool.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options_str += f'- "{tool.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history, agent_name=agent.name if agent else "Khoj")
if query_images:
query = f"[placeholder for user attached images]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
# Extract Past User Message and Inferred Questions from Conversation Log
today = datetime.today()
location_data = f"{location}" if location else "Unknown"
function_planning_prompt = prompts.plan_function_execution.format(
tools=tool_options_str,
chat_history=chat_history,
personality_context=personality_context,
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
username=user_name or "Unknown",
location=location_data,
previous_iterations=previous_iterations_history,
max_iterations=max_iterations,
)
with timer("Chat actor: Infer information sources to refer", logger):
response = await send_message_to_model_wrapper(
query=query,
context=function_planning_prompt,
response_type="json_object",
user=user,
query_images=query_images,
tracer=tracer,
)
try:
response = clean_json(response)
response = json.loads(response)
selected_tool = response.get("tool", None)
generated_query = response.get("query", None)
scratchpad = response.get("scratchpad", None)
logger.info(f"Response for determining relevant tools: {response}")
if send_status_func:
determined_tool_message = "**Determined Tool**: "
determined_tool_message += f"{selected_tool}({generated_query})." if selected_tool else "respond."
determined_tool_message += f"\nReason: {scratchpad}" if scratchpad else ""
async for event in send_status_func(f"{scratchpad}"):
yield {ChatEvent.STATUS: event}
yield InformationCollectionIteration(
tool=selected_tool,
query=generated_query,
)
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. {e}", exc_info=True)
yield InformationCollectionIteration(
tool=None,
query=None,
)
async def execute_information_collection(
request: Request,
user: KhojUser,
query: str,
conversation_id: str,
conversation_history: dict,
query_images: List[str],
agent: Agent = None,
send_status_func: Optional[Callable] = None,
user_name: str = None,
location: LocationData = None,
file_filters: List[str] = [],
tracer: dict = {},
):
current_iteration = 0
MAX_ITERATIONS = 5
previous_iterations: List[InformationCollectionIteration] = []
while current_iteration < MAX_ITERATIONS:
online_results: Dict = dict()
code_results: Dict = dict()
document_results: List[Dict[str, str]] = []
summarize_files: str = ""
this_iteration = InformationCollectionIteration(tool=None, query=query)
previous_iterations_history = construct_iteration_history(previous_iterations, prompts.previous_iteration)
async for result in apick_next_tool(
query,
conversation_history,
user,
query_images,
location,
user_name,
agent,
previous_iterations_history,
MAX_ITERATIONS,
send_status_func,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, InformationCollectionIteration):
this_iteration = result
if this_iteration.tool == ConversationCommand.Notes:
this_iteration.context = []
document_results = []
async for result in extract_references_and_questions(
request,
construct_tool_chat_history(previous_iterations, ConversationCommand.Notes),
this_iteration.query,
7,
None,
conversation_id,
[ConversationCommand.Default],
location,
send_status_func,
query_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
elif isinstance(result, tuple):
document_results = result[0]
this_iteration.context += document_results
if not is_none_or_empty(document_results):
try:
distinct_files = {d["file"] for d in document_results}
distinct_headings = set([d["compiled"].split("\n")[0] for d in document_results if "compiled" in d])
# Strip only leading # from headings
headings_str = "\n- " + "\n- ".join(distinct_headings).replace("#", "")
async for result in send_status_func(
f"**Found {len(distinct_headings)} Notes Across {len(distinct_files)} Files**: {headings_str}"
):
yield result
except Exception as e:
logger.error(f"Error extracting document references: {e}", exc_info=True)
elif this_iteration.tool == ConversationCommand.Online:
async for result in search_online(
this_iteration.query,
construct_tool_chat_history(previous_iterations, ConversationCommand.Online),
location,
user,
send_status_func,
[],
max_webpages_to_read=0,
query_images=query_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
online_results: Dict[str, Dict] = result # type: ignore
this_iteration.onlineContext = online_results
elif this_iteration.tool == ConversationCommand.Webpage:
try:
async for result in read_webpages(
this_iteration.query,
construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage),
location,
user,
send_status_func,
query_images=query_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
direct_web_pages: Dict[str, Dict] = result # type: ignore
webpages = []
for web_query in direct_web_pages:
if online_results.get(web_query):
online_results[web_query]["webpages"] = direct_web_pages[web_query]["webpages"]
else:
online_results[web_query] = {"webpages": direct_web_pages[web_query]["webpages"]}
for webpage in direct_web_pages[web_query]["webpages"]:
webpages.append(webpage["link"])
this_iteration.onlineContext = online_results
except Exception as e:
logger.error(f"Error reading webpages: {e}", exc_info=True)
elif this_iteration.tool == ConversationCommand.Code:
try:
async for result in run_code(
this_iteration.query,
construct_tool_chat_history(previous_iterations, ConversationCommand.Webpage),
"",
location,
user,
send_status_func,
query_images=query_images,
agent=agent,
tracer=tracer,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
code_results: Dict[str, Dict] = result # type: ignore
this_iteration.codeContext = code_results
async for result in send_status_func(f"**Ran code snippets**: {len(this_iteration.codeContext)}"):
yield result
except ValueError as e:
logger.warning(
f"Failed to use code tool: {e}. Attempting to respond without code results",
exc_info=True,
)
elif this_iteration.tool == ConversationCommand.Summarize:
try:
async for result in generate_summary_from_files(
this_iteration.query,
user,
file_filters,
construct_tool_chat_history(previous_iterations),
query_images=query_images,
agent=agent,
send_status_func=send_status_func,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
else:
summarize_files = result # type: ignore
except Exception as e:
logger.error(f"Error generating summary: {e}", exc_info=True)
else:
# No valid tools. This is our exit condition.
current_iteration = MAX_ITERATIONS
current_iteration += 1
if document_results or online_results or code_results or summarize_files:
results_data = f"**Results**:\n"
if document_results:
results_data += f"**Document References**: {yaml.dump(document_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if online_results:
results_data += f"**Online Results**: {yaml.dump(online_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if code_results:
results_data += f"**Code Results**: {yaml.dump(code_results, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
if summarize_files:
results_data += f"**Summarized Files**: {yaml.dump(summarize_files, allow_unicode=True, sort_keys=False, default_flow_style=False)}\n"
# intermediate_result = await extract_relevant_info(this_iteration.query, results_data, agent)
this_iteration.summarizedResult = results_data
previous_iterations.append(this_iteration)
yield this_iteration

View file

@ -7,8 +7,6 @@ from math import inf
from typing import List, Tuple
import dateparser as dtparse
from dateparser.search import search_dates
from dateparser_data.settings import default_parsers
from dateutil.relativedelta import relativedelta
from khoj.search_filter.base_filter import BaseFilter
@ -23,7 +21,7 @@ class DateFilter(BaseFilter):
# - dt>="yesterday" dt<"tomorrow"
# - dt>="last week"
# - dt:"2 years ago"
date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']"
date_regex = r"dt([:><=]{1,2})[\"'](.*?)[\"']"
def __init__(self, entry_key="compiled"):
self.entry_key = entry_key

View file

@ -1,11 +1,10 @@
import fnmatch
import logging
import re
from collections import defaultdict
from typing import List
from khoj.search_filter.base_filter import BaseFilter
from khoj.utils.helpers import LRU, timer
from khoj.utils.helpers import LRU
logger = logging.getLogger(__name__)

View file

@ -102,8 +102,8 @@ def load_embeddings(
async def query(
user: KhojUser,
raw_query: str,
user: KhojUser,
type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None,
max_distance: float = None,
@ -130,12 +130,12 @@ async def query(
top_k = 10
with timer("Search Time", logger, state.device):
hits = EntryAdapters.search_with_embeddings(
user=user,
raw_query=raw_query,
embeddings=question_embedding,
max_results=top_k,
file_type_filter=file_type,
raw_query=raw_query,
max_distance=max_distance,
user=user,
agent=agent,
).all()
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]

View file

@ -313,12 +313,14 @@ class ConversationCommand(str, Enum):
Help = "help"
Online = "online"
Webpage = "webpage"
Code = "code"
Image = "image"
Text = "text"
Automation = "automation"
AutomatedTask = "automated_task"
Summarize = "summarize"
Diagram = "diagram"
Research = "research"
command_descriptions = {
@ -327,11 +329,13 @@ command_descriptions = {
ConversationCommand.Default: "The default command when no command specified. It intelligently auto-switches between general and notes mode.",
ConversationCommand.Online: "Search for information on the internet.",
ConversationCommand.Webpage: "Get information from webpage suggested by you.",
ConversationCommand.Code: "Run Python code to parse information, run complex calculations, create documents and charts.",
ConversationCommand.Image: "Generate illustrative, creative images by describing your imagination in words.",
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
ConversationCommand.Summarize: "Get help with a question pertaining to an entire document.",
ConversationCommand.Diagram: "Draw a flowchart, diagram, or any other visual representation best expressed with primitives like lines, rectangles, and text.",
ConversationCommand.Research: "Do deep research on a topic. This will take longer than usual, but give a more detailed, comprehensive answer.",
}
command_descriptions_for_agent = {
@ -340,6 +344,7 @@ command_descriptions_for_agent = {
ConversationCommand.Online: "Agent can search the internet for information.",
ConversationCommand.Webpage: "Agent can read suggested web pages for information.",
ConversationCommand.Summarize: "Agent can read an entire document. Agents knowledge base must be a single document.",
ConversationCommand.Research: "Agent can do deep research on a topic.",
}
tool_descriptions_for_llm = {
@ -348,18 +353,26 @@ tool_descriptions_for_llm = {
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.",
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
}
function_calling_description_for_llm = {
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
ConversationCommand.Online: "To search the internet for information. Provide all relevant context to ensure new searches, not previously run, are performed.",
ConversationCommand.Webpage: "To extract information from a webpage. Useful for more detailed research from the internet. Usually used when you know the webpage links to refer to. Share the webpage link and information to extract in your query.",
ConversationCommand.Code: "To run Python code in a Pyodide sandbox with no network access. Helpful when need to parse information, run complex calculations, create documents and charts for user. Matplotlib, bs4, pandas, numpy, etc. are available.",
}
mode_descriptions_for_llm = {
ConversationCommand.Image: "Use this if you are confident the user is requesting you to create a new picture based on their description.",
ConversationCommand.Image: "Use this if you are confident the user is requesting you to create a new picture based on their description. This does not support generating charts or graphs.",
ConversationCommand.Automation: "Use this if you are confident the user is requesting a response at a scheduled date, time and frequency",
ConversationCommand.Text: "Use this if a normal text response would be sufficient for accurately responding to the query.",
ConversationCommand.Diagram: "Use this if the user is requesting a diagram or visual representation that requires primitives like lines, rectangles, and text.",
}
mode_descriptions_for_agent = {
ConversationCommand.Image: "Agent can generate image in response.",
ConversationCommand.Image: "Agent can generate images in response. It cannot not use this to generate charts and graphs.",
ConversationCommand.Automation: "Agent can schedule a task to run at a scheduled date, time and frequency in response.",
ConversationCommand.Text: "Agent can generate text in response.",
ConversationCommand.Diagram: "Agent can generate a visual representation that requires primitives like lines, rectangles, and text.",

View file

@ -41,3 +41,7 @@ def parse_config_from_string(yaml_config: dict) -> FullConfig:
def parse_config_from_file(yaml_config_file):
"Parse and validate config in YML file"
return parse_config_from_string(load_config_from_file(yaml_config_file))
def yaml_dump(data):
return yaml.dump(data, allow_unicode=True, sort_keys=False, default_flow_style=False)

View file

@ -164,7 +164,7 @@ async def test_text_search(search_config: SearchConfig):
query = "Load Khoj on Emacs?"
# Act
hits = await text_search.query(default_user, query)
hits = await text_search.query(query, default_user)
results = text_search.collate_results(hits)
results = sorted(results, key=lambda x: float(x.score))[:1]