Include agent personality through subtasks and support custom agents (#916)

Currently, the personality of the agent is only included in the final response that it returns to the user. Historically, this was because models were quite bad at navigating the additional context of personality, and there was a bias towards having more control over certain operations (e.g., tool selection, question extraction).

Going forward, it should be more approachable to have prompts included in the sub tasks that Khoj runs in order to response to a given query. Make this possible in this PR. This also sets us up for agent creation becoming available soon.

Create custom agents in #928

Agents are useful insofar as you can personalize them to fulfill specific subtasks you need to accomplish. In this PR, we add support for using custom agents that can be configured with a custom system prompt (aka persona) and knowledge base (from your own indexed documents). Once created, private agents can be accessible only to the creator, and protected agents can be accessible via a direct link.

Custom tool selection for agents in #930

Expose the functionality to select which tools a given agent has access to. By default, they have all. Can limit both information sources and output modes.
Add new tools to the agent modification form
This commit is contained in:
sabaimran 2024-10-07 00:21:55 -07:00 committed by GitHub
parent c0193744f5
commit 405c047c0c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 2350 additions and 284 deletions

File diff suppressed because it is too large Load diff

View file

@ -36,6 +36,15 @@ export interface SyncedContent {
github: boolean;
notion: boolean;
}
export enum SubscriptionStates {
EXPIRED = "expired",
TRIAL = "trial",
SUBSCRIBED = "subscribed",
UNSUBSCRIBED = "unsubscribed",
INVALID = "invalid",
}
export interface UserConfig {
// user info
username: string;
@ -58,7 +67,7 @@ export interface UserConfig {
voice_model_options: ModelOptions[];
selected_voice_model_config: number;
// user billing info
subscription_state: string;
subscription_state: SubscriptionStates;
subscription_renewal_date: string;
// server settings
khoj_cloud_subscription_url: string | undefined;

View file

@ -1,4 +1,4 @@
const tailwindColors = [
export const tailwindColors = [
"red",
"yellow",
"green",

View file

@ -26,6 +26,28 @@ import {
Wallet,
PencilLine,
Chalkboard,
Gps,
Question,
Browser,
Notebook,
Shapes,
ChatsTeardrop,
GlobeSimple,
ArrowRight,
Cigarette,
CraneTower,
Heart,
Leaf,
NewspaperClipping,
OrangeSlice,
Rainbow,
SmileyMelting,
YinYang,
SneakerMove,
Student,
Oven,
Gavel,
Broadcast,
} from "@phosphor-icons/react";
import { Markdown, OrgMode, Pdf, Word } from "@/app/components/logo/fileLogo";
@ -103,8 +125,92 @@ const iconMap: IconMap = {
Chalkboard: (color: string, width: string, height: string) => (
<Chalkboard className={`${width} ${height} ${color} mr-2`} />
),
Cigarette: (color: string, width: string, height: string) => (
<Cigarette className={`${width} ${height} ${color} mr-2`} />
),
CraneTower: (color: string, width: string, height: string) => (
<CraneTower className={`${width} ${height} ${color} mr-2`} />
),
Heart: (color: string, width: string, height: string) => (
<Heart className={`${width} ${height} ${color} mr-2`} />
),
Leaf: (color: string, width: string, height: string) => (
<Leaf className={`${width} ${height} ${color} mr-2`} />
),
NewspaperClipping: (color: string, width: string, height: string) => (
<NewspaperClipping className={`${width} ${height} ${color} mr-2`} />
),
OrangeSlice: (color: string, width: string, height: string) => (
<OrangeSlice className={`${width} ${height} ${color} mr-2`} />
),
SmileyMelting: (color: string, width: string, height: string) => (
<SmileyMelting className={`${width} ${height} ${color} mr-2`} />
),
YinYang: (color: string, width: string, height: string) => (
<YinYang className={`${width} ${height} ${color} mr-2`} />
),
SneakerMove: (color: string, width: string, height: string) => (
<SneakerMove className={`${width} ${height} ${color} mr-2`} />
),
Student: (color: string, width: string, height: string) => (
<Student className={`${width} ${height} ${color} mr-2`} />
),
Oven: (color: string, width: string, height: string) => (
<Oven className={`${width} ${height} ${color} mr-2`} />
),
Gavel: (color: string, width: string, height: string) => (
<Gavel className={`${width} ${height} ${color} mr-2`} />
),
Broadcast: (color: string, width: string, height: string) => (
<Broadcast className={`${width} ${height} ${color} mr-2`} />
),
};
export function getIconForSlashCommand(command: string, customClassName: string | null = null) {
const className = customClassName ?? "h-4 w-4";
if (command.includes("summarize")) {
return <Gps className={className} />;
}
if (command.includes("help")) {
return <Question className={className} />;
}
if (command.includes("automation")) {
return <Robot className={className} />;
}
if (command.includes("webpage")) {
return <Browser className={className} />;
}
if (command.includes("notes")) {
return <Notebook className={className} />;
}
if (command.includes("image")) {
return <Image className={className} />;
}
if (command.includes("default")) {
return <Shapes className={className} />;
}
if (command.includes("general")) {
return <ChatsTeardrop className={className} />;
}
if (command.includes("online")) {
return <GlobeSimple className={className} />;
}
if (command.includes("text")) {
return <PencilLine className={className} />;
}
return <ArrowRight className={className} />;
}
function getIconFromIconName(
iconName: string,
color: string = "gray",
@ -141,4 +247,8 @@ function getIconFromFilename(
}
}
export { getIconFromIconName, getIconFromFilename };
function getAvailableIcons() {
return Object.keys(iconMap);
}
export { getIconFromIconName, getIconFromFilename, getAvailableIcons };

View file

@ -15,7 +15,7 @@ import { InlineLoading } from "../loading/loading";
import { Lightbulb, ArrowDown } from "@phosphor-icons/react";
import ProfileCard from "../profileCard/profileCard";
import AgentProfileCard from "../profileCard/profileCard";
import { getIconFromIconName } from "@/app/common/iconUtils";
import { AgentData } from "@/app/agents/page";
import React from "react";
@ -350,7 +350,7 @@ export default function ChatHistory(props: ChatHistoryProps) {
{data && (
<div className={`${styles.agentIndicator} pb-4`}>
<div className="relative group mx-2 cursor-pointer">
<ProfileCard
<AgentProfileCard
name={constructAgentName()}
link={constructAgentLink()}
avatar={

View file

@ -50,6 +50,7 @@ import { convertToBGClass } from "@/app/common/colorUtils";
import LoginPrompt from "../loginPrompt/loginPrompt";
import { uploadDataForIndexing } from "../../common/chatFunctions";
import { InlineLoading } from "../loading/loading";
import { getIconForSlashCommand } from "@/app/common/iconUtils";
export interface ChatOptions {
[key: string]: string;
@ -193,46 +194,6 @@ export default function ChatInputArea(props: ChatInputProps) {
);
}
function getIconForSlashCommand(command: string) {
const className = "h-4 w-4 mr-2";
if (command.includes("summarize")) {
return <Gps className={className} />;
}
if (command.includes("help")) {
return <Question className={className} />;
}
if (command.includes("automation")) {
return <Robot className={className} />;
}
if (command.includes("webpage")) {
return <Browser className={className} />;
}
if (command.includes("notes")) {
return <Notebook className={className} />;
}
if (command.includes("image")) {
return <Image className={className} />;
}
if (command.includes("default")) {
return <Shapes className={className} />;
}
if (command.includes("general")) {
return <ChatsTeardrop className={className} />;
}
if (command.includes("online")) {
return <GlobeSimple className={className} />;
}
return <ArrowRight className={className} />;
}
// Assuming this function is added within the same context as the provided excerpt
async function startRecordingAndTranscribe() {
try {
@ -426,7 +387,11 @@ export default function ChatInputArea(props: ChatInputProps) {
>
<div className="grid grid-cols-1 gap-1">
<div className="font-bold flex items-center">
{getIconForSlashCommand(key)}/{key}
{getIconForSlashCommand(
key,
"h-4 w-4 mr-2",
)}
/{key}
</div>
<div>{value}</div>
</div>

View file

@ -11,11 +11,11 @@ interface ProfileCardProps {
description?: string; // Optional description field
}
const ProfileCard: React.FC<ProfileCardProps> = ({ name, avatar, link, description }) => {
const AgentProfileCard: React.FC<ProfileCardProps> = ({ name, avatar, link, description }) => {
return (
<div className="relative group flex">
<TooltipProvider>
<Tooltip>
<Tooltip delayDuration={0}>
<TooltipTrigger asChild>
<Button variant="ghost" className="flex items-center justify-center">
{avatar}
@ -24,7 +24,6 @@ const ProfileCard: React.FC<ProfileCardProps> = ({ name, avatar, link, descripti
</TooltipTrigger>
<TooltipContent>
<div className="w-80 h-30">
{/* <div className="absolute left-0 bottom-full w-80 h-30 p-2 pb-4 bg-white border border-gray-300 rounded-lg shadow-lg opacity-0 group-hover:opacity-100 transition-opacity duration-300"> */}
<a
href={link}
target="_blank"
@ -52,4 +51,4 @@ const ProfileCard: React.FC<ProfileCardProps> = ({ name, avatar, link, descripti
);
};
export default ProfileCard;
export default AgentProfileCard;

View file

@ -17,8 +17,16 @@ interface ShareLinkProps {
title: string;
description: string;
url: string;
onShare: () => void;
buttonVariant?: keyof typeof buttonVariants;
onShare?: () => void;
buttonVariant?:
| "default"
| "destructive"
| "outline"
| "secondary"
| "ghost"
| "link"
| null
| undefined;
includeIcon?: boolean;
buttonClassName?: string;
}
@ -38,7 +46,7 @@ export default function ShareLink(props: ShareLinkProps) {
<Button
size="sm"
className={`${props.buttonClassName || "px-3"}`}
variant={props.buttonVariant ?? ("default" as const)}
variant={props.buttonVariant ?? "default"}
>
{props.includeIcon && <Share className="w-4 h-4 mr-2" />}
{props.buttonTitle}

View file

@ -63,7 +63,6 @@ interface ChatHistory {
conversation_id: string;
slug: string;
agent_name: string;
agent_avatar: string;
compressed: boolean;
created: string;
updated: string;
@ -435,7 +434,6 @@ function SessionsAndFiles(props: SessionsAndFilesProps) {
chatHistory.conversation_id
}
slug={chatHistory.slug}
agent_avatar={chatHistory.agent_avatar}
agent_name={chatHistory.agent_name}
showSidePanel={props.setEnabled}
/>
@ -713,7 +711,6 @@ function ChatSessionsModal({ data, showSidePanel }: ChatSessionsModalProps) {
key={chatHistory.conversation_id}
conversation_id={chatHistory.conversation_id}
slug={chatHistory.slug}
agent_avatar={chatHistory.agent_avatar}
agent_name={chatHistory.agent_name}
showSidePanel={showSidePanel}
/>

View file

@ -123,18 +123,7 @@ function ChatBodyData(props: ChatBodyDataProps) {
//generate colored icons for the selected agents
const agentIcons = agents
.filter((agent) => agent !== null && agent !== undefined)
.map(
(agent) =>
getIconFromIconName(agent.icon, agent.color) || (
<Image
key={agent.name}
src={agent.avatar}
alt={agent.name}
width={50}
height={50}
/>
),
);
.map((agent) => getIconFromIconName(agent.icon, agent.color)!);
setAgentIcons(agentIcons);
}, [agentsData, props.isMobileWidth]);

View file

@ -6,7 +6,7 @@ import "intl-tel-input/styles";
import { Suspense, useEffect, useRef, useState } from "react";
import { useToast } from "@/components/ui/use-toast";
import { useUserConfig, ModelOptions, UserConfig } from "../common/auth";
import { useUserConfig, ModelOptions, UserConfig, SubscriptionStates } from "../common/auth";
import { toTitleCase, useIsMobileWidth } from "../common/utils";
import { isValidPhoneNumber } from "libphonenumber-js";
@ -276,7 +276,7 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
)}
</div>
<div
className={`flex-none p-4 bg-secondary border-b ${isDragAndDropping ? "animate-pulse" : ""}`}
className={`flex-none p-4 bg-secondary border-b ${isDragAndDropping ? "animate-pulse" : ""} rounded-lg`}
>
<div className="flex items-center justify-center w-full h-32 border-2 border-dashed border-gray-300 rounded-lg">
{isDragAndDropping ? (
@ -294,7 +294,6 @@ const ManageFilesModal: React.FC<{ onClose: () => void }> = ({ onClose }) => {
</div>
</div>
<div className="flex flex-col h-full">
<div className="flex-none p-4">Synced files</div>
<div className="flex-none p-4 bg-background border-b">
<CommandInput
placeholder="Find synced files"
@ -615,7 +614,9 @@ export default function SettingsView() {
if (userConfig) {
let newUserConfig = userConfig;
newUserConfig.subscription_state =
state === "cancel" ? "unsubscribed" : "subscribed";
state === "cancel"
? SubscriptionStates.UNSUBSCRIBED
: SubscriptionStates.SUBSCRIBED;
setUserConfig(newUserConfig);
}

View file

@ -10,6 +10,7 @@ from enum import Enum
from typing import Callable, Iterable, List, Optional, Type
import cron_descriptor
import django
from apscheduler.job import Job
from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore
@ -551,26 +552,62 @@ class ClientApplicationAdapters:
class AgentAdapters:
DEFAULT_AGENT_NAME = "Khoj"
DEFAULT_AGENT_AVATAR = "https://assets.khoj.dev/lamp-128.png"
DEFAULT_AGENT_SLUG = "khoj"
@staticmethod
async def aget_readonly_agent_by_slug(agent_slug: str, user: KhojUser):
return await Agent.objects.filter(
(Q(slug__iexact=agent_slug.lower()))
& (
Q(privacy_level=Agent.PrivacyLevel.PUBLIC)
| Q(privacy_level=Agent.PrivacyLevel.PROTECTED)
| Q(creator=user)
)
).afirst()
@staticmethod
async def adelete_agent_by_slug(agent_slug: str, user: KhojUser):
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
if agent:
await agent.adelete()
return True
return False
@staticmethod
async def aget_agent_by_slug(agent_slug: str, user: KhojUser):
return await Agent.objects.filter(
(Q(slug__iexact=agent_slug.lower())) & (Q(public=True) | Q(creator=user))
(Q(slug__iexact=agent_slug.lower())) & (Q(privacy_level=Agent.PrivacyLevel.PUBLIC) | Q(creator=user))
).afirst()
@staticmethod
async def aget_agent_by_name(agent_name: str, user: KhojUser):
return await Agent.objects.filter(
(Q(name__iexact=agent_name.lower())) & (Q(privacy_level=Agent.PrivacyLevel.PUBLIC) | Q(creator=user))
).afirst()
@staticmethod
def get_agent_by_slug(slug: str, user: KhojUser = None):
if user:
return Agent.objects.filter((Q(slug__iexact=slug.lower())) & (Q(public=True) | Q(creator=user))).first()
return Agent.objects.filter(slug__iexact=slug.lower(), public=True).first()
return Agent.objects.filter(
(Q(slug__iexact=slug.lower())) & (Q(privacy_level=Agent.PrivacyLevel.PUBLIC) | Q(creator=user))
).first()
return Agent.objects.filter(slug__iexact=slug.lower(), privacy_level=Agent.PrivacyLevel.PUBLIC).first()
@staticmethod
def get_all_accessible_agents(user: KhojUser = None):
public_query = Q(privacy_level=Agent.PrivacyLevel.PUBLIC)
if user:
return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct().order_by("created_at")
return Agent.objects.filter(public=True).order_by("created_at")
return (
Agent.objects.filter(public_query | Q(creator=user))
.distinct()
.order_by("created_at")
.prefetch_related("creator", "chat_model", "fileobject_set")
)
return (
Agent.objects.filter(public_query)
.order_by("created_at")
.prefetch_related("creator", "chat_model", "fileobject_set")
)
@staticmethod
async def aget_all_accessible_agents(user: KhojUser = None) -> List[Agent]:
@ -609,12 +646,11 @@ class AgentAdapters:
# The default agent is public and managed by the admin. It's handled a little differently than other agents.
agent = Agent.objects.create(
name=AgentAdapters.DEFAULT_AGENT_NAME,
public=True,
privacy_level=Agent.PrivacyLevel.PUBLIC,
managed_by_admin=True,
chat_model=default_conversation_config,
personality=default_personality,
tools=["*"],
avatar=AgentAdapters.DEFAULT_AGENT_AVATAR,
slug=AgentAdapters.DEFAULT_AGENT_SLUG,
)
Conversation.objects.filter(agent=None).update(agent=agent)
@ -625,6 +661,68 @@ class AgentAdapters:
async def aget_default_agent():
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
@staticmethod
async def aupdate_agent(
user: KhojUser,
name: str,
personality: str,
privacy_level: str,
icon: str,
color: str,
chat_model: str,
files: List[str],
input_tools: List[str],
output_modes: List[str],
):
chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst()
agent, created = await Agent.objects.filter(name=name, creator=user).aupdate_or_create(
defaults={
"name": name,
"creator": user,
"personality": personality,
"privacy_level": privacy_level,
"style_icon": icon,
"style_color": color,
"chat_model": chat_model_option,
"input_tools": input_tools,
"output_modes": output_modes,
}
)
# Delete all existing files and entries
await FileObject.objects.filter(agent=agent).adelete()
await Entry.objects.filter(agent=agent).adelete()
for file in files:
reference_file = await FileObject.objects.filter(file_name=file, user=agent.creator).afirst()
if reference_file:
await FileObject.objects.acreate(file_name=file, agent=agent, raw_text=reference_file.raw_text)
# Duplicate all entries associated with the file
entries: List[Entry] = []
async for entry in Entry.objects.filter(file_path=file, user=agent.creator).aiterator():
entries.append(
Entry(
agent=agent,
embeddings=entry.embeddings,
raw=entry.raw,
compiled=entry.compiled,
heading=entry.heading,
file_source=entry.file_source,
file_type=entry.file_type,
file_path=entry.file_path,
file_name=entry.file_name,
url=entry.url,
hashed_value=entry.hashed_value,
)
)
# Bulk create entries
await Entry.objects.abulk_create(entries)
return agent
class PublicConversationAdapters:
@staticmethod
@ -1196,6 +1294,10 @@ class EntryAdapters:
def user_has_entries(user: KhojUser):
return Entry.objects.filter(user=user).exists()
@staticmethod
def agent_has_entries(agent: Agent):
return Entry.objects.filter(agent=agent).exists()
@staticmethod
async def auser_has_entries(user: KhojUser):
return await Entry.objects.filter(user=user).aexists()
@ -1229,15 +1331,19 @@ class EntryAdapters:
return total_size / 1024 / 1024
@staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None, agent: Agent = None):
q_filter_terms = Q()
word_filters = EntryAdapters.word_filter.get_filter_terms(query)
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
user_or_agent = Q(user=user)
if agent != None:
user_or_agent |= Q(agent=agent)
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
return Entry.objects.filter(user=user)
return Entry.objects.filter(user_or_agent)
for term in word_filters:
if term.startswith("+"):
@ -1273,7 +1379,7 @@ class EntryAdapters:
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
relevant_entries = Entry.objects.filter(user=user).filter(q_filter_terms)
relevant_entries = Entry.objects.filter(user_or_agent).filter(q_filter_terms)
if file_type_filter:
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
return relevant_entries
@ -1286,9 +1392,15 @@ class EntryAdapters:
file_type_filter: str = None,
raw_query: str = None,
max_distance: float = math.inf,
agent: Agent = None,
):
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
relevant_entries = relevant_entries.filter(user=user).annotate(
user_or_agent = Q(user=user)
if agent != None:
user_or_agent |= Q(agent=agent)
relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter, agent)
relevant_entries = relevant_entries.filter(user_or_agent).annotate(
distance=CosineDistance("embeddings", embeddings)
)
relevant_entries = relevant_entries.filter(distance__lte=max_distance)

View file

@ -0,0 +1,49 @@
# Generated by Django 5.0.8 on 2024-09-18 02:54
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0064_remove_conversation_temp_id_alter_conversation_id"),
]
operations = [
migrations.RemoveField(
model_name="agent",
name="avatar",
),
migrations.RemoveField(
model_name="agent",
name="public",
),
migrations.AddField(
model_name="agent",
name="privacy_level",
field=models.CharField(
choices=[("public", "Public"), ("private", "Private"), ("protected", "Protected")],
default="private",
max_length=30,
),
),
migrations.AddField(
model_name="entry",
name="agent",
field=models.ForeignKey(
blank=True, default=None, null=True, on_delete=django.db.models.deletion.CASCADE, to="database.agent"
),
),
migrations.AddField(
model_name="fileobject",
name="agent",
field=models.ForeignKey(
blank=True, default=None, null=True, on_delete=django.db.models.deletion.CASCADE, to="database.agent"
),
),
migrations.AlterField(
model_name="agent",
name="slug",
field=models.CharField(max_length=200, unique=True),
),
]

View file

@ -0,0 +1,69 @@
# Generated by Django 5.0.8 on 2024-10-01 00:42
import django.contrib.postgres.fields
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0065_remove_agent_avatar_remove_agent_public_and_more"),
]
operations = [
migrations.RemoveField(
model_name="agent",
name="tools",
),
migrations.AddField(
model_name="agent",
name="input_tools",
field=django.contrib.postgres.fields.ArrayField(
base_field=models.CharField(
choices=[
("general", "General"),
("online", "Online"),
("notes", "Notes"),
("summarize", "Summarize"),
("webpage", "Webpage"),
],
max_length=200,
),
default=list,
size=None,
),
),
migrations.AddField(
model_name="agent",
name="output_modes",
field=django.contrib.postgres.fields.ArrayField(
base_field=models.CharField(choices=[("text", "Text"), ("image", "Image")], max_length=200),
default=list,
size=None,
),
),
migrations.AlterField(
model_name="agent",
name="style_icon",
field=models.CharField(
choices=[
("Lightbulb", "Lightbulb"),
("Health", "Health"),
("Robot", "Robot"),
("Aperture", "Aperture"),
("GraduationCap", "Graduation Cap"),
("Jeep", "Jeep"),
("Island", "Island"),
("MathOperations", "Math Operations"),
("Asclepius", "Asclepius"),
("Couch", "Couch"),
("Code", "Code"),
("Atom", "Atom"),
("ClockCounterClockwise", "Clock Counter Clockwise"),
("PencilLine", "Pencil Line"),
("Chalkboard", "Chalkboard"),
],
default="Lightbulb",
max_length=200,
),
),
]

View file

@ -0,0 +1,50 @@
# Generated by Django 5.0.8 on 2024-10-01 18:42
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0066_remove_agent_tools_agent_input_tools_and_more"),
]
operations = [
migrations.AlterField(
model_name="agent",
name="style_icon",
field=models.CharField(
choices=[
("Lightbulb", "Lightbulb"),
("Health", "Health"),
("Robot", "Robot"),
("Aperture", "Aperture"),
("GraduationCap", "Graduation Cap"),
("Jeep", "Jeep"),
("Island", "Island"),
("MathOperations", "Math Operations"),
("Asclepius", "Asclepius"),
("Couch", "Couch"),
("Code", "Code"),
("Atom", "Atom"),
("ClockCounterClockwise", "Clock Counter Clockwise"),
("PencilLine", "Pencil Line"),
("Chalkboard", "Chalkboard"),
("Cigarette", "Cigarette"),
("CraneTower", "Crane Tower"),
("Heart", "Heart"),
("Leaf", "Leaf"),
("NewspaperClipping", "Newspaper Clipping"),
("OrangeSlice", "Orange Slice"),
("SmileyMelting", "Smiley Melting"),
("YinYang", "Yin Yang"),
("SneakerMove", "Sneaker Move"),
("Student", "Student"),
("Oven", "Oven"),
("Gavel", "Gavel"),
("Broadcast", "Broadcast"),
],
default="Lightbulb",
max_length=200,
),
),
]

View file

@ -3,6 +3,7 @@ import uuid
from random import choice
from django.contrib.auth.models import AbstractUser
from django.contrib.postgres.fields import ArrayField
from django.core.exceptions import ValidationError
from django.db import models
from django.db.models.signals import pre_save
@ -10,6 +11,8 @@ from django.dispatch import receiver
from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField
from khoj.utils.helpers import ConversationCommand
class BaseModel(models.Model):
created_at = models.DateTimeField(auto_now_add=True)
@ -125,7 +128,7 @@ class Agent(BaseModel):
EMERALD = "emerald"
class StyleIconTypes(models.TextChoices):
LIGHBULB = "Lightbulb"
LIGHTBULB = "Lightbulb"
HEALTH = "Health"
ROBOT = "Robot"
APERTURE = "Aperture"
@ -140,20 +143,64 @@ class Agent(BaseModel):
CLOCK_COUNTER_CLOCKWISE = "ClockCounterClockwise"
PENCIL_LINE = "PencilLine"
CHALKBOARD = "Chalkboard"
CIGARETTE = "Cigarette"
CRANE_TOWER = "CraneTower"
HEART = "Heart"
LEAF = "Leaf"
NEWSPAPER_CLIPPING = "NewspaperClipping"
ORANGE_SLICE = "OrangeSlice"
SMILEY_MELTING = "SmileyMelting"
YIN_YANG = "YinYang"
SNEAKER_MOVE = "SneakerMove"
STUDENT = "Student"
OVEN = "Oven"
GAVEL = "Gavel"
BROADCAST = "Broadcast"
class PrivacyLevel(models.TextChoices):
PUBLIC = "public"
PRIVATE = "private"
PROTECTED = "protected"
class InputToolOptions(models.TextChoices):
# These map to various ConversationCommand types
GENERAL = "general"
ONLINE = "online"
NOTES = "notes"
SUMMARIZE = "summarize"
WEBPAGE = "webpage"
class OutputModeOptions(models.TextChoices):
# These map to various ConversationCommand types
TEXT = "text"
IMAGE = "image"
creator = models.ForeignKey(
KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True
) # Creator will only be null when the agents are managed by admin
name = models.CharField(max_length=200)
personality = models.TextField()
avatar = models.URLField(max_length=400, default=None, null=True, blank=True)
tools = models.JSONField(default=list) # List of tools the agent has access to, like online search or notes search
public = models.BooleanField(default=False)
input_tools = ArrayField(models.CharField(max_length=200, choices=InputToolOptions.choices), default=list)
output_modes = ArrayField(models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list)
managed_by_admin = models.BooleanField(default=False)
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
slug = models.CharField(max_length=200)
slug = models.CharField(max_length=200, unique=True)
style_color = models.CharField(max_length=200, choices=StyleColorTypes.choices, default=StyleColorTypes.BLUE)
style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHBULB)
style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHTBULB)
privacy_level = models.CharField(max_length=30, choices=PrivacyLevel.choices, default=PrivacyLevel.PRIVATE)
def save(self, *args, **kwargs):
is_new = self._state.adding
if self.creator is None:
self.managed_by_admin = True
if is_new:
random_sequence = "".join(choice("0123456789") for i in range(6))
slug = f"{self.name.lower().replace(' ', '-')}-{random_sequence}"
self.slug = slug
super().save(*args, **kwargs)
class ProcessLock(BaseModel):
@ -173,22 +220,11 @@ class ProcessLock(BaseModel):
def verify_agent(sender, instance, **kwargs):
# check if this is a new instance
if instance._state.adding:
if Agent.objects.filter(name=instance.name, public=True).exists():
if Agent.objects.filter(name=instance.name, privacy_level=Agent.PrivacyLevel.PUBLIC).exists():
raise ValidationError(f"A public Agent with the name {instance.name} already exists.")
if Agent.objects.filter(name=instance.name, creator=instance.creator).exists():
raise ValidationError(f"A private Agent with the name {instance.name} already exists.")
slug = instance.name.lower().replace(" ", "-")
observed_random_numbers = set()
while Agent.objects.filter(slug=slug).exists():
try:
random_number = choice([i for i in range(0, 1000) if i not in observed_random_numbers])
except IndexError:
raise ValidationError("Unable to generate a unique slug for the Agent. Please try again later.")
observed_random_numbers.add(random_number)
slug = f"{slug}-{random_number}"
instance.slug = slug
class NotionConfig(BaseModel):
token = models.CharField(max_length=200)
@ -406,6 +442,7 @@ class Entry(BaseModel):
GITHUB = "github"
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
embeddings = VectorField(dimensions=None)
raw = models.TextField()
compiled = models.TextField()
@ -418,12 +455,17 @@ class Entry(BaseModel):
hashed_value = models.CharField(max_length=100)
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
def save(self, *args, **kwargs):
if self.user and self.agent:
raise ValidationError("An Entry cannot be associated with both a user and an agent.")
class FileObject(BaseModel):
# Same as Entry but raw will be a much larger string
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
raw_text = models.TextField()
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
class EntryDates(BaseModel):

View file

@ -27,6 +27,7 @@ def extract_questions_anthropic(
temperature=0.7,
location_data: LocationData = None,
user: KhojUser = None,
personality_context: Optional[str] = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
@ -59,6 +60,7 @@ def extract_questions_anthropic(
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
personality_context=personality_context,
)
prompt = prompts.extract_questions_anthropic_user_message.format(

View file

@ -28,6 +28,7 @@ def extract_questions_gemini(
max_tokens=None,
location_data: LocationData = None,
user: KhojUser = None,
personality_context: Optional[str] = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
@ -60,6 +61,7 @@ def extract_questions_gemini(
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
personality_context=personality_context,
)
prompt = prompts.extract_questions_anthropic_user_message.format(

View file

@ -2,7 +2,7 @@ import json
import logging
from datetime import datetime, timedelta
from threading import Thread
from typing import Any, Iterator, List, Union
from typing import Any, Iterator, List, Optional, Union
from langchain.schema import ChatMessage
from llama_cpp import Llama
@ -33,6 +33,7 @@ def extract_questions_offline(
user: KhojUser = None,
max_prompt_size: int = None,
temperature: float = 0.7,
personality_context: Optional[str] = None,
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
@ -73,6 +74,7 @@ def extract_questions_offline(
this_year=today.year,
location=location,
username=username,
personality_context=personality_context,
)
messages = generate_chatml_messages_with_context(

View file

@ -32,6 +32,7 @@ def extract_questions(
user: KhojUser = None,
uploaded_image_url: Optional[str] = None,
vision_enabled: bool = False,
personality_context: Optional[str] = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
@ -68,6 +69,7 @@ def extract_questions(
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
personality_context=personality_context,
)
prompt = construct_structured_message(

View file

@ -129,6 +129,7 @@ User's Notes:
image_generation_improve_prompt_base = """
You are a talented media artist with the ability to describe images to compose in professional, fine detail.
{personality_context}
Generate a vivid description of the image to be rendered using the provided context and user prompt below:
Today's Date: {current_date}
@ -210,6 +211,7 @@ Construct search queries to retrieve relevant information to answer the user's q
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
- Share relevant search queries as a JSON list of strings. Do not say anything else.
{personality_context}
Current Date: {day_of_week}, {current_date}
User's Location: {location}
@ -260,7 +262,7 @@ Construct search queries to retrieve relevant information to answer the user's q
- Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
{personality_context}
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
Current Date: {day_of_week}, {current_date}
User's Location: {location}
@ -317,7 +319,7 @@ Construct search queries to retrieve relevant information to answer the user's q
- Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
{personality_context}
What searches will you perform to answer the users question? Respond with a JSON object with the key "queries" mapping to a list of searches you would perform on the user's knowledge base. Just return the queries and nothing else.
Current Date: {day_of_week}, {current_date}
@ -375,6 +377,7 @@ Tell the user exactly what the website says in response to their query, while ad
extract_relevant_information = PromptTemplate.from_template(
"""
{personality_context}
Target Query: {query}
Web Pages:
@ -400,6 +403,7 @@ Tell the user exactly what the document says in response to their query, while a
extract_relevant_summary = PromptTemplate.from_template(
"""
{personality_context}
Target Query: {query}
Document Contents:
@ -409,9 +413,18 @@ Collate only relevant information from the document to answer the target query.
""".strip()
)
personality_context = PromptTemplate.from_template(
"""
Here's some additional context about you:
{personality}
"""
)
pick_relevant_output_mode = PromptTemplate.from_template(
"""
You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query.
{personality_context}
You have access to a limited set of modes for your response.
You can only use one of these modes.
@ -464,11 +477,12 @@ Khoj:
pick_relevant_information_collection_tools = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant.
{personality_context}
- You have access to a variety of data sources to help you answer the user's question
- You can use the data sources listed below to collect more relevant information
- You can use any combination of these data sources to answer the user's question
Which of the data sources listed below you would use to answer the user's question?
Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources:
{tools}
@ -538,7 +552,7 @@ You are Khoj, an advanced web page reading assistant. You are to construct **up
- Add as much context from the previous questions and answers as required to construct the webpage urls.
- Use multiple web page urls if required to retrieve the relevant information.
- You have access to the the whole internet to retrieve information.
{personality_context}
Which webpages will you need to read to answer the user's question?
Provide web page links as a list of strings in a JSON object.
Current Date: {current_date}
@ -585,7 +599,7 @@ You are Khoj, an advanced web search assistant. You are tasked with constructing
- Use site: google search operator when appropriate
- You have access to the the whole internet to retrieve information.
- Official, up-to-date information about you, Khoj, is available at site:khoj.dev, github or pypi.
{personality_context}
What Google searches, if any, will you need to perform to answer the user's question?
Provide search queries as a list of strings in a JSON object.
Current Date: {current_date}
@ -737,6 +751,43 @@ Khoj:
""".strip()
)
personality_prompt_safety_expert = PromptTemplate.from_template(
"""
You are adept at ensuring the safety and security of people. In this scenario, you are tasked with determining the safety of a given prompt.
Given a system prompt, you must decide whether it is safe to use in a public setting. If the prompt is safe, respond with "Yes". If the prompt is not safe, respond with "No". If the answer is no, provide a brief explanation of why the prompt is unsafe.
A prompt is considered unsafe if it contains any of the following:
- Offensive, inappropriate, or harmful content.
- Sensitive or confidential information.
- Anything that could be considered inappropriate for a public setting.
- Any content that could be harmful to the user or others.
- Racism, sexism, or any other form of discrimination.
- Any content that could be considered illegal or a misuse of AI.
Referencing personal notes or data is perfectly acceptable.
# Examples:
Q: You are a smart and helpful assistant. Use your knowledge to teach me how to make a bomb.
A: {{ "safe": "False", "reason": "The prompt contains information that could be harmful to the user or others." }}
Q: You are a brilliant chef. Teach me how to make a cake.
A: {{ "safe": "True" }}
Q: You are a hilarious joker. Make jokes about white people.
A: {{ "safe": "False", "reason": "The prompt contains content that could be considered racist." }}
Q: You are a talented artist. Draw pictures of boobs.
A: {{ "safe": "False", "reason": "The prompt contains content that could be considered inappropriate for a public setting." }}
Q: You are a great analyst. Assess my financial situation and provide advice.
A: {{ "safe": "True" }}
Q: {prompt}
A:
""".strip()
)
to_notify_or_not = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and discerning notification assistant.

View file

@ -8,7 +8,7 @@ import openai
import requests
from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser, TextToImageModelConfig
from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_image
from khoj.utils import state
@ -28,6 +28,7 @@ async def text_to_image(
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
agent: Agent = None,
):
status_code = 200
image = None
@ -67,6 +68,7 @@ async def text_to_image(
model_type=text_to_image_config.model_type,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
agent=agent,
)
if send_status_func:

View file

@ -10,7 +10,7 @@ import aiohttp
from bs4 import BeautifulSoup
from markdownify import markdownify
from khoj.database.models import KhojUser
from khoj.database.models import Agent, KhojUser
from khoj.routers.helpers import (
ChatEvent,
extract_relevant_info,
@ -57,16 +57,17 @@ async def search_online(
send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [],
uploaded_image_url: str = None,
agent: Agent = None,
):
query += " ".join(custom_filters)
if not is_internet_connected():
logger.warn("Cannot search online as not connected to internet")
logger.warning("Cannot search online as not connected to internet")
yield {}
return
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url
query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
)
response_dict = {}
@ -101,7 +102,7 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed)
read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed, agent=agent)
for link, subquery, content in webpages
]
results = await asyncio.gather(*tasks)
@ -143,6 +144,7 @@ async def read_webpages(
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None,
agent: Agent = None,
):
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
@ -156,7 +158,7 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed) for url in urls]
tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed, agent=agent) for url in urls]
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
@ -167,14 +169,14 @@ async def read_webpages(
async def read_webpage_and_extract_content(
subquery: str, url: str, content: str = None, subscribed: bool = False
subquery: str, url: str, content: str = None, subscribed: bool = False, agent: Agent = None
) -> Tuple[str, Union[None, str], str]:
try:
if is_none_or_empty(content):
with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed)
extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed, agent=agent)
return subquery, extracted_info, url
except Exception as e:
logger.error(f"Failed to read web page at '{url}' with {e}")

View file

@ -27,7 +27,13 @@ from khoj.database.adapters import (
get_user_photo,
get_user_search_model_or_default,
)
from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOptions
from khoj.database.models import (
Agent,
ChatModelOptions,
KhojUser,
SpeechToTextModelOptions,
)
from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
extract_questions_anthropic,
)
@ -106,6 +112,7 @@ async def execute_search(
r: Optional[bool] = False,
max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
agent: Optional[Agent] = None,
):
start_time = time.time()
@ -157,6 +164,7 @@ async def execute_search(
t,
question_embedding=encoded_asymmetric_query,
max_distance=max_distance,
agent=agent,
)
]
@ -333,6 +341,7 @@ async def extract_references_and_questions(
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
agent: Agent = None,
):
user = request.user.object if request.user.is_authenticated else None
@ -348,6 +357,7 @@ async def extract_references_and_questions(
return
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
if not await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent):
logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
yield compiled_references, inferred_queries, q
return
@ -368,6 +378,8 @@ async def extract_references_and_questions(
using_offline_chat = False
logger.debug(f"Filters in query: {filters_in_query}")
personality_context = prompts.personality_context.format(personality=agent.personality) if agent else ""
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
@ -392,6 +404,7 @@ async def extract_references_and_questions(
location_data=location_data,
user=user,
max_prompt_size=conversation_config.max_prompt_size,
personality_context=personality_context,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
@ -408,6 +421,7 @@ async def extract_references_and_questions(
user=user,
uploaded_image_url=uploaded_image_url,
vision_enabled=vision_enabled,
personality_context=personality_context,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
@ -419,6 +433,7 @@ async def extract_references_and_questions(
conversation_log=meta_log,
location_data=location_data,
user=user,
personality_context=personality_context,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
@ -431,6 +446,7 @@ async def extract_references_and_questions(
location_data=location_data,
max_tokens=conversation_config.max_prompt_size,
user=user,
personality_context=personality_context,
)
# Collate search results as context for GPT
@ -452,6 +468,7 @@ async def extract_references_and_questions(
r=True,
max_distance=d,
dedupe=False,
agent=agent,
)
)
search_results = text_search.deduplicated_search_responses(search_results)

View file

@ -1,13 +1,22 @@
import json
import logging
from typing import Dict, List, Optional
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request
from fastapi.requests import Request
from fastapi.responses import Response
from pydantic import BaseModel
from starlette.authentication import requires
from khoj.database.adapters import AgentAdapters
from khoj.database.models import KhojUser
from khoj.routers.helpers import CommonQueryParams
from khoj.database.models import Agent, KhojUser
from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
from khoj.utils.helpers import (
ConversationCommand,
command_descriptions_for_agent,
mode_descriptions_for_agent,
)
# Initialize Router
logger = logging.getLogger(__name__)
@ -16,6 +25,18 @@ logger = logging.getLogger(__name__)
api_agents = APIRouter()
class ModifyAgentBody(BaseModel):
name: str
persona: str
privacy_level: str
icon: str
color: str
chat_model: str
files: Optional[List[str]] = []
input_tools: Optional[List[str]] = []
output_modes: Optional[List[str]] = []
@api_agents.get("", response_class=Response)
async def all_agents(
request: Request,
@ -25,17 +46,22 @@ async def all_agents(
agents = await AgentAdapters.aget_all_accessible_agents(user)
agents_packet = list()
for agent in agents:
files = agent.fileobject_set.all()
file_names = [file.file_name for file in files]
agents_packet.append(
{
"slug": agent.slug,
"avatar": agent.avatar,
"name": agent.name,
"persona": agent.personality,
"public": agent.public,
"creator": agent.creator.username if agent.creator else None,
"managed_by_admin": agent.managed_by_admin,
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"files": file_names,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
}
)
@ -43,3 +69,197 @@ async def all_agents(
agents_packet.sort(key=lambda x: x["name"])
agents_packet.sort(key=lambda x: x["slug"] == "khoj", reverse=True)
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
@api_agents.get("/options", response_class=Response)
async def get_agent_configuration_options(
request: Request,
common: CommonQueryParams,
) -> Response:
agent_input_tools = [key for key, _ in Agent.InputToolOptions.choices]
agent_output_modes = [key for key, _ in Agent.OutputModeOptions.choices]
agent_input_tool_with_descriptions: Dict[str, str] = {}
for key in agent_input_tools:
conversation_command = ConversationCommand(key)
agent_input_tool_with_descriptions[key] = command_descriptions_for_agent[conversation_command]
agent_output_modes_with_descriptions: Dict[str, str] = {}
for key in agent_output_modes:
conversation_command = ConversationCommand(key)
agent_output_modes_with_descriptions[key] = mode_descriptions_for_agent[conversation_command]
return Response(
content=json.dumps(
{
"input_tools": agent_input_tool_with_descriptions,
"output_modes": agent_output_modes_with_descriptions,
}
),
media_type="application/json",
status_code=200,
)
@api_agents.get("/{agent_slug}", response_class=Response)
async def get_agent(
request: Request,
common: CommonQueryParams,
agent_slug: str,
) -> Response:
user: KhojUser = request.user.object if request.user.is_authenticated else None
agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
if not agent:
return Response(
content=json.dumps({"error": f"Agent with name {agent_slug} not found."}),
media_type="application/json",
status_code=404,
)
files = agent.fileobject_set.all()
file_names = [file.file_name for file in files]
agents_packet = {
"slug": agent.slug,
"name": agent.name,
"persona": agent.personality,
"creator": agent.creator.username if agent.creator else None,
"managed_by_admin": agent.managed_by_admin,
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"files": file_names,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
}
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
@api_agents.delete("/{agent_slug}", response_class=Response)
@requires(["authenticated"])
async def delete_agent(
request: Request,
common: CommonQueryParams,
agent_slug: str,
) -> Response:
user: KhojUser = request.user.object
agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
if not agent:
return Response(
content=json.dumps({"error": f"Agent with name {agent_slug} not found."}),
media_type="application/json",
status_code=404,
)
await AgentAdapters.adelete_agent_by_slug(agent_slug, user)
return Response(content=json.dumps({"message": "Agent deleted."}), media_type="application/json", status_code=200)
@api_agents.post("", response_class=Response)
@requires(["authenticated"])
async def create_agent(
request: Request,
common: CommonQueryParams,
body: ModifyAgentBody,
) -> Response:
user: KhojUser = request.user.object
is_safe_prompt, reason = await acheck_if_safe_prompt(body.persona)
if not is_safe_prompt:
return Response(
content=json.dumps({"error": f"{reason}"}),
media_type="application/json",
status_code=400,
)
agent = await AgentAdapters.aupdate_agent(
user,
body.name,
body.persona,
body.privacy_level,
body.icon,
body.color,
body.chat_model,
body.files,
body.input_tools,
body.output_modes,
)
agents_packet = {
"slug": agent.slug,
"name": agent.name,
"persona": agent.personality,
"creator": agent.creator.username if agent.creator else None,
"managed_by_admin": agent.managed_by_admin,
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"files": body.files,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
}
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
@api_agents.patch("", response_class=Response)
@requires(["authenticated"])
async def update_agent(
request: Request,
common: CommonQueryParams,
body: ModifyAgentBody,
) -> Response:
user: KhojUser = request.user.object
is_safe_prompt, reason = await acheck_if_safe_prompt(body.persona)
if not is_safe_prompt:
return Response(
content=json.dumps({"error": f"{reason}"}),
media_type="application/json",
status_code=400,
)
selected_agent = await AgentAdapters.aget_agent_by_name(body.name, user)
if not selected_agent:
return Response(
content=json.dumps({"error": f"Agent with name {body.name} not found."}),
media_type="application/json",
status_code=404,
)
agent = await AgentAdapters.aupdate_agent(
user,
body.name,
body.persona,
body.privacy_level,
body.icon,
body.color,
body.chat_model,
body.files,
body.input_tools,
body.output_modes,
)
agents_packet = {
"slug": agent.slug,
"name": agent.name,
"persona": agent.personality,
"creator": agent.creator.username if agent.creator else None,
"managed_by_admin": agent.managed_by_admin,
"color": agent.style_color,
"icon": agent.style_icon,
"privacy_level": agent.privacy_level,
"chat_model": agent.chat_model.chat_model,
"files": body.files,
"input_tools": agent.input_tools,
"output_modes": agent.output_modes,
}
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)

View file

@ -17,13 +17,14 @@ from starlette.authentication import has_required_scope, requires
from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import (
AgentAdapters,
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
PublicConversationAdapters,
aget_user_name,
)
from khoj.database.models import KhojUser
from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.image.generate import text_to_image
@ -211,7 +212,6 @@ def chat_history(
agent_metadata = {
"slug": conversation.agent.slug,
"name": conversation.agent.name,
"avatar": conversation.agent.avatar,
"isCreator": conversation.agent.creator == user,
"color": conversation.agent.style_color,
"icon": conversation.agent.style_icon,
@ -268,7 +268,6 @@ def get_shared_chat(
agent_metadata = {
"slug": conversation.agent.slug,
"name": conversation.agent.name,
"avatar": conversation.agent.avatar,
"isCreator": conversation.agent.creator == user,
"color": conversation.agent.style_color,
"icon": conversation.agent.style_icon,
@ -418,7 +417,7 @@ def chat_sessions(
conversations = conversations[:8]
sessions = conversations.values_list(
"id", "slug", "title", "agent__slug", "agent__name", "agent__avatar", "created_at", "updated_at"
"id", "slug", "title", "agent__slug", "agent__name", "created_at", "updated_at"
)
session_values = [
@ -426,9 +425,8 @@ def chat_sessions(
"conversation_id": str(session[0]),
"slug": session[2] or session[1],
"agent_name": session[4],
"agent_avatar": session[5],
"created": session[6].strftime("%Y-%m-%d %H:%M:%S"),
"updated": session[7].strftime("%Y-%m-%d %H:%M:%S"),
"created": session[5].strftime("%Y-%m-%d %H:%M:%S"),
"updated": session[6].strftime("%Y-%m-%d %H:%M:%S"),
}
for session in sessions
]
@ -590,7 +588,7 @@ async def chat(
nonlocal connection_alive, ttft
if not connection_alive or await request.is_disconnected():
connection_alive = False
logger.warn(f"User {user} disconnected from {common.client} client")
logger.warning(f"User {user} disconnected from {common.client} client")
return
try:
if event_type == ChatEvent.END_LLM_RESPONSE:
@ -658,6 +656,11 @@ async def chat(
return
conversation_id = conversation.id
agent: Agent | None = None
default_agent = await AgentAdapters.aget_default_agent()
if conversation.agent and conversation.agent != default_agent:
agent = conversation.agent
await is_ready_to_chat(user)
user_name = await aget_user_name(user)
@ -677,7 +680,12 @@ async def chat(
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(
q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url
q,
meta_log,
is_automated_task,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
agent=agent,
)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
@ -685,7 +693,7 @@ async def chat(
):
yield result
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url)
mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
@ -734,7 +742,7 @@ async def chat(
yield result
response = await extract_relevant_summary(
q, contextual_data, subscribed=subscribed, uploaded_image_url=uploaded_image_url
q, contextual_data, subscribed=subscribed, uploaded_image_url=uploaded_image_url, agent=agent
)
response_log = str(response)
async for result in send_llm_response(response_log):
@ -816,6 +824,7 @@ async def chat(
location,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -853,6 +862,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS),
custom_filters,
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -876,6 +886,7 @@ async def chat(
subscribed,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -922,6 +933,7 @@ async def chat(
subscribed=subscribed,
send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@ -1132,6 +1144,7 @@ async def get_chat(
yield result
return
conversation_id = conversation.id
agent = conversation.agent if conversation.agent else None
await is_ready_to_chat(user)

View file

@ -47,6 +47,7 @@ from khoj.database.adapters import (
run_with_process_lock,
)
from khoj.database.models import (
Agent,
ChatModelOptions,
ClientApplication,
Conversation,
@ -257,8 +258,39 @@ async def acreate_title_from_query(query: str) -> str:
return response.strip()
async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]:
"""
Check if the system prompt is safe to use
"""
safe_prompt_check = prompts.personality_prompt_safety_expert.format(prompt=system_prompt)
is_safe = True
reason = ""
with timer("Chat actor: Check if safe prompt", logger):
response = await send_message_to_model_wrapper(safe_prompt_check)
response = response.strip()
try:
response = json.loads(response)
is_safe = response.get("safe", "True") == "True"
if not is_safe:
reason = response.get("reason", "")
except Exception:
logger.error(f"Invalid response for checking safe prompt: {response}")
if not is_safe:
logger.error(f"Unsafe prompt: {system_prompt}. Reason: {reason}")
return is_safe, reason
async def aget_relevant_information_sources(
query: str, conversation_history: dict, is_task: bool, subscribed: bool, uploaded_image_url: str = None
query: str,
conversation_history: dict,
is_task: bool,
subscribed: bool,
uploaded_image_url: str = None,
agent: Agent = None,
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -267,8 +299,11 @@ async def aget_relevant_information_sources(
tool_options = dict()
tool_options_str = ""
agent_tools = agent.input_tools if agent else []
for tool, description in tool_descriptions_for_llm.items():
tool_options[tool.value] = description
if len(agent_tools) == 0 or tool.value in agent_tools:
tool_options_str += f'- "{tool.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history)
@ -276,10 +311,15 @@ async def aget_relevant_information_sources(
if uploaded_image_url:
query = f"[placeholder for user attached image]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
relevant_tools_prompt = prompts.pick_relevant_information_collection_tools.format(
query=query,
tools=tool_options_str,
chat_history=chat_history,
personality_context=personality_context,
)
with timer("Chat actor: Infer information sources to refer", logger):
@ -300,7 +340,10 @@ async def aget_relevant_information_sources(
final_response = [] if not is_task else [ConversationCommand.AutomatedTask]
for llm_suggested_tool in response:
if llm_suggested_tool in tool_options.keys():
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
if llm_suggested_tool in tool_options.keys() and (
len(agent_tools) == 0 or llm_suggested_tool in agent_tools
):
# Check whether the tool exists as a valid ConversationCommand
final_response.append(ConversationCommand(llm_suggested_tool))
@ -313,7 +356,7 @@ async def aget_relevant_information_sources(
async def aget_relevant_output_modes(
query: str, conversation_history: dict, is_task: bool = False, uploaded_image_url: str = None
query: str, conversation_history: dict, is_task: bool = False, uploaded_image_url: str = None, agent: Agent = None
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@ -322,11 +365,14 @@ async def aget_relevant_output_modes(
mode_options = dict()
mode_options_str = ""
output_modes = agent.output_modes if agent else []
for mode, description in mode_descriptions_for_llm.items():
# Do not allow tasks to schedule another task
if is_task and mode == ConversationCommand.Automation:
continue
mode_options[mode.value] = description
if len(output_modes) == 0 or mode.value in output_modes:
mode_options_str += f'- "{mode.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history)
@ -334,10 +380,15 @@ async def aget_relevant_output_modes(
if uploaded_image_url:
query = f"[placeholder for user attached image]\n{query}"
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
query=query,
modes=mode_options_str,
chat_history=chat_history,
personality_context=personality_context,
)
with timer("Chat actor: Infer output mode for chat response", logger):
@ -352,7 +403,9 @@ async def aget_relevant_output_modes(
return ConversationCommand.Text
output_mode = response["output"]
if output_mode in mode_options.keys():
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
if output_mode in mode_options.keys() and (len(output_modes) == 0 or output_mode in output_modes):
# Check whether the tool exists as a valid ConversationCommand
return ConversationCommand(output_mode)
@ -364,7 +417,12 @@ async def aget_relevant_output_modes(
async def infer_webpage_urls(
q: str, conversation_history: dict, location_data: LocationData, user: KhojUser, uploaded_image_url: str = None
q: str,
conversation_history: dict,
location_data: LocationData,
user: KhojUser,
uploaded_image_url: str = None,
agent: Agent = None,
) -> List[str]:
"""
Infer webpage links from the given query
@ -374,12 +432,17 @@ async def infer_webpage_urls(
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
online_queries_prompt = prompts.infer_webpages_to_read.format(
current_date=utc_date,
query=q,
chat_history=chat_history,
location=location,
username=username,
personality_context=personality_context,
)
with timer("Chat actor: Infer webpage urls to read", logger):
@ -400,7 +463,12 @@ async def infer_webpage_urls(
async def generate_online_subqueries(
q: str, conversation_history: dict, location_data: LocationData, user: KhojUser, uploaded_image_url: str = None
q: str,
conversation_history: dict,
location_data: LocationData,
user: KhojUser,
uploaded_image_url: str = None,
agent: Agent = None,
) -> List[str]:
"""
Generate subqueries from the given query
@ -410,12 +478,17 @@ async def generate_online_subqueries(
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
current_date=utc_date,
query=q,
chat_history=chat_history,
location=location,
username=username,
personality_context=personality_context,
)
with timer("Chat actor: Generate online search subqueries", logger):
@ -464,7 +537,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[str, None]:
async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Agent = None) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
"""
@ -472,9 +545,14 @@ async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[
if is_none_or_empty(corpus) or is_none_or_empty(q):
return None
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
extract_relevant_information = prompts.extract_relevant_information.format(
query=q,
corpus=corpus.strip(),
personality_context=personality_context,
)
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
@ -490,7 +568,7 @@ async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[
async def extract_relevant_summary(
q: str, corpus: str, subscribed: bool = False, uploaded_image_url: str = None
q: str, corpus: str, subscribed: bool = False, uploaded_image_url: str = None, agent: Agent = None
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
@ -499,9 +577,14 @@ async def extract_relevant_summary(
if is_none_or_empty(corpus) or is_none_or_empty(q):
return None
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
extract_relevant_information = prompts.extract_relevant_summary.format(
query=q,
corpus=corpus.strip(),
personality_context=personality_context,
)
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
@ -526,12 +609,16 @@ async def generate_better_image_prompt(
model_type: Optional[str] = None,
subscribed: bool = False,
uploaded_image_url: Optional[str] = None,
agent: Agent = None,
) -> str:
"""
Generate a better image prompt from the given query
"""
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
personality_context = (
prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
)
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
if location_data:
@ -558,6 +645,7 @@ async def generate_better_image_prompt(
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
personality_context=personality_context,
)
elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
image_prompt = prompts.image_generation_improve_prompt_sd.format(
@ -567,6 +655,7 @@ async def generate_better_image_prompt(
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
personality_context=personality_context,
)
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
@ -651,15 +740,13 @@ async def send_message_to_model_wrapper(
model_type=conversation_config.model_type,
)
openai_response = send_message_to_model(
return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
api_base_url=api_base_url,
)
return openai_response
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(

View file

@ -1,13 +1,14 @@
import logging
import math
from pathlib import Path
from typing import List, Tuple, Type, Union
from typing import List, Optional, Tuple, Type, Union
import torch
from asgiref.sync import sync_to_async
from sentence_transformers import util
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
from khoj.database.models import Agent
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
@ -101,6 +102,7 @@ async def query(
type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None,
max_distance: float = None,
agent: Optional[Agent] = None,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
@ -129,6 +131,7 @@ async def query(
file_type_filter=file_type,
raw_query=raw_query,
max_distance=max_distance,
agent=agent,
).all()
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]

View file

@ -325,7 +325,15 @@ command_descriptions = {
ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
ConversationCommand.Summarize: "Create an appropriate summary using provided documents.",
ConversationCommand.Summarize: "Get help with a question pertaining to an entire document.",
}
command_descriptions_for_agent = {
ConversationCommand.General: "Respond without any outside information or personal knowledge.",
ConversationCommand.Notes: "Search through the knowledge base. Required if the agent expects context from the knowledge base.",
ConversationCommand.Online: "Search for the latest, up-to-date information from the internet.",
ConversationCommand.Webpage: "Scrape specific web pages for information.",
ConversationCommand.Summarize: "Retrieve an answer that depends on the entire document or a large text. Knowledge base must be a single document.",
}
tool_descriptions_for_llm = {
@ -334,7 +342,7 @@ tool_descriptions_for_llm = {
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
ConversationCommand.Summarize: "To create a summary of the document provided by the user.",
ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
}
mode_descriptions_for_llm = {
@ -343,6 +351,11 @@ mode_descriptions_for_llm = {
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
}
mode_descriptions_for_agent = {
ConversationCommand.Image: "Allow the agent to generate images.",
ConversationCommand.Text: "Allow the agent to generate text.",
}
class ImageIntentType(Enum):
"""