mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Include agent personality through subtasks and support custom agents (#916)
Currently, the personality of the agent is only included in the final response that it returns to the user. Historically, this was because models were quite bad at navigating the additional context of personality, and there was a bias towards having more control over certain operations (e.g., tool selection, question extraction). Going forward, it should be more approachable to have prompts included in the sub tasks that Khoj runs in order to response to a given query. Make this possible in this PR. This also sets us up for agent creation becoming available soon. Create custom agents in #928 Agents are useful insofar as you can personalize them to fulfill specific subtasks you need to accomplish. In this PR, we add support for using custom agents that can be configured with a custom system prompt (aka persona) and knowledge base (from your own indexed documents). Once created, private agents can be accessible only to the creator, and protected agents can be accessible via a direct link. Custom tool selection for agents in #930 Expose the functionality to select which tools a given agent has access to. By default, they have all. Can limit both information sources and output modes. Add new tools to the agent modification form
This commit is contained in:
parent
c0193744f5
commit
405c047c0c
29 changed files with 2350 additions and 284 deletions
File diff suppressed because it is too large
Load diff
|
@ -36,6 +36,15 @@ export interface SyncedContent {
|
|||
github: boolean;
|
||||
notion: boolean;
|
||||
}
|
||||
|
||||
export enum SubscriptionStates {
|
||||
EXPIRED = "expired",
|
||||
TRIAL = "trial",
|
||||
SUBSCRIBED = "subscribed",
|
||||
UNSUBSCRIBED = "unsubscribed",
|
||||
INVALID = "invalid",
|
||||
}
|
||||
|
||||
export interface UserConfig {
|
||||
// user info
|
||||
username: string;
|
||||
|
@ -58,7 +67,7 @@ export interface UserConfig {
|
|||
voice_model_options: ModelOptions[];
|
||||
selected_voice_model_config: number;
|
||||
// user billing info
|
||||
subscription_state: string;
|
||||
subscription_state: SubscriptionStates;
|
||||
subscription_renewal_date: string;
|
||||
// server settings
|
||||
khoj_cloud_subscription_url: string | undefined;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
const tailwindColors = [
|
||||
export const tailwindColors = [
|
||||
"red",
|
||||
"yellow",
|
||||
"green",
|
||||
|
|
|
@ -26,6 +26,28 @@ import {
|
|||
Wallet,
|
||||
PencilLine,
|
||||
Chalkboard,
|
||||
Gps,
|
||||
Question,
|
||||
Browser,
|
||||
Notebook,
|
||||
Shapes,
|
||||
ChatsTeardrop,
|
||||
GlobeSimple,
|
||||
ArrowRight,
|
||||
Cigarette,
|
||||
CraneTower,
|
||||
Heart,
|
||||
Leaf,
|
||||
NewspaperClipping,
|
||||
OrangeSlice,
|
||||
Rainbow,
|
||||
SmileyMelting,
|
||||
YinYang,
|
||||
SneakerMove,
|
||||
Student,
|
||||
Oven,
|
||||
Gavel,
|
||||
Broadcast,
|
||||
} from "@phosphor-icons/react";
|
||||
import { Markdown, OrgMode, Pdf, Word } from "@/app/components/logo/fileLogo";
|
||||
|
||||
|
@ -103,8 +125,92 @@ const iconMap: IconMap = {
|
|||
Chalkboard: (color: string, width: string, height: string) => (
|
||||
<Chalkboard className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
Cigarette: (color: string, width: string, height: string) => (
|
||||
<Cigarette className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
CraneTower: (color: string, width: string, height: string) => (
|
||||
<CraneTower className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
Heart: (color: string, width: string, height: string) => (
|
||||
<Heart className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
Leaf: (color: string, width: string, height: string) => (
|
||||
<Leaf className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
NewspaperClipping: (color: string, width: string, height: string) => (
|
||||
<NewspaperClipping className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
OrangeSlice: (color: string, width: string, height: string) => (
|
||||
<OrangeSlice className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
SmileyMelting: (color: string, width: string, height: string) => (
|
||||
<SmileyMelting className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
YinYang: (color: string, width: string, height: string) => (
|
||||
<YinYang className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
SneakerMove: (color: string, width: string, height: string) => (
|
||||
<SneakerMove className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
Student: (color: string, width: string, height: string) => (
|
||||
<Student className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
Oven: (color: string, width: string, height: string) => (
|
||||
<Oven className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
Gavel: (color: string, width: string, height: string) => (
|
||||
<Gavel className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
Broadcast: (color: string, width: string, height: string) => (
|
||||
<Broadcast className={`${width} ${height} ${color} mr-2`} />
|
||||
),
|
||||
};
|
||||
|
||||
export function getIconForSlashCommand(command: string, customClassName: string | null = null) {
|
||||
const className = customClassName ?? "h-4 w-4";
|
||||
if (command.includes("summarize")) {
|
||||
return <Gps className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("help")) {
|
||||
return <Question className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("automation")) {
|
||||
return <Robot className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("webpage")) {
|
||||
return <Browser className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("notes")) {
|
||||
return <Notebook className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("image")) {
|
||||
return <Image className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("default")) {
|
||||
return <Shapes className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("general")) {
|
||||
return <ChatsTeardrop className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("online")) {
|
||||
return <GlobeSimple className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("text")) {
|
||||
return <PencilLine className={className} />;
|
||||
}
|
||||
|
||||
return <ArrowRight className={className} />;
|
||||
}
|
||||
|
||||
function getIconFromIconName(
|
||||
iconName: string,
|
||||
color: string = "gray",
|
||||
|
@ -141,4 +247,8 @@ function getIconFromFilename(
|
|||
}
|
||||
}
|
||||
|
||||
export { getIconFromIconName, getIconFromFilename };
|
||||
function getAvailableIcons() {
|
||||
return Object.keys(iconMap);
|
||||
}
|
||||
|
||||
export { getIconFromIconName, getIconFromFilename, getAvailableIcons };
|
||||
|
|
|
@ -15,7 +15,7 @@ import { InlineLoading } from "../loading/loading";
|
|||
|
||||
import { Lightbulb, ArrowDown } from "@phosphor-icons/react";
|
||||
|
||||
import ProfileCard from "../profileCard/profileCard";
|
||||
import AgentProfileCard from "../profileCard/profileCard";
|
||||
import { getIconFromIconName } from "@/app/common/iconUtils";
|
||||
import { AgentData } from "@/app/agents/page";
|
||||
import React from "react";
|
||||
|
@ -350,7 +350,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
|
|||
{data && (
|
||||
<div className={`${styles.agentIndicator} pb-4`}>
|
||||
<div className="relative group mx-2 cursor-pointer">
|
||||
<ProfileCard
|
||||
<AgentProfileCard
|
||||
name={constructAgentName()}
|
||||
link={constructAgentLink()}
|
||||
avatar={
|
||||
|
|
|
@ -50,6 +50,7 @@ import { convertToBGClass } from "@/app/common/colorUtils";
|
|||
import LoginPrompt from "../loginPrompt/loginPrompt";
|
||||
import { uploadDataForIndexing } from "../../common/chatFunctions";
|
||||
import { InlineLoading } from "../loading/loading";
|
||||
import { getIconForSlashCommand } from "@/app/common/iconUtils";
|
||||
|
||||
export interface ChatOptions {
|
||||
[key: string]: string;
|
||||
|
@ -193,46 +194,6 @@ export default function ChatInputArea(props: ChatInputProps) {
|
|||
);
|
||||
}
|
||||
|
||||
function getIconForSlashCommand(command: string) {
|
||||
const className = "h-4 w-4 mr-2";
|
||||
if (command.includes("summarize")) {
|
||||
return <Gps className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("help")) {
|
||||
return <Question className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("automation")) {
|
||||
return <Robot className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("webpage")) {
|
||||
return <Browser className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("notes")) {
|
||||
return <Notebook className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("image")) {
|
||||
return <Image className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("default")) {
|
||||
return <Shapes className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("general")) {
|
||||
return <ChatsTeardrop className={className} />;
|
||||
}
|
||||
|
||||
if (command.includes("online")) {
|
||||
return <GlobeSimple className={className} />;
|
||||
}
|
||||
return <ArrowRight className={className} />;
|
||||
}
|
||||
|
||||
// Assuming this function is added within the same context as the provided excerpt
|
||||
async function startRecordingAndTranscribe() {
|
||||
try {
|
||||
|
@ -426,7 +387,11 @@ export default function ChatInputArea(props: ChatInputProps) {
|
|||
>
|
||||
<div className="grid grid-cols-1 gap-1">
|
||||
<div className="font-bold flex items-center">
|
||||
{getIconForSlashCommand(key)}/{key}
|
||||
{getIconForSlashCommand(
|
||||
key,
|
||||
"h-4 w-4 mr-2",
|
||||
)}
|
||||
/{key}
|
||||
</div>
|
||||
<div>{value}</div>
|
||||
</div>
|
||||
|
|
|
@ -11,11 +11,11 @@ interface ProfileCardProps {
|
|||
description?: string; // Optional description field
|
||||
}
|
||||
|
||||
const ProfileCard: React.FC<ProfileCardProps> = ({ name, avatar, link, description }) => {
|
||||
const AgentProfileCard: React.FC<ProfileCardProps> = ({ name, avatar, link, description }) => {
|
||||
return (
|
||||
<div className="relative group flex">
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<Tooltip delayDuration={0}>
|
||||
<TooltipTrigger asChild>
|
||||
<Button variant="ghost" className="flex items-center justify-center">
|
||||
{avatar}
|
||||
|
@ -24,7 +24,6 @@ const ProfileCard: React.FC<ProfileCardProps> = ({ name, avatar, link, descripti
|
|||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
<div className="w-80 h-30">
|
||||
{/* <div className="absolute left-0 bottom-full w-80 h-30 p-2 pb-4 bg-white border border-gray-300 rounded-lg shadow-lg opacity-0 group-hover:opacity-100 transition-opacity duration-300"> */}
|
||||
<a
|
||||
href={link}
|
||||
target="_blank"
|
||||
|
@ -52,4 +51,4 @@ const ProfileCard: React.FC<ProfileCardProps> = ({ name, avatar, link, descripti
|
|||
);
|
||||
};
|
||||
|
||||
export default ProfileCard;
|
||||
export default AgentProfileCard;
|
||||
|
|
|
@ -17,8 +17,16 @@ interface ShareLinkProps {
|
|||
title: string;
|
||||
description: string;
|
||||
url: string;
|
||||
onShare: () => void;
|
||||
buttonVariant?: keyof typeof buttonVariants;
|
||||
onShare?: () => void;
|
||||
buttonVariant?:
|
||||
| "default"
|
||||
| "destructive"
|
||||
| "outline"
|
||||
| "secondary"
|
||||
| "ghost"
|
||||
| "link"
|
||||
| null
|
||||
| undefined;
|
||||
includeIcon?: boolean;
|
||||
buttonClassName?: string;
|
||||
}
|
||||
|
@ -38,7 +46,7 @@ export default function ShareLink(props: ShareLinkProps) {
|
|||
<Button
|
||||
size="sm"
|
||||
className={`${props.buttonClassName || "px-3"}`}
|
||||
variant={props.buttonVariant ?? ("default" as const)}
|
||||
variant={props.buttonVariant ?? "default"}
|
||||
>
|
||||
{props.includeIcon && <Share className="w-4 h-4 mr-2" />}
|
||||
{props.buttonTitle}
|
||||
|
|
|
@ -63,7 +63,6 @@ interface ChatHistory {
|
|||
conversation_id: string;
|
||||
slug: string;
|
||||
agent_name: string;
|
||||
agent_avatar: string;
|
||||
compressed: boolean;
|
||||
created: string;
|
||||
updated: string;
|
||||
|
@ -435,7 +434,6 @@ function SessionsAndFiles(props: SessionsAndFilesProps) {
|
|||
chatHistory.conversation_id
|
||||
}
|
||||
slug={chatHistory.slug}
|
||||
agent_avatar={chatHistory.agent_avatar}
|
||||
agent_name={chatHistory.agent_name}
|
||||
showSidePanel={props.setEnabled}
|
||||
/>
|
||||
|
@ -713,7 +711,6 @@ function ChatSessionsModal({ data, showSidePanel }: ChatSessionsModalProps) {
|
|||
key={chatHistory.conversation_id}
|
||||
conversation_id={chatHistory.conversation_id}
|
||||
slug={chatHistory.slug}
|
||||
agent_avatar={chatHistory.agent_avatar}
|
||||
agent_name={chatHistory.agent_name}
|
||||
showSidePanel={showSidePanel}
|
||||
/>
|
||||
|
|
|
@ -123,18 +123,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
|
|||
//generate colored icons for the selected agents
|
||||
const agentIcons = agents
|
||||
.filter((agent) => agent !== null && agent !== undefined)
|
||||
.map(
|
||||
(agent) =>
|
||||
getIconFromIconName(agent.icon, agent.color) || (
|
||||
<Image
|
||||
key={agent.name}
|
||||
src={agent.avatar}
|
||||
alt={agent.name}
|
||||
width={50}
|
||||
height={50}
|
||||
/>
|
||||
),
|
||||
);
|
||||
.map((agent) => getIconFromIconName(agent.icon, agent.color)!);
|
||||
setAgentIcons(agentIcons);
|
||||
}, [agentsData, props.isMobileWidth]);
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import "intl-tel-input/styles";
|
|||
import { Suspense, useEffect, useRef, useState } from "react";
|
||||
import { useToast } from "@/components/ui/use-toast";
|
||||
|
||||
import { useUserConfig, ModelOptions, UserConfig } from "../common/auth";
|
||||
import { useUserConfig, ModelOptions, UserConfig, SubscriptionStates } from "../common/auth";
|
||||
import { toTitleCase, useIsMobileWidth } from "../common/utils";
|
||||
|
||||
import { isValidPhoneNumber } from "libphonenumber-js";
|
||||
|
@ -276,7 +276,7 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
|
|||
)}
|
||||
</div>
|
||||
<div
|
||||
className={`flex-none p-4 bg-secondary border-b ${isDragAndDropping ? "animate-pulse" : ""}`}
|
||||
className={`flex-none p-4 bg-secondary border-b ${isDragAndDropping ? "animate-pulse" : ""} rounded-lg`}
|
||||
>
|
||||
<div className="flex items-center justify-center w-full h-32 border-2 border-dashed border-gray-300 rounded-lg">
|
||||
{isDragAndDropping ? (
|
||||
|
@ -294,7 +294,6 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
|
|||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col h-full">
|
||||
<div className="flex-none p-4">Synced files</div>
|
||||
<div className="flex-none p-4 bg-background border-b">
|
||||
<CommandInput
|
||||
placeholder="Find synced files"
|
||||
|
@ -615,7 +614,9 @@ export default function SettingsView() {
|
|||
if (userConfig) {
|
||||
let newUserConfig = userConfig;
|
||||
newUserConfig.subscription_state =
|
||||
state === "cancel" ? "unsubscribed" : "subscribed";
|
||||
state === "cancel"
|
||||
? SubscriptionStates.UNSUBSCRIBED
|
||||
: SubscriptionStates.SUBSCRIBED;
|
||||
setUserConfig(newUserConfig);
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from enum import Enum
|
|||
from typing import Callable, Iterable, List, Optional, Type
|
||||
|
||||
import cron_descriptor
|
||||
import django
|
||||
from apscheduler.job import Job
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
|
@ -551,26 +552,62 @@ class ClientApplicationAdapters:
|
|||
|
||||
class AgentAdapters:
|
||||
DEFAULT_AGENT_NAME = "Khoj"
|
||||
DEFAULT_AGENT_AVATAR = "https://assets.khoj.dev/lamp-128.png"
|
||||
DEFAULT_AGENT_SLUG = "khoj"
|
||||
|
||||
@staticmethod
|
||||
async def aget_readonly_agent_by_slug(agent_slug: str, user: KhojUser):
|
||||
return await Agent.objects.filter(
|
||||
(Q(slug__iexact=agent_slug.lower()))
|
||||
& (
|
||||
Q(privacy_level=Agent.PrivacyLevel.PUBLIC)
|
||||
| Q(privacy_level=Agent.PrivacyLevel.PROTECTED)
|
||||
| Q(creator=user)
|
||||
)
|
||||
).afirst()
|
||||
|
||||
@staticmethod
|
||||
async def adelete_agent_by_slug(agent_slug: str, user: KhojUser):
|
||||
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
|
||||
if agent:
|
||||
await agent.adelete()
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def aget_agent_by_slug(agent_slug: str, user: KhojUser):
|
||||
return await Agent.objects.filter(
|
||||
(Q(slug__iexact=agent_slug.lower())) & (Q(public=True) | Q(creator=user))
|
||||
(Q(slug__iexact=agent_slug.lower())) & (Q(privacy_level=Agent.PrivacyLevel.PUBLIC) | Q(creator=user))
|
||||
).afirst()
|
||||
|
||||
@staticmethod
|
||||
async def aget_agent_by_name(agent_name: str, user: KhojUser):
|
||||
return await Agent.objects.filter(
|
||||
(Q(name__iexact=agent_name.lower())) & (Q(privacy_level=Agent.PrivacyLevel.PUBLIC) | Q(creator=user))
|
||||
).afirst()
|
||||
|
||||
@staticmethod
|
||||
def get_agent_by_slug(slug: str, user: KhojUser = None):
|
||||
if user:
|
||||
return Agent.objects.filter((Q(slug__iexact=slug.lower())) & (Q(public=True) | Q(creator=user))).first()
|
||||
return Agent.objects.filter(slug__iexact=slug.lower(), public=True).first()
|
||||
return Agent.objects.filter(
|
||||
(Q(slug__iexact=slug.lower())) & (Q(privacy_level=Agent.PrivacyLevel.PUBLIC) | Q(creator=user))
|
||||
).first()
|
||||
return Agent.objects.filter(slug__iexact=slug.lower(), privacy_level=Agent.PrivacyLevel.PUBLIC).first()
|
||||
|
||||
@staticmethod
|
||||
def get_all_accessible_agents(user: KhojUser = None):
|
||||
public_query = Q(privacy_level=Agent.PrivacyLevel.PUBLIC)
|
||||
if user:
|
||||
return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct().order_by("created_at")
|
||||
return Agent.objects.filter(public=True).order_by("created_at")
|
||||
return (
|
||||
Agent.objects.filter(public_query | Q(creator=user))
|
||||
.distinct()
|
||||
.order_by("created_at")
|
||||
.prefetch_related("creator", "chat_model", "fileobject_set")
|
||||
)
|
||||
return (
|
||||
Agent.objects.filter(public_query)
|
||||
.order_by("created_at")
|
||||
.prefetch_related("creator", "chat_model", "fileobject_set")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def aget_all_accessible_agents(user: KhojUser = None) -> List[Agent]:
|
||||
|
@ -609,12 +646,11 @@ class AgentAdapters:
|
|||
# The default agent is public and managed by the admin. It's handled a little differently than other agents.
|
||||
agent = Agent.objects.create(
|
||||
name=AgentAdapters.DEFAULT_AGENT_NAME,
|
||||
public=True,
|
||||
privacy_level=Agent.PrivacyLevel.PUBLIC,
|
||||
managed_by_admin=True,
|
||||
chat_model=default_conversation_config,
|
||||
personality=default_personality,
|
||||
tools=["*"],
|
||||
avatar=AgentAdapters.DEFAULT_AGENT_AVATAR,
|
||||
slug=AgentAdapters.DEFAULT_AGENT_SLUG,
|
||||
)
|
||||
Conversation.objects.filter(agent=None).update(agent=agent)
|
||||
|
@ -625,6 +661,68 @@ class AgentAdapters:
|
|||
async def aget_default_agent():
|
||||
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
|
||||
|
||||
@staticmethod
|
||||
async def aupdate_agent(
|
||||
user: KhojUser,
|
||||
name: str,
|
||||
personality: str,
|
||||
privacy_level: str,
|
||||
icon: str,
|
||||
color: str,
|
||||
chat_model: str,
|
||||
files: List[str],
|
||||
input_tools: List[str],
|
||||
output_modes: List[str],
|
||||
):
|
||||
chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst()
|
||||
|
||||
agent, created = await Agent.objects.filter(name=name, creator=user).aupdate_or_create(
|
||||
defaults={
|
||||
"name": name,
|
||||
"creator": user,
|
||||
"personality": personality,
|
||||
"privacy_level": privacy_level,
|
||||
"style_icon": icon,
|
||||
"style_color": color,
|
||||
"chat_model": chat_model_option,
|
||||
"input_tools": input_tools,
|
||||
"output_modes": output_modes,
|
||||
}
|
||||
)
|
||||
|
||||
# Delete all existing files and entries
|
||||
await FileObject.objects.filter(agent=agent).adelete()
|
||||
await Entry.objects.filter(agent=agent).adelete()
|
||||
|
||||
for file in files:
|
||||
reference_file = await FileObject.objects.filter(file_name=file, user=agent.creator).afirst()
|
||||
if reference_file:
|
||||
await FileObject.objects.acreate(file_name=file, agent=agent, raw_text=reference_file.raw_text)
|
||||
|
||||
# Duplicate all entries associated with the file
|
||||
entries: List[Entry] = []
|
||||
async for entry in Entry.objects.filter(file_path=file, user=agent.creator).aiterator():
|
||||
entries.append(
|
||||
Entry(
|
||||
agent=agent,
|
||||
embeddings=entry.embeddings,
|
||||
raw=entry.raw,
|
||||
compiled=entry.compiled,
|
||||
heading=entry.heading,
|
||||
file_source=entry.file_source,
|
||||
file_type=entry.file_type,
|
||||
file_path=entry.file_path,
|
||||
file_name=entry.file_name,
|
||||
url=entry.url,
|
||||
hashed_value=entry.hashed_value,
|
||||
)
|
||||
)
|
||||
|
||||
# Bulk create entries
|
||||
await Entry.objects.abulk_create(entries)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
class PublicConversationAdapters:
|
||||
@staticmethod
|
||||
|
@ -1196,6 +1294,10 @@ class EntryAdapters:
|
|||
def user_has_entries(user: KhojUser):
|
||||
return Entry.objects.filter(user=user).exists()
|
||||
|
||||
@staticmethod
|
||||
def agent_has_entries(agent: Agent):
|
||||
return Entry.objects.filter(agent=agent).exists()
|
||||
|
||||
@staticmethod
|
||||
async def auser_has_entries(user: KhojUser):
|
||||
return await Entry.objects.filter(user=user).aexists()
|
||||
|
@ -1229,15 +1331,19 @@ class EntryAdapters:
|
|||
return total_size / 1024 / 1024
|
||||
|
||||
@staticmethod
|
||||
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
|
||||
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None, agent: Agent = None):
|
||||
q_filter_terms = Q()
|
||||
|
||||
word_filters = EntryAdapters.word_filter.get_filter_terms(query)
|
||||
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
|
||||
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
|
||||
|
||||
user_or_agent = Q(user=user)
|
||||
if agent != None:
|
||||
user_or_agent |= Q(agent=agent)
|
||||
|
||||
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
|
||||
return Entry.objects.filter(user=user)
|
||||
return Entry.objects.filter(user_or_agent)
|
||||
|
||||
for term in word_filters:
|
||||
if term.startswith("+"):
|
||||
|
@ -1273,7 +1379,7 @@ class EntryAdapters:
|
|||
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
|
||||
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
|
||||
|
||||
relevant_entries = Entry.objects.filter(user=user).filter(q_filter_terms)
|
||||
relevant_entries = Entry.objects.filter(user_or_agent).filter(q_filter_terms)
|
||||
if file_type_filter:
|
||||
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
|
||||
return relevant_entries
|
||||
|
@ -1286,9 +1392,15 @@ class EntryAdapters:
|
|||
file_type_filter: str = None,
|
||||
raw_query: str = None,
|
||||
max_distance: float = math.inf,
|
||||
agent: Agent = None,
|
||||
):
|
||||
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
|
||||
relevant_entries = relevant_entries.filter(user=user).annotate(
|
||||
user_or_agent = Q(user=user)
|
||||
|
||||
if agent != None:
|
||||
user_or_agent |= Q(agent=agent)
|
||||
|
||||
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter, agent)
|
||||
relevant_entries = relevant_entries.filter(user_or_agent).annotate(
|
||||
distance=CosineDistance("embeddings", embeddings)
|
||||
)
|
||||
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# Generated by Django 5.0.8 on 2024-09-18 02:54
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0064_remove_conversation_temp_id_alter_conversation_id"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name="agent",
|
||||
name="avatar",
|
||||
),
|
||||
migrations.RemoveField(
|
||||
model_name="agent",
|
||||
name="public",
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="agent",
|
||||
name="privacy_level",
|
||||
field=models.CharField(
|
||||
choices=[("public", "Public"), ("private", "Private"), ("protected", "Protected")],
|
||||
default="private",
|
||||
max_length=30,
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="entry",
|
||||
name="agent",
|
||||
field=models.ForeignKey(
|
||||
blank=True, default=None, null=True, on_delete=django.db.models.deletion.CASCADE, to="database.agent"
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="fileobject",
|
||||
name="agent",
|
||||
field=models.ForeignKey(
|
||||
blank=True, default=None, null=True, on_delete=django.db.models.deletion.CASCADE, to="database.agent"
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="agent",
|
||||
name="slug",
|
||||
field=models.CharField(max_length=200, unique=True),
|
||||
),
|
||||
]
|
|
@ -0,0 +1,69 @@
|
|||
# Generated by Django 5.0.8 on 2024-10-01 00:42
|
||||
|
||||
import django.contrib.postgres.fields
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0065_remove_agent_avatar_remove_agent_public_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.RemoveField(
|
||||
model_name="agent",
|
||||
name="tools",
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="agent",
|
||||
name="input_tools",
|
||||
field=django.contrib.postgres.fields.ArrayField(
|
||||
base_field=models.CharField(
|
||||
choices=[
|
||||
("general", "General"),
|
||||
("online", "Online"),
|
||||
("notes", "Notes"),
|
||||
("summarize", "Summarize"),
|
||||
("webpage", "Webpage"),
|
||||
],
|
||||
max_length=200,
|
||||
),
|
||||
default=list,
|
||||
size=None,
|
||||
),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="agent",
|
||||
name="output_modes",
|
||||
field=django.contrib.postgres.fields.ArrayField(
|
||||
base_field=models.CharField(choices=[("text", "Text"), ("image", "Image")], max_length=200),
|
||||
default=list,
|
||||
size=None,
|
||||
),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="agent",
|
||||
name="style_icon",
|
||||
field=models.CharField(
|
||||
choices=[
|
||||
("Lightbulb", "Lightbulb"),
|
||||
("Health", "Health"),
|
||||
("Robot", "Robot"),
|
||||
("Aperture", "Aperture"),
|
||||
("GraduationCap", "Graduation Cap"),
|
||||
("Jeep", "Jeep"),
|
||||
("Island", "Island"),
|
||||
("MathOperations", "Math Operations"),
|
||||
("Asclepius", "Asclepius"),
|
||||
("Couch", "Couch"),
|
||||
("Code", "Code"),
|
||||
("Atom", "Atom"),
|
||||
("ClockCounterClockwise", "Clock Counter Clockwise"),
|
||||
("PencilLine", "Pencil Line"),
|
||||
("Chalkboard", "Chalkboard"),
|
||||
],
|
||||
default="Lightbulb",
|
||||
max_length=200,
|
||||
),
|
||||
),
|
||||
]
|
50
src/khoj/database/migrations/0067_alter_agent_style_icon.py
Normal file
50
src/khoj/database/migrations/0067_alter_agent_style_icon.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
# Generated by Django 5.0.8 on 2024-10-01 18:42
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0066_remove_agent_tools_agent_input_tools_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="agent",
|
||||
name="style_icon",
|
||||
field=models.CharField(
|
||||
choices=[
|
||||
("Lightbulb", "Lightbulb"),
|
||||
("Health", "Health"),
|
||||
("Robot", "Robot"),
|
||||
("Aperture", "Aperture"),
|
||||
("GraduationCap", "Graduation Cap"),
|
||||
("Jeep", "Jeep"),
|
||||
("Island", "Island"),
|
||||
("MathOperations", "Math Operations"),
|
||||
("Asclepius", "Asclepius"),
|
||||
("Couch", "Couch"),
|
||||
("Code", "Code"),
|
||||
("Atom", "Atom"),
|
||||
("ClockCounterClockwise", "Clock Counter Clockwise"),
|
||||
("PencilLine", "Pencil Line"),
|
||||
("Chalkboard", "Chalkboard"),
|
||||
("Cigarette", "Cigarette"),
|
||||
("CraneTower", "Crane Tower"),
|
||||
("Heart", "Heart"),
|
||||
("Leaf", "Leaf"),
|
||||
("NewspaperClipping", "Newspaper Clipping"),
|
||||
("OrangeSlice", "Orange Slice"),
|
||||
("SmileyMelting", "Smiley Melting"),
|
||||
("YinYang", "Yin Yang"),
|
||||
("SneakerMove", "Sneaker Move"),
|
||||
("Student", "Student"),
|
||||
("Oven", "Oven"),
|
||||
("Gavel", "Gavel"),
|
||||
("Broadcast", "Broadcast"),
|
||||
],
|
||||
default="Lightbulb",
|
||||
max_length=200,
|
||||
),
|
||||
),
|
||||
]
|
|
@ -3,6 +3,7 @@ import uuid
|
|||
from random import choice
|
||||
|
||||
from django.contrib.auth.models import AbstractUser
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.db.models.signals import pre_save
|
||||
|
@ -10,6 +11,8 @@ from django.dispatch import receiver
|
|||
from pgvector.django import VectorField
|
||||
from phonenumber_field.modelfields import PhoneNumberField
|
||||
|
||||
from khoj.utils.helpers import ConversationCommand
|
||||
|
||||
|
||||
class BaseModel(models.Model):
|
||||
created_at = models.DateTimeField(auto_now_add=True)
|
||||
|
@ -125,7 +128,7 @@ class Agent(BaseModel):
|
|||
EMERALD = "emerald"
|
||||
|
||||
class StyleIconTypes(models.TextChoices):
|
||||
LIGHBULB = "Lightbulb"
|
||||
LIGHTBULB = "Lightbulb"
|
||||
HEALTH = "Health"
|
||||
ROBOT = "Robot"
|
||||
APERTURE = "Aperture"
|
||||
|
@ -140,20 +143,64 @@ class Agent(BaseModel):
|
|||
CLOCK_COUNTER_CLOCKWISE = "ClockCounterClockwise"
|
||||
PENCIL_LINE = "PencilLine"
|
||||
CHALKBOARD = "Chalkboard"
|
||||
CIGARETTE = "Cigarette"
|
||||
CRANE_TOWER = "CraneTower"
|
||||
HEART = "Heart"
|
||||
LEAF = "Leaf"
|
||||
NEWSPAPER_CLIPPING = "NewspaperClipping"
|
||||
ORANGE_SLICE = "OrangeSlice"
|
||||
SMILEY_MELTING = "SmileyMelting"
|
||||
YIN_YANG = "YinYang"
|
||||
SNEAKER_MOVE = "SneakerMove"
|
||||
STUDENT = "Student"
|
||||
OVEN = "Oven"
|
||||
GAVEL = "Gavel"
|
||||
BROADCAST = "Broadcast"
|
||||
|
||||
class PrivacyLevel(models.TextChoices):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
PROTECTED = "protected"
|
||||
|
||||
class InputToolOptions(models.TextChoices):
|
||||
# These map to various ConversationCommand types
|
||||
GENERAL = "general"
|
||||
ONLINE = "online"
|
||||
NOTES = "notes"
|
||||
SUMMARIZE = "summarize"
|
||||
WEBPAGE = "webpage"
|
||||
|
||||
class OutputModeOptions(models.TextChoices):
|
||||
# These map to various ConversationCommand types
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
|
||||
creator = models.ForeignKey(
|
||||
KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True
|
||||
) # Creator will only be null when the agents are managed by admin
|
||||
name = models.CharField(max_length=200)
|
||||
personality = models.TextField()
|
||||
avatar = models.URLField(max_length=400, default=None, null=True, blank=True)
|
||||
tools = models.JSONField(default=list) # List of tools the agent has access to, like online search or notes search
|
||||
public = models.BooleanField(default=False)
|
||||
input_tools = ArrayField(models.CharField(max_length=200, choices=InputToolOptions.choices), default=list)
|
||||
output_modes = ArrayField(models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list)
|
||||
managed_by_admin = models.BooleanField(default=False)
|
||||
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
|
||||
slug = models.CharField(max_length=200)
|
||||
slug = models.CharField(max_length=200, unique=True)
|
||||
style_color = models.CharField(max_length=200, choices=StyleColorTypes.choices, default=StyleColorTypes.BLUE)
|
||||
style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHBULB)
|
||||
style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHTBULB)
|
||||
privacy_level = models.CharField(max_length=30, choices=PrivacyLevel.choices, default=PrivacyLevel.PRIVATE)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
is_new = self._state.adding
|
||||
|
||||
if self.creator is None:
|
||||
self.managed_by_admin = True
|
||||
|
||||
if is_new:
|
||||
random_sequence = "".join(choice("0123456789") for i in range(6))
|
||||
slug = f"{self.name.lower().replace(' ', '-')}-{random_sequence}"
|
||||
self.slug = slug
|
||||
|
||||
super().save(*args, **kwargs)
|
||||
|
||||
|
||||
class ProcessLock(BaseModel):
|
||||
|
@ -173,22 +220,11 @@ class ProcessLock(BaseModel):
|
|||
def verify_agent(sender, instance, **kwargs):
|
||||
# check if this is a new instance
|
||||
if instance._state.adding:
|
||||
if Agent.objects.filter(name=instance.name, public=True).exists():
|
||||
if Agent.objects.filter(name=instance.name, privacy_level=Agent.PrivacyLevel.PUBLIC).exists():
|
||||
raise ValidationError(f"A public Agent with the name {instance.name} already exists.")
|
||||
if Agent.objects.filter(name=instance.name, creator=instance.creator).exists():
|
||||
raise ValidationError(f"A private Agent with the name {instance.name} already exists.")
|
||||
|
||||
slug = instance.name.lower().replace(" ", "-")
|
||||
observed_random_numbers = set()
|
||||
while Agent.objects.filter(slug=slug).exists():
|
||||
try:
|
||||
random_number = choice([i for i in range(0, 1000) if i not in observed_random_numbers])
|
||||
except IndexError:
|
||||
raise ValidationError("Unable to generate a unique slug for the Agent. Please try again later.")
|
||||
observed_random_numbers.add(random_number)
|
||||
slug = f"{slug}-{random_number}"
|
||||
instance.slug = slug
|
||||
|
||||
|
||||
class NotionConfig(BaseModel):
|
||||
token = models.CharField(max_length=200)
|
||||
|
@ -406,6 +442,7 @@ class Entry(BaseModel):
|
|||
GITHUB = "github"
|
||||
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
embeddings = VectorField(dimensions=None)
|
||||
raw = models.TextField()
|
||||
compiled = models.TextField()
|
||||
|
@ -418,12 +455,17 @@ class Entry(BaseModel):
|
|||
hashed_value = models.CharField(max_length=100)
|
||||
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if self.user and self.agent:
|
||||
raise ValidationError("An Entry cannot be associated with both a user and an agent.")
|
||||
|
||||
|
||||
class FileObject(BaseModel):
|
||||
# Same as Entry but raw will be a much larger string
|
||||
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
|
||||
raw_text = models.TextField()
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class EntryDates(BaseModel):
|
||||
|
|
|
@ -27,6 +27,7 @@ def extract_questions_anthropic(
|
|||
temperature=0.7,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
personality_context: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -59,6 +60,7 @@ def extract_questions_anthropic(
|
|||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
prompt = prompts.extract_questions_anthropic_user_message.format(
|
||||
|
|
|
@ -28,6 +28,7 @@ def extract_questions_gemini(
|
|||
max_tokens=None,
|
||||
location_data: LocationData = None,
|
||||
user: KhojUser = None,
|
||||
personality_context: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -60,6 +61,7 @@ def extract_questions_gemini(
|
|||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
prompt = prompts.extract_questions_anthropic_user_message.format(
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Thread
|
||||
from typing import Any, Iterator, List, Union
|
||||
from typing import Any, Iterator, List, Optional, Union
|
||||
|
||||
from langchain.schema import ChatMessage
|
||||
from llama_cpp import Llama
|
||||
|
@ -33,6 +33,7 @@ def extract_questions_offline(
|
|||
user: KhojUser = None,
|
||||
max_prompt_size: int = None,
|
||||
temperature: float = 0.7,
|
||||
personality_context: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -73,6 +74,7 @@ def extract_questions_offline(
|
|||
this_year=today.year,
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
messages = generate_chatml_messages_with_context(
|
||||
|
|
|
@ -32,6 +32,7 @@ def extract_questions(
|
|||
user: KhojUser = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
vision_enabled: bool = False,
|
||||
personality_context: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Infer search queries to retrieve relevant notes to answer user query
|
||||
|
@ -68,6 +69,7 @@ def extract_questions(
|
|||
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
prompt = construct_structured_message(
|
||||
|
|
|
@ -129,6 +129,7 @@ User's Notes:
|
|||
|
||||
image_generation_improve_prompt_base = """
|
||||
You are a talented media artist with the ability to describe images to compose in professional, fine detail.
|
||||
{personality_context}
|
||||
Generate a vivid description of the image to be rendered using the provided context and user prompt below:
|
||||
|
||||
Today's Date: {current_date}
|
||||
|
@ -210,6 +211,7 @@ Construct search queries to retrieve relevant information to answer the user's q
|
|||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
|
||||
- Share relevant search queries as a JSON list of strings. Do not say anything else.
|
||||
{personality_context}
|
||||
|
||||
Current Date: {day_of_week}, {current_date}
|
||||
User's Location: {location}
|
||||
|
@ -260,7 +262,7 @@ Construct search queries to retrieve relevant information to answer the user's q
|
|||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
|
||||
|
||||
{personality_context}
|
||||
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
|
||||
Current Date: {day_of_week}, {current_date}
|
||||
User's Location: {location}
|
||||
|
@ -317,7 +319,7 @@ Construct search queries to retrieve relevant information to answer the user's q
|
|||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
|
||||
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
|
||||
|
||||
{personality_context}
|
||||
What searches will you perform to answer the users question? Respond with a JSON object with the key "queries" mapping to a list of searches you would perform on the user's knowledge base. Just return the queries and nothing else.
|
||||
|
||||
Current Date: {day_of_week}, {current_date}
|
||||
|
@ -375,6 +377,7 @@ Tell the user exactly what the website says in response to their query, while ad
|
|||
|
||||
extract_relevant_information = PromptTemplate.from_template(
|
||||
"""
|
||||
{personality_context}
|
||||
Target Query: {query}
|
||||
|
||||
Web Pages:
|
||||
|
@ -400,6 +403,7 @@ Tell the user exactly what the document says in response to their query, while a
|
|||
|
||||
extract_relevant_summary = PromptTemplate.from_template(
|
||||
"""
|
||||
{personality_context}
|
||||
Target Query: {query}
|
||||
|
||||
Document Contents:
|
||||
|
@ -409,9 +413,18 @@ Collate only relevant information from the document to answer the target query.
|
|||
""".strip()
|
||||
)
|
||||
|
||||
personality_context = PromptTemplate.from_template(
|
||||
"""
|
||||
Here's some additional context about you:
|
||||
{personality}
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
pick_relevant_output_mode = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query.
|
||||
{personality_context}
|
||||
You have access to a limited set of modes for your response.
|
||||
You can only use one of these modes.
|
||||
|
||||
|
@ -464,11 +477,12 @@ Khoj:
|
|||
pick_relevant_information_collection_tools = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and helpful search assistant.
|
||||
{personality_context}
|
||||
- You have access to a variety of data sources to help you answer the user's question
|
||||
- You can use the data sources listed below to collect more relevant information
|
||||
- You can use any combination of these data sources to answer the user's question
|
||||
|
||||
Which of the data sources listed below you would use to answer the user's question?
|
||||
Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources:
|
||||
|
||||
{tools}
|
||||
|
||||
|
@ -538,7 +552,7 @@ You are Khoj, an advanced web page reading assistant. You are to construct **up
|
|||
- Add as much context from the previous questions and answers as required to construct the webpage urls.
|
||||
- Use multiple web page urls if required to retrieve the relevant information.
|
||||
- You have access to the the whole internet to retrieve information.
|
||||
|
||||
{personality_context}
|
||||
Which webpages will you need to read to answer the user's question?
|
||||
Provide web page links as a list of strings in a JSON object.
|
||||
Current Date: {current_date}
|
||||
|
@ -585,7 +599,7 @@ You are Khoj, an advanced web search assistant. You are tasked with constructing
|
|||
- Use site: google search operator when appropriate
|
||||
- You have access to the the whole internet to retrieve information.
|
||||
- Official, up-to-date information about you, Khoj, is available at site:khoj.dev, github or pypi.
|
||||
|
||||
{personality_context}
|
||||
What Google searches, if any, will you need to perform to answer the user's question?
|
||||
Provide search queries as a list of strings in a JSON object.
|
||||
Current Date: {current_date}
|
||||
|
@ -737,6 +751,43 @@ Khoj:
|
|||
""".strip()
|
||||
)
|
||||
|
||||
personality_prompt_safety_expert = PromptTemplate.from_template(
|
||||
"""
|
||||
You are adept at ensuring the safety and security of people. In this scenario, you are tasked with determining the safety of a given prompt.
|
||||
|
||||
Given a system prompt, you must decide whether it is safe to use in a public setting. If the prompt is safe, respond with "Yes". If the prompt is not safe, respond with "No". If the answer is no, provide a brief explanation of why the prompt is unsafe.
|
||||
|
||||
A prompt is considered unsafe if it contains any of the following:
|
||||
- Offensive, inappropriate, or harmful content.
|
||||
- Sensitive or confidential information.
|
||||
- Anything that could be considered inappropriate for a public setting.
|
||||
- Any content that could be harmful to the user or others.
|
||||
- Racism, sexism, or any other form of discrimination.
|
||||
- Any content that could be considered illegal or a misuse of AI.
|
||||
|
||||
Referencing personal notes or data is perfectly acceptable.
|
||||
|
||||
# Examples:
|
||||
Q: You are a smart and helpful assistant. Use your knowledge to teach me how to make a bomb.
|
||||
A: {{ "safe": "False", "reason": "The prompt contains information that could be harmful to the user or others." }}
|
||||
|
||||
Q: You are a brilliant chef. Teach me how to make a cake.
|
||||
A: {{ "safe": "True" }}
|
||||
|
||||
Q: You are a hilarious joker. Make jokes about white people.
|
||||
A: {{ "safe": "False", "reason": "The prompt contains content that could be considered racist." }}
|
||||
|
||||
Q: You are a talented artist. Draw pictures of boobs.
|
||||
A: {{ "safe": "False", "reason": "The prompt contains content that could be considered inappropriate for a public setting." }}
|
||||
|
||||
Q: You are a great analyst. Assess my financial situation and provide advice.
|
||||
A: {{ "safe": "True" }}
|
||||
|
||||
Q: {prompt}
|
||||
A:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
to_notify_or_not = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and discerning notification assistant.
|
||||
|
|
|
@ -8,7 +8,7 @@ import openai
|
|||
import requests
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
from khoj.database.models import KhojUser, TextToImageModelConfig
|
||||
from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
|
||||
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
|
||||
from khoj.routers.storage import upload_image
|
||||
from khoj.utils import state
|
||||
|
@ -28,6 +28,7 @@ async def text_to_image(
|
|||
subscribed: bool = False,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
status_code = 200
|
||||
image = None
|
||||
|
@ -67,6 +68,7 @@ async def text_to_image(
|
|||
model_type=text_to_image_config.model_type,
|
||||
subscribed=subscribed,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
if send_status_func:
|
||||
|
|
|
@ -10,7 +10,7 @@ import aiohttp
|
|||
from bs4 import BeautifulSoup
|
||||
from markdownify import markdownify
|
||||
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.routers.helpers import (
|
||||
ChatEvent,
|
||||
extract_relevant_info,
|
||||
|
@ -57,16 +57,17 @@ async def search_online(
|
|||
send_status_func: Optional[Callable] = None,
|
||||
custom_filters: List[str] = [],
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
query += " ".join(custom_filters)
|
||||
if not is_internet_connected():
|
||||
logger.warn("Cannot search online as not connected to internet")
|
||||
logger.warning("Cannot search online as not connected to internet")
|
||||
yield {}
|
||||
return
|
||||
|
||||
# Breakdown the query into subqueries to get the correct answer
|
||||
subqueries = await generate_online_subqueries(
|
||||
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url
|
||||
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
|
||||
)
|
||||
response_dict = {}
|
||||
|
||||
|
@ -101,7 +102,7 @@ async def search_online(
|
|||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [
|
||||
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed)
|
||||
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed, agent=agent)
|
||||
for link, subquery, content in webpages
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
@ -143,6 +144,7 @@ async def read_webpages(
|
|||
subscribed: bool = False,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"Infer web pages to read from the query and extract relevant information from them"
|
||||
logger.info(f"Inferring web pages to read")
|
||||
|
@ -156,7 +158,7 @@ async def read_webpages(
|
|||
webpage_links_str = "\n- " + "\n- ".join(list(urls))
|
||||
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
|
||||
yield {ChatEvent.STATUS: event}
|
||||
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed) for url in urls]
|
||||
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed, agent=agent) for url in urls]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
response: Dict[str, Dict] = defaultdict(dict)
|
||||
|
@ -167,14 +169,14 @@ async def read_webpages(
|
|||
|
||||
|
||||
async def read_webpage_and_extract_content(
|
||||
subquery: str, url: str, content: str = None, subscribed: bool = False
|
||||
subquery: str, url: str, content: str = None, subscribed: bool = False, agent: Agent = None
|
||||
) -> Tuple[str, Union[None, str], str]:
|
||||
try:
|
||||
if is_none_or_empty(content):
|
||||
with timer(f"Reading web page at '{url}' took", logger):
|
||||
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
|
||||
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
|
||||
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed)
|
||||
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed, agent=agent)
|
||||
return subquery, extracted_info, url
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read web page at '{url}' with {e}")
|
||||
|
|
|
@ -27,7 +27,13 @@ from khoj.database.adapters import (
|
|||
get_user_photo,
|
||||
get_user_search_model_or_default,
|
||||
)
|
||||
from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOptions
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatModelOptions,
|
||||
KhojUser,
|
||||
SpeechToTextModelOptions,
|
||||
)
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.anthropic.anthropic_chat import (
|
||||
extract_questions_anthropic,
|
||||
)
|
||||
|
@ -106,6 +112,7 @@ async def execute_search(
|
|||
r: Optional[bool] = False,
|
||||
max_distance: Optional[Union[float, None]] = None,
|
||||
dedupe: Optional[bool] = True,
|
||||
agent: Optional[Agent] = None,
|
||||
):
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -157,6 +164,7 @@ async def execute_search(
|
|||
t,
|
||||
question_embedding=encoded_asymmetric_query,
|
||||
max_distance=max_distance,
|
||||
agent=agent,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -333,6 +341,7 @@ async def extract_references_and_questions(
|
|||
location_data: LocationData = None,
|
||||
send_status_func: Optional[Callable] = None,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
user = request.user.object if request.user.is_authenticated else None
|
||||
|
||||
|
@ -348,9 +357,10 @@ async def extract_references_and_questions(
|
|||
return
|
||||
|
||||
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
|
||||
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
|
||||
yield compiled_references, inferred_queries, q
|
||||
return
|
||||
if not await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent):
|
||||
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
|
||||
yield compiled_references, inferred_queries, q
|
||||
return
|
||||
|
||||
# Extract filter terms from user message
|
||||
defiltered_query = q
|
||||
|
@ -368,6 +378,8 @@ async def extract_references_and_questions(
|
|||
using_offline_chat = False
|
||||
logger.debug(f"Filters in query: {filters_in_query}")
|
||||
|
||||
personality_context = prompts.personality_context.format(personality=agent.personality) if agent else ""
|
||||
|
||||
# Infer search queries from user message
|
||||
with timer("Extracting search queries took", logger):
|
||||
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
|
||||
|
@ -392,6 +404,7 @@ async def extract_references_and_questions(
|
|||
location_data=location_data,
|
||||
user=user,
|
||||
max_prompt_size=conversation_config.max_prompt_size,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
|
||||
openai_chat_config = conversation_config.openai_config
|
||||
|
@ -408,6 +421,7 @@ async def extract_references_and_questions(
|
|||
user=user,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
vision_enabled=vision_enabled,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -419,6 +433,7 @@ async def extract_references_and_questions(
|
|||
conversation_log=meta_log,
|
||||
location_data=location_data,
|
||||
user=user,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
|
@ -431,6 +446,7 @@ async def extract_references_and_questions(
|
|||
location_data=location_data,
|
||||
max_tokens=conversation_config.max_prompt_size,
|
||||
user=user,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
# Collate search results as context for GPT
|
||||
|
@ -452,6 +468,7 @@ async def extract_references_and_questions(
|
|||
r=True,
|
||||
max_distance=d,
|
||||
dedupe=False,
|
||||
agent=agent,
|
||||
)
|
||||
)
|
||||
search_results = text_search.deduplicated_search_responses(search_results)
|
||||
|
|
|
@ -1,13 +1,22 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Request
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from starlette.authentication import requires
|
||||
|
||||
from khoj.database.adapters import AgentAdapters
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.routers.helpers import CommonQueryParams
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
|
||||
from khoj.utils.helpers import (
|
||||
ConversationCommand,
|
||||
command_descriptions_for_agent,
|
||||
mode_descriptions_for_agent,
|
||||
)
|
||||
|
||||
# Initialize Router
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -16,6 +25,18 @@ logger = logging.getLogger(__name__)
|
|||
api_agents = APIRouter()
|
||||
|
||||
|
||||
class ModifyAgentBody(BaseModel):
|
||||
name: str
|
||||
persona: str
|
||||
privacy_level: str
|
||||
icon: str
|
||||
color: str
|
||||
chat_model: str
|
||||
files: Optional[List[str]] = []
|
||||
input_tools: Optional[List[str]] = []
|
||||
output_modes: Optional[List[str]] = []
|
||||
|
||||
|
||||
@api_agents.get("", response_class=Response)
|
||||
async def all_agents(
|
||||
request: Request,
|
||||
|
@ -25,17 +46,22 @@ async def all_agents(
|
|||
agents = await AgentAdapters.aget_all_accessible_agents(user)
|
||||
agents_packet = list()
|
||||
for agent in agents:
|
||||
files = agent.fileobject_set.all()
|
||||
file_names = [file.file_name for file in files]
|
||||
agents_packet.append(
|
||||
{
|
||||
"slug": agent.slug,
|
||||
"avatar": agent.avatar,
|
||||
"name": agent.name,
|
||||
"persona": agent.personality,
|
||||
"public": agent.public,
|
||||
"creator": agent.creator.username if agent.creator else None,
|
||||
"managed_by_admin": agent.managed_by_admin,
|
||||
"color": agent.style_color,
|
||||
"icon": agent.style_icon,
|
||||
"privacy_level": agent.privacy_level,
|
||||
"chat_model": agent.chat_model.chat_model,
|
||||
"files": file_names,
|
||||
"input_tools": agent.input_tools,
|
||||
"output_modes": agent.output_modes,
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -43,3 +69,197 @@ async def all_agents(
|
|||
agents_packet.sort(key=lambda x: x["name"])
|
||||
agents_packet.sort(key=lambda x: x["slug"] == "khoj", reverse=True)
|
||||
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@api_agents.get("/options", response_class=Response)
|
||||
async def get_agent_configuration_options(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
) -> Response:
|
||||
agent_input_tools = [key for key, _ in Agent.InputToolOptions.choices]
|
||||
agent_output_modes = [key for key, _ in Agent.OutputModeOptions.choices]
|
||||
|
||||
agent_input_tool_with_descriptions: Dict[str, str] = {}
|
||||
for key in agent_input_tools:
|
||||
conversation_command = ConversationCommand(key)
|
||||
agent_input_tool_with_descriptions[key] = command_descriptions_for_agent[conversation_command]
|
||||
|
||||
agent_output_modes_with_descriptions: Dict[str, str] = {}
|
||||
for key in agent_output_modes:
|
||||
conversation_command = ConversationCommand(key)
|
||||
agent_output_modes_with_descriptions[key] = mode_descriptions_for_agent[conversation_command]
|
||||
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"input_tools": agent_input_tool_with_descriptions,
|
||||
"output_modes": agent_output_modes_with_descriptions,
|
||||
}
|
||||
),
|
||||
media_type="application/json",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@api_agents.get("/{agent_slug}", response_class=Response)
|
||||
async def get_agent(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
agent_slug: str,
|
||||
) -> Response:
|
||||
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
||||
agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
|
||||
|
||||
if not agent:
|
||||
return Response(
|
||||
content=json.dumps({"error": f"Agent with name {agent_slug} not found."}),
|
||||
media_type="application/json",
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
files = agent.fileobject_set.all()
|
||||
file_names = [file.file_name for file in files]
|
||||
agents_packet = {
|
||||
"slug": agent.slug,
|
||||
"name": agent.name,
|
||||
"persona": agent.personality,
|
||||
"creator": agent.creator.username if agent.creator else None,
|
||||
"managed_by_admin": agent.managed_by_admin,
|
||||
"color": agent.style_color,
|
||||
"icon": agent.style_icon,
|
||||
"privacy_level": agent.privacy_level,
|
||||
"chat_model": agent.chat_model.chat_model,
|
||||
"files": file_names,
|
||||
"input_tools": agent.input_tools,
|
||||
"output_modes": agent.output_modes,
|
||||
}
|
||||
|
||||
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@api_agents.delete("/{agent_slug}", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
async def delete_agent(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
agent_slug: str,
|
||||
) -> Response:
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
|
||||
|
||||
if not agent:
|
||||
return Response(
|
||||
content=json.dumps({"error": f"Agent with name {agent_slug} not found."}),
|
||||
media_type="application/json",
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
await AgentAdapters.adelete_agent_by_slug(agent_slug, user)
|
||||
|
||||
return Response(content=json.dumps({"message": "Agent deleted."}), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@api_agents.post("", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
async def create_agent(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
body: ModifyAgentBody,
|
||||
) -> Response:
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
is_safe_prompt, reason = await acheck_if_safe_prompt(body.persona)
|
||||
if not is_safe_prompt:
|
||||
return Response(
|
||||
content=json.dumps({"error": f"{reason}"}),
|
||||
media_type="application/json",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
agent = await AgentAdapters.aupdate_agent(
|
||||
user,
|
||||
body.name,
|
||||
body.persona,
|
||||
body.privacy_level,
|
||||
body.icon,
|
||||
body.color,
|
||||
body.chat_model,
|
||||
body.files,
|
||||
body.input_tools,
|
||||
body.output_modes,
|
||||
)
|
||||
|
||||
agents_packet = {
|
||||
"slug": agent.slug,
|
||||
"name": agent.name,
|
||||
"persona": agent.personality,
|
||||
"creator": agent.creator.username if agent.creator else None,
|
||||
"managed_by_admin": agent.managed_by_admin,
|
||||
"color": agent.style_color,
|
||||
"icon": agent.style_icon,
|
||||
"privacy_level": agent.privacy_level,
|
||||
"chat_model": agent.chat_model.chat_model,
|
||||
"files": body.files,
|
||||
"input_tools": agent.input_tools,
|
||||
"output_modes": agent.output_modes,
|
||||
}
|
||||
|
||||
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
|
||||
|
||||
|
||||
@api_agents.patch("", response_class=Response)
|
||||
@requires(["authenticated"])
|
||||
async def update_agent(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
body: ModifyAgentBody,
|
||||
) -> Response:
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
is_safe_prompt, reason = await acheck_if_safe_prompt(body.persona)
|
||||
if not is_safe_prompt:
|
||||
return Response(
|
||||
content=json.dumps({"error": f"{reason}"}),
|
||||
media_type="application/json",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
selected_agent = await AgentAdapters.aget_agent_by_name(body.name, user)
|
||||
|
||||
if not selected_agent:
|
||||
return Response(
|
||||
content=json.dumps({"error": f"Agent with name {body.name} not found."}),
|
||||
media_type="application/json",
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
agent = await AgentAdapters.aupdate_agent(
|
||||
user,
|
||||
body.name,
|
||||
body.persona,
|
||||
body.privacy_level,
|
||||
body.icon,
|
||||
body.color,
|
||||
body.chat_model,
|
||||
body.files,
|
||||
body.input_tools,
|
||||
body.output_modes,
|
||||
)
|
||||
|
||||
agents_packet = {
|
||||
"slug": agent.slug,
|
||||
"name": agent.name,
|
||||
"persona": agent.personality,
|
||||
"creator": agent.creator.username if agent.creator else None,
|
||||
"managed_by_admin": agent.managed_by_admin,
|
||||
"color": agent.style_color,
|
||||
"icon": agent.style_icon,
|
||||
"privacy_level": agent.privacy_level,
|
||||
"chat_model": agent.chat_model.chat_model,
|
||||
"files": body.files,
|
||||
"input_tools": agent.input_tools,
|
||||
"output_modes": agent.output_modes,
|
||||
}
|
||||
|
||||
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
|
||||
|
|
|
@ -17,13 +17,14 @@ from starlette.authentication import has_required_scope, requires
|
|||
|
||||
from khoj.app.settings import ALLOWED_HOSTS
|
||||
from khoj.database.adapters import (
|
||||
AgentAdapters,
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
FileObjectAdapters,
|
||||
PublicConversationAdapters,
|
||||
aget_user_name,
|
||||
)
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation.prompts import help_message, no_entries_found
|
||||
from khoj.processor.conversation.utils import save_to_conversation_log
|
||||
from khoj.processor.image.generate import text_to_image
|
||||
|
@ -211,7 +212,6 @@ def chat_history(
|
|||
agent_metadata = {
|
||||
"slug": conversation.agent.slug,
|
||||
"name": conversation.agent.name,
|
||||
"avatar": conversation.agent.avatar,
|
||||
"isCreator": conversation.agent.creator == user,
|
||||
"color": conversation.agent.style_color,
|
||||
"icon": conversation.agent.style_icon,
|
||||
|
@ -268,7 +268,6 @@ def get_shared_chat(
|
|||
agent_metadata = {
|
||||
"slug": conversation.agent.slug,
|
||||
"name": conversation.agent.name,
|
||||
"avatar": conversation.agent.avatar,
|
||||
"isCreator": conversation.agent.creator == user,
|
||||
"color": conversation.agent.style_color,
|
||||
"icon": conversation.agent.style_icon,
|
||||
|
@ -418,7 +417,7 @@ def chat_sessions(
|
|||
conversations = conversations[:8]
|
||||
|
||||
sessions = conversations.values_list(
|
||||
"id", "slug", "title", "agent__slug", "agent__name", "agent__avatar", "created_at", "updated_at"
|
||||
"id", "slug", "title", "agent__slug", "agent__name", "created_at", "updated_at"
|
||||
)
|
||||
|
||||
session_values = [
|
||||
|
@ -426,9 +425,8 @@ def chat_sessions(
|
|||
"conversation_id": str(session[0]),
|
||||
"slug": session[2] or session[1],
|
||||
"agent_name": session[4],
|
||||
"agent_avatar": session[5],
|
||||
"created": session[6].strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"updated": session[7].strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"created": session[5].strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"updated": session[6].strftime("%Y-%m-%d %H:%M:%S"),
|
||||
}
|
||||
for session in sessions
|
||||
]
|
||||
|
@ -590,7 +588,7 @@ async def chat(
|
|||
nonlocal connection_alive, ttft
|
||||
if not connection_alive or await request.is_disconnected():
|
||||
connection_alive = False
|
||||
logger.warn(f"User {user} disconnected from {common.client} client")
|
||||
logger.warning(f"User {user} disconnected from {common.client} client")
|
||||
return
|
||||
try:
|
||||
if event_type == ChatEvent.END_LLM_RESPONSE:
|
||||
|
@ -658,6 +656,11 @@ async def chat(
|
|||
return
|
||||
conversation_id = conversation.id
|
||||
|
||||
agent: Agent | None = None
|
||||
default_agent = await AgentAdapters.aget_default_agent()
|
||||
if conversation.agent and conversation.agent != default_agent:
|
||||
agent = conversation.agent
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
|
||||
user_name = await aget_user_name(user)
|
||||
|
@ -677,7 +680,12 @@ async def chat(
|
|||
|
||||
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
|
||||
conversation_commands = await aget_relevant_information_sources(
|
||||
q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url
|
||||
q,
|
||||
meta_log,
|
||||
is_automated_task,
|
||||
subscribed=subscribed,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
)
|
||||
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
|
||||
async for result in send_event(
|
||||
|
@ -685,7 +693,7 @@ async def chat(
|
|||
):
|
||||
yield result
|
||||
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url)
|
||||
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent)
|
||||
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
|
||||
yield result
|
||||
if mode not in conversation_commands:
|
||||
|
@ -734,7 +742,7 @@ async def chat(
|
|||
yield result
|
||||
|
||||
response = await extract_relevant_summary(
|
||||
q, contextual_data, subscribed=subscribed, uploaded_image_url=uploaded_image_url
|
||||
q, contextual_data, subscribed=subscribed, uploaded_image_url=uploaded_image_url, agent=agent
|
||||
)
|
||||
response_log = str(response)
|
||||
async for result in send_llm_response(response_log):
|
||||
|
@ -816,6 +824,7 @@ async def chat(
|
|||
location,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -853,6 +862,7 @@ async def chat(
|
|||
partial(send_event, ChatEvent.STATUS),
|
||||
custom_filters,
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -876,6 +886,7 @@ async def chat(
|
|||
subscribed,
|
||||
partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -922,6 +933,7 @@ async def chat(
|
|||
subscribed=subscribed,
|
||||
send_status_func=partial(send_event, ChatEvent.STATUS),
|
||||
uploaded_image_url=uploaded_image_url,
|
||||
agent=agent,
|
||||
):
|
||||
if isinstance(result, dict) and ChatEvent.STATUS in result:
|
||||
yield result[ChatEvent.STATUS]
|
||||
|
@ -1132,6 +1144,7 @@ async def get_chat(
|
|||
yield result
|
||||
return
|
||||
conversation_id = conversation.id
|
||||
agent = conversation.agent if conversation.agent else None
|
||||
|
||||
await is_ready_to_chat(user)
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ from khoj.database.adapters import (
|
|||
run_with_process_lock,
|
||||
)
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatModelOptions,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
|
@ -257,8 +258,39 @@ async def acreate_title_from_query(query: str) -> str:
|
|||
return response.strip()
|
||||
|
||||
|
||||
async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Check if the system prompt is safe to use
|
||||
"""
|
||||
safe_prompt_check = prompts.personality_prompt_safety_expert.format(prompt=system_prompt)
|
||||
is_safe = True
|
||||
reason = ""
|
||||
|
||||
with timer("Chat actor: Check if safe prompt", logger):
|
||||
response = await send_message_to_model_wrapper(safe_prompt_check)
|
||||
|
||||
response = response.strip()
|
||||
try:
|
||||
response = json.loads(response)
|
||||
is_safe = response.get("safe", "True") == "True"
|
||||
if not is_safe:
|
||||
reason = response.get("reason", "")
|
||||
except Exception:
|
||||
logger.error(f"Invalid response for checking safe prompt: {response}")
|
||||
|
||||
if not is_safe:
|
||||
logger.error(f"Unsafe prompt: {system_prompt}. Reason: {reason}")
|
||||
|
||||
return is_safe, reason
|
||||
|
||||
|
||||
async def aget_relevant_information_sources(
|
||||
query: str, conversation_history: dict, is_task: bool, subscribed: bool, uploaded_image_url: str = None
|
||||
query: str,
|
||||
conversation_history: dict,
|
||||
is_task: bool,
|
||||
subscribed: bool,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||
|
@ -267,19 +299,27 @@ async def aget_relevant_information_sources(
|
|||
tool_options = dict()
|
||||
tool_options_str = ""
|
||||
|
||||
agent_tools = agent.input_tools if agent else []
|
||||
|
||||
for tool, description in tool_descriptions_for_llm.items():
|
||||
tool_options[tool.value] = description
|
||||
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||
if len(agent_tools) == 0 or tool.value in agent_tools:
|
||||
tool_options_str += f'- "{tool.value}": "{description}"\n'
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
if uploaded_image_url:
|
||||
query = f"[placeholder for user attached image]\n{query}"
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
relevant_tools_prompt = prompts.pick_relevant_information_collection_tools.format(
|
||||
query=query,
|
||||
tools=tool_options_str,
|
||||
chat_history=chat_history,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Infer information sources to refer", logger):
|
||||
|
@ -300,7 +340,10 @@ async def aget_relevant_information_sources(
|
|||
|
||||
final_response = [] if not is_task else [ConversationCommand.AutomatedTask]
|
||||
for llm_suggested_tool in response:
|
||||
if llm_suggested_tool in tool_options.keys():
|
||||
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
|
||||
if llm_suggested_tool in tool_options.keys() and (
|
||||
len(agent_tools) == 0 or llm_suggested_tool in agent_tools
|
||||
):
|
||||
# Check whether the tool exists as a valid ConversationCommand
|
||||
final_response.append(ConversationCommand(llm_suggested_tool))
|
||||
|
||||
|
@ -313,7 +356,7 @@ async def aget_relevant_information_sources(
|
|||
|
||||
|
||||
async def aget_relevant_output_modes(
|
||||
query: str, conversation_history: dict, is_task: bool = False, uploaded_image_url: str = None
|
||||
query: str, conversation_history: dict, is_task: bool = False, uploaded_image_url: str = None, agent: Agent = None
|
||||
):
|
||||
"""
|
||||
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
|
||||
|
@ -322,22 +365,30 @@ async def aget_relevant_output_modes(
|
|||
mode_options = dict()
|
||||
mode_options_str = ""
|
||||
|
||||
output_modes = agent.output_modes if agent else []
|
||||
|
||||
for mode, description in mode_descriptions_for_llm.items():
|
||||
# Do not allow tasks to schedule another task
|
||||
if is_task and mode == ConversationCommand.Automation:
|
||||
continue
|
||||
mode_options[mode.value] = description
|
||||
mode_options_str += f'- "{mode.value}": "{description}"\n'
|
||||
if len(output_modes) == 0 or mode.value in output_modes:
|
||||
mode_options_str += f'- "{mode.value}": "{description}"\n'
|
||||
|
||||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
if uploaded_image_url:
|
||||
query = f"[placeholder for user attached image]\n{query}"
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
|
||||
query=query,
|
||||
modes=mode_options_str,
|
||||
chat_history=chat_history,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Infer output mode for chat response", logger):
|
||||
|
@ -352,7 +403,9 @@ async def aget_relevant_output_modes(
|
|||
return ConversationCommand.Text
|
||||
|
||||
output_mode = response["output"]
|
||||
if output_mode in mode_options.keys():
|
||||
|
||||
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
|
||||
if output_mode in mode_options.keys() and (len(output_modes) == 0 or output_mode in output_modes):
|
||||
# Check whether the tool exists as a valid ConversationCommand
|
||||
return ConversationCommand(output_mode)
|
||||
|
||||
|
@ -364,7 +417,12 @@ async def aget_relevant_output_modes(
|
|||
|
||||
|
||||
async def infer_webpage_urls(
|
||||
q: str, conversation_history: dict, location_data: LocationData, user: KhojUser, uploaded_image_url: str = None
|
||||
q: str,
|
||||
conversation_history: dict,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Infer webpage links from the given query
|
||||
|
@ -374,12 +432,17 @@ async def infer_webpage_urls(
|
|||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
online_queries_prompt = prompts.infer_webpages_to_read.format(
|
||||
current_date=utc_date,
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Infer webpage urls to read", logger):
|
||||
|
@ -400,7 +463,12 @@ async def infer_webpage_urls(
|
|||
|
||||
|
||||
async def generate_online_subqueries(
|
||||
q: str, conversation_history: dict, location_data: LocationData, user: KhojUser, uploaded_image_url: str = None
|
||||
q: str,
|
||||
conversation_history: dict,
|
||||
location_data: LocationData,
|
||||
user: KhojUser,
|
||||
uploaded_image_url: str = None,
|
||||
agent: Agent = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate subqueries from the given query
|
||||
|
@ -410,12 +478,17 @@ async def generate_online_subqueries(
|
|||
chat_history = construct_chat_history(conversation_history)
|
||||
|
||||
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
|
||||
current_date=utc_date,
|
||||
query=q,
|
||||
chat_history=chat_history,
|
||||
location=location,
|
||||
username=username,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Generate online search subqueries", logger):
|
||||
|
@ -464,7 +537,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
|
|||
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
||||
|
||||
|
||||
async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[str, None]:
|
||||
async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Agent = None) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
"""
|
||||
|
@ -472,9 +545,14 @@ async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[
|
|||
if is_none_or_empty(corpus) or is_none_or_empty(q):
|
||||
return None
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
extract_relevant_information = prompts.extract_relevant_information.format(
|
||||
query=q,
|
||||
corpus=corpus.strip(),
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
||||
|
@ -490,7 +568,7 @@ async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[
|
|||
|
||||
|
||||
async def extract_relevant_summary(
|
||||
q: str, corpus: str, subscribed: bool = False, uploaded_image_url: str = None
|
||||
q: str, corpus: str, subscribed: bool = False, uploaded_image_url: str = None, agent: Agent = None
|
||||
) -> Union[str, None]:
|
||||
"""
|
||||
Extract relevant information for a given query from the target corpus
|
||||
|
@ -499,9 +577,14 @@ async def extract_relevant_summary(
|
|||
if is_none_or_empty(corpus) or is_none_or_empty(q):
|
||||
return None
|
||||
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
|
||||
extract_relevant_information = prompts.extract_relevant_summary.format(
|
||||
query=q,
|
||||
corpus=corpus.strip(),
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
||||
|
@ -526,12 +609,16 @@ async def generate_better_image_prompt(
|
|||
model_type: Optional[str] = None,
|
||||
subscribed: bool = False,
|
||||
uploaded_image_url: Optional[str] = None,
|
||||
agent: Agent = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a better image prompt from the given query
|
||||
"""
|
||||
|
||||
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
|
||||
personality_context = (
|
||||
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
|
||||
)
|
||||
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
|
||||
|
||||
if location_data:
|
||||
|
@ -558,6 +645,7 @@ async def generate_better_image_prompt(
|
|||
current_date=today_date,
|
||||
references=user_references,
|
||||
online_results=simplified_online_results,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
|
||||
image_prompt = prompts.image_generation_improve_prompt_sd.format(
|
||||
|
@ -567,6 +655,7 @@ async def generate_better_image_prompt(
|
|||
current_date=today_date,
|
||||
references=user_references,
|
||||
online_results=simplified_online_results,
|
||||
personality_context=personality_context,
|
||||
)
|
||||
|
||||
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
|
||||
|
@ -651,15 +740,13 @@ async def send_message_to_model_wrapper(
|
|||
model_type=conversation_config.model_type,
|
||||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
return send_message_to_model(
|
||||
messages=truncated_messages,
|
||||
api_key=api_key,
|
||||
model=chat_model,
|
||||
response_type=response_type,
|
||||
api_base_url=api_base_url,
|
||||
)
|
||||
|
||||
return openai_response
|
||||
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
|
||||
api_key = conversation_config.openai_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Type, Union
|
||||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from asgiref.sync import sync_to_async
|
||||
from sentence_transformers import util
|
||||
|
||||
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
|
||||
from khoj.database.models import Agent
|
||||
from khoj.database.models import Entry as DbEntry
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.processor.content.text_to_entries import TextToEntries
|
||||
|
@ -101,6 +102,7 @@ async def query(
|
|||
type: SearchType = SearchType.All,
|
||||
question_embedding: Union[torch.Tensor, None] = None,
|
||||
max_distance: float = None,
|
||||
agent: Optional[Agent] = None,
|
||||
) -> Tuple[List[dict], List[Entry]]:
|
||||
"Search for entries that answer the query"
|
||||
|
||||
|
@ -129,6 +131,7 @@ async def query(
|
|||
file_type_filter=file_type,
|
||||
raw_query=raw_query,
|
||||
max_distance=max_distance,
|
||||
agent=agent,
|
||||
).all()
|
||||
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
|
||||
|
||||
|
|
|
@ -325,7 +325,15 @@ command_descriptions = {
|
|||
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
||||
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
|
||||
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
|
||||
ConversationCommand.Summarize: "Create an appropriate summary using provided documents.",
|
||||
ConversationCommand.Summarize: "Get help with a question pertaining to an entire document.",
|
||||
}
|
||||
|
||||
command_descriptions_for_agent = {
|
||||
ConversationCommand.General: "Respond without any outside information or personal knowledge.",
|
||||
ConversationCommand.Notes: "Search through the knowledge base. Required if the agent expects context from the knowledge base.",
|
||||
ConversationCommand.Online: "Search for the latest, up-to-date information from the internet.",
|
||||
ConversationCommand.Webpage: "Scrape specific web pages for information.",
|
||||
ConversationCommand.Summarize: "Retrieve an answer that depends on the entire document or a large text. Knowledge base must be a single document.",
|
||||
}
|
||||
|
||||
tool_descriptions_for_llm = {
|
||||
|
@ -334,7 +342,7 @@ tool_descriptions_for_llm = {
|
|||
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
|
||||
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
|
||||
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
|
||||
ConversationCommand.Summarize: "To create a summary of the document provided by the user.",
|
||||
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
|
||||
}
|
||||
|
||||
mode_descriptions_for_llm = {
|
||||
|
@ -343,6 +351,11 @@ mode_descriptions_for_llm = {
|
|||
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
|
||||
}
|
||||
|
||||
mode_descriptions_for_agent = {
|
||||
ConversationCommand.Image: "Allow the agent to generate images.",
|
||||
ConversationCommand.Text: "Allow the agent to generate text.",
|
||||
}
|
||||
|
||||
|
||||
class ImageIntentType(Enum):
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue