-
Synced files
List[Agent]:
@@ -609,12 +646,11 @@ class AgentAdapters:
# The default agent is public and managed by the admin. It's handled a little differently than other agents.
agent = Agent.objects.create(
name=AgentAdapters.DEFAULT_AGENT_NAME,
- public=True,
+ privacy_level=Agent.PrivacyLevel.PUBLIC,
managed_by_admin=True,
chat_model=default_conversation_config,
personality=default_personality,
tools=["*"],
- avatar=AgentAdapters.DEFAULT_AGENT_AVATAR,
slug=AgentAdapters.DEFAULT_AGENT_SLUG,
)
Conversation.objects.filter(agent=None).update(agent=agent)
@@ -625,6 +661,68 @@ class AgentAdapters:
async def aget_default_agent():
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
+ @staticmethod
+ async def aupdate_agent(
+ user: KhojUser,
+ name: str,
+ personality: str,
+ privacy_level: str,
+ icon: str,
+ color: str,
+ chat_model: str,
+ files: List[str],
+ input_tools: List[str],
+ output_modes: List[str],
+ ):
+ chat_model_option = await ChatModelOptions.objects.filter(chat_model=chat_model).afirst()
+
+ agent, created = await Agent.objects.filter(name=name, creator=user).aupdate_or_create(
+ defaults={
+ "name": name,
+ "creator": user,
+ "personality": personality,
+ "privacy_level": privacy_level,
+ "style_icon": icon,
+ "style_color": color,
+ "chat_model": chat_model_option,
+ "input_tools": input_tools,
+ "output_modes": output_modes,
+ }
+ )
+
+ # Delete all existing files and entries
+ await FileObject.objects.filter(agent=agent).adelete()
+ await Entry.objects.filter(agent=agent).adelete()
+
+ for file in files:
+ reference_file = await FileObject.objects.filter(file_name=file, user=agent.creator).afirst()
+ if reference_file:
+ await FileObject.objects.acreate(file_name=file, agent=agent, raw_text=reference_file.raw_text)
+
+ # Duplicate all entries associated with the file
+ entries: List[Entry] = []
+ async for entry in Entry.objects.filter(file_path=file, user=agent.creator).aiterator():
+ entries.append(
+ Entry(
+ agent=agent,
+ embeddings=entry.embeddings,
+ raw=entry.raw,
+ compiled=entry.compiled,
+ heading=entry.heading,
+ file_source=entry.file_source,
+ file_type=entry.file_type,
+ file_path=entry.file_path,
+ file_name=entry.file_name,
+ url=entry.url,
+ hashed_value=entry.hashed_value,
+ )
+ )
+
+ # Bulk create entries
+ await Entry.objects.abulk_create(entries)
+
+ return agent
+
class PublicConversationAdapters:
@staticmethod
@@ -1196,6 +1294,10 @@ class EntryAdapters:
def user_has_entries(user: KhojUser):
return Entry.objects.filter(user=user).exists()
+ @staticmethod
+ def agent_has_entries(agent: Agent):
+ return Entry.objects.filter(agent=agent).exists()
+
@staticmethod
async def auser_has_entries(user: KhojUser):
return await Entry.objects.filter(user=user).aexists()
@@ -1229,15 +1331,19 @@ class EntryAdapters:
return total_size / 1024 / 1024
@staticmethod
- def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
+ def apply_filters(user: KhojUser, query: str, file_type_filter: str = None, agent: Agent = None):
q_filter_terms = Q()
word_filters = EntryAdapters.word_filter.get_filter_terms(query)
file_filters = EntryAdapters.file_filter.get_filter_terms(query)
date_filters = EntryAdapters.date_filter.get_query_date_range(query)
+ user_or_agent = Q(user=user)
+ if agent != None:
+ user_or_agent |= Q(agent=agent)
+
if len(word_filters) == 0 and len(file_filters) == 0 and len(date_filters) == 0:
- return Entry.objects.filter(user=user)
+ return Entry.objects.filter(user_or_agent)
for term in word_filters:
if term.startswith("+"):
@@ -1273,7 +1379,7 @@ class EntryAdapters:
formatted_max_date = date.fromtimestamp(max_date).strftime("%Y-%m-%d")
q_filter_terms &= Q(embeddings_dates__date__lte=formatted_max_date)
- relevant_entries = Entry.objects.filter(user=user).filter(q_filter_terms)
+ relevant_entries = Entry.objects.filter(user_or_agent).filter(q_filter_terms)
if file_type_filter:
relevant_entries = relevant_entries.filter(file_type=file_type_filter)
return relevant_entries
@@ -1286,9 +1392,15 @@ class EntryAdapters:
file_type_filter: str = None,
raw_query: str = None,
max_distance: float = math.inf,
+ agent: Agent = None,
):
- relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter)
- relevant_entries = relevant_entries.filter(user=user).annotate(
+ user_or_agent = Q(user=user)
+
+ if agent != None:
+ user_or_agent |= Q(agent=agent)
+
+ relevant_entries = EntryAdapters.apply_filters(user, raw_query, file_type_filter, agent)
+ relevant_entries = relevant_entries.filter(user_or_agent).annotate(
distance=CosineDistance("embeddings", embeddings)
)
relevant_entries = relevant_entries.filter(distance__lte=max_distance)
diff --git a/src/khoj/database/migrations/0065_remove_agent_avatar_remove_agent_public_and_more.py b/src/khoj/database/migrations/0065_remove_agent_avatar_remove_agent_public_and_more.py
new file mode 100644
index 00000000..77951d44
--- /dev/null
+++ b/src/khoj/database/migrations/0065_remove_agent_avatar_remove_agent_public_and_more.py
@@ -0,0 +1,49 @@
+# Generated by Django 5.0.8 on 2024-09-18 02:54
+
+import django.db.models.deletion
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0064_remove_conversation_temp_id_alter_conversation_id"),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name="agent",
+ name="avatar",
+ ),
+ migrations.RemoveField(
+ model_name="agent",
+ name="public",
+ ),
+ migrations.AddField(
+ model_name="agent",
+ name="privacy_level",
+ field=models.CharField(
+ choices=[("public", "Public"), ("private", "Private"), ("protected", "Protected")],
+ default="private",
+ max_length=30,
+ ),
+ ),
+ migrations.AddField(
+ model_name="entry",
+ name="agent",
+ field=models.ForeignKey(
+ blank=True, default=None, null=True, on_delete=django.db.models.deletion.CASCADE, to="database.agent"
+ ),
+ ),
+ migrations.AddField(
+ model_name="fileobject",
+ name="agent",
+ field=models.ForeignKey(
+ blank=True, default=None, null=True, on_delete=django.db.models.deletion.CASCADE, to="database.agent"
+ ),
+ ),
+ migrations.AlterField(
+ model_name="agent",
+ name="slug",
+ field=models.CharField(max_length=200, unique=True),
+ ),
+ ]
diff --git a/src/khoj/database/migrations/0066_remove_agent_tools_agent_input_tools_and_more.py b/src/khoj/database/migrations/0066_remove_agent_tools_agent_input_tools_and_more.py
new file mode 100644
index 00000000..167d5cc5
--- /dev/null
+++ b/src/khoj/database/migrations/0066_remove_agent_tools_agent_input_tools_and_more.py
@@ -0,0 +1,69 @@
+# Generated by Django 5.0.8 on 2024-10-01 00:42
+
+import django.contrib.postgres.fields
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0065_remove_agent_avatar_remove_agent_public_and_more"),
+ ]
+
+ operations = [
+ migrations.RemoveField(
+ model_name="agent",
+ name="tools",
+ ),
+ migrations.AddField(
+ model_name="agent",
+ name="input_tools",
+ field=django.contrib.postgres.fields.ArrayField(
+ base_field=models.CharField(
+ choices=[
+ ("general", "General"),
+ ("online", "Online"),
+ ("notes", "Notes"),
+ ("summarize", "Summarize"),
+ ("webpage", "Webpage"),
+ ],
+ max_length=200,
+ ),
+ default=list,
+ size=None,
+ ),
+ ),
+ migrations.AddField(
+ model_name="agent",
+ name="output_modes",
+ field=django.contrib.postgres.fields.ArrayField(
+ base_field=models.CharField(choices=[("text", "Text"), ("image", "Image")], max_length=200),
+ default=list,
+ size=None,
+ ),
+ ),
+ migrations.AlterField(
+ model_name="agent",
+ name="style_icon",
+ field=models.CharField(
+ choices=[
+ ("Lightbulb", "Lightbulb"),
+ ("Health", "Health"),
+ ("Robot", "Robot"),
+ ("Aperture", "Aperture"),
+ ("GraduationCap", "Graduation Cap"),
+ ("Jeep", "Jeep"),
+ ("Island", "Island"),
+ ("MathOperations", "Math Operations"),
+ ("Asclepius", "Asclepius"),
+ ("Couch", "Couch"),
+ ("Code", "Code"),
+ ("Atom", "Atom"),
+ ("ClockCounterClockwise", "Clock Counter Clockwise"),
+ ("PencilLine", "Pencil Line"),
+ ("Chalkboard", "Chalkboard"),
+ ],
+ default="Lightbulb",
+ max_length=200,
+ ),
+ ),
+ ]
diff --git a/src/khoj/database/migrations/0067_alter_agent_style_icon.py b/src/khoj/database/migrations/0067_alter_agent_style_icon.py
new file mode 100644
index 00000000..1bff530a
--- /dev/null
+++ b/src/khoj/database/migrations/0067_alter_agent_style_icon.py
@@ -0,0 +1,50 @@
+# Generated by Django 5.0.8 on 2024-10-01 18:42
+
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0066_remove_agent_tools_agent_input_tools_and_more"),
+ ]
+
+ operations = [
+ migrations.AlterField(
+ model_name="agent",
+ name="style_icon",
+ field=models.CharField(
+ choices=[
+ ("Lightbulb", "Lightbulb"),
+ ("Health", "Health"),
+ ("Robot", "Robot"),
+ ("Aperture", "Aperture"),
+ ("GraduationCap", "Graduation Cap"),
+ ("Jeep", "Jeep"),
+ ("Island", "Island"),
+ ("MathOperations", "Math Operations"),
+ ("Asclepius", "Asclepius"),
+ ("Couch", "Couch"),
+ ("Code", "Code"),
+ ("Atom", "Atom"),
+ ("ClockCounterClockwise", "Clock Counter Clockwise"),
+ ("PencilLine", "Pencil Line"),
+ ("Chalkboard", "Chalkboard"),
+ ("Cigarette", "Cigarette"),
+ ("CraneTower", "Crane Tower"),
+ ("Heart", "Heart"),
+ ("Leaf", "Leaf"),
+ ("NewspaperClipping", "Newspaper Clipping"),
+ ("OrangeSlice", "Orange Slice"),
+ ("SmileyMelting", "Smiley Melting"),
+ ("YinYang", "Yin Yang"),
+ ("SneakerMove", "Sneaker Move"),
+ ("Student", "Student"),
+ ("Oven", "Oven"),
+ ("Gavel", "Gavel"),
+ ("Broadcast", "Broadcast"),
+ ],
+ default="Lightbulb",
+ max_length=200,
+ ),
+ ),
+ ]
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index ed91a027..dfe91b14 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -3,6 +3,7 @@ import uuid
from random import choice
from django.contrib.auth.models import AbstractUser
+from django.contrib.postgres.fields import ArrayField
from django.core.exceptions import ValidationError
from django.db import models
from django.db.models.signals import pre_save
@@ -10,6 +11,8 @@ from django.dispatch import receiver
from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField
+from khoj.utils.helpers import ConversationCommand
+
class BaseModel(models.Model):
created_at = models.DateTimeField(auto_now_add=True)
@@ -125,7 +128,7 @@ class Agent(BaseModel):
EMERALD = "emerald"
class StyleIconTypes(models.TextChoices):
- LIGHBULB = "Lightbulb"
+ LIGHTBULB = "Lightbulb"
HEALTH = "Health"
ROBOT = "Robot"
APERTURE = "Aperture"
@@ -140,20 +143,64 @@ class Agent(BaseModel):
CLOCK_COUNTER_CLOCKWISE = "ClockCounterClockwise"
PENCIL_LINE = "PencilLine"
CHALKBOARD = "Chalkboard"
+ CIGARETTE = "Cigarette"
+ CRANE_TOWER = "CraneTower"
+ HEART = "Heart"
+ LEAF = "Leaf"
+ NEWSPAPER_CLIPPING = "NewspaperClipping"
+ ORANGE_SLICE = "OrangeSlice"
+ SMILEY_MELTING = "SmileyMelting"
+ YIN_YANG = "YinYang"
+ SNEAKER_MOVE = "SneakerMove"
+ STUDENT = "Student"
+ OVEN = "Oven"
+ GAVEL = "Gavel"
+ BROADCAST = "Broadcast"
+
+ class PrivacyLevel(models.TextChoices):
+ PUBLIC = "public"
+ PRIVATE = "private"
+ PROTECTED = "protected"
+
+ class InputToolOptions(models.TextChoices):
+ # These map to various ConversationCommand types
+ GENERAL = "general"
+ ONLINE = "online"
+ NOTES = "notes"
+ SUMMARIZE = "summarize"
+ WEBPAGE = "webpage"
+
+ class OutputModeOptions(models.TextChoices):
+ # These map to various ConversationCommand types
+ TEXT = "text"
+ IMAGE = "image"
creator = models.ForeignKey(
KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True
) # Creator will only be null when the agents are managed by admin
name = models.CharField(max_length=200)
personality = models.TextField()
- avatar = models.URLField(max_length=400, default=None, null=True, blank=True)
- tools = models.JSONField(default=list) # List of tools the agent has access to, like online search or notes search
- public = models.BooleanField(default=False)
+ input_tools = ArrayField(models.CharField(max_length=200, choices=InputToolOptions.choices), default=list)
+ output_modes = ArrayField(models.CharField(max_length=200, choices=OutputModeOptions.choices), default=list)
managed_by_admin = models.BooleanField(default=False)
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
- slug = models.CharField(max_length=200)
+ slug = models.CharField(max_length=200, unique=True)
style_color = models.CharField(max_length=200, choices=StyleColorTypes.choices, default=StyleColorTypes.BLUE)
- style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHBULB)
+ style_icon = models.CharField(max_length=200, choices=StyleIconTypes.choices, default=StyleIconTypes.LIGHTBULB)
+ privacy_level = models.CharField(max_length=30, choices=PrivacyLevel.choices, default=PrivacyLevel.PRIVATE)
+
+ def save(self, *args, **kwargs):
+ is_new = self._state.adding
+
+ if self.creator is None:
+ self.managed_by_admin = True
+
+ if is_new:
+ random_sequence = "".join(choice("0123456789") for i in range(6))
+ slug = f"{self.name.lower().replace(' ', '-')}-{random_sequence}"
+ self.slug = slug
+
+ super().save(*args, **kwargs)
class ProcessLock(BaseModel):
@@ -173,22 +220,11 @@ class ProcessLock(BaseModel):
def verify_agent(sender, instance, **kwargs):
# check if this is a new instance
if instance._state.adding:
- if Agent.objects.filter(name=instance.name, public=True).exists():
+ if Agent.objects.filter(name=instance.name, privacy_level=Agent.PrivacyLevel.PUBLIC).exists():
raise ValidationError(f"A public Agent with the name {instance.name} already exists.")
if Agent.objects.filter(name=instance.name, creator=instance.creator).exists():
raise ValidationError(f"A private Agent with the name {instance.name} already exists.")
- slug = instance.name.lower().replace(" ", "-")
- observed_random_numbers = set()
- while Agent.objects.filter(slug=slug).exists():
- try:
- random_number = choice([i for i in range(0, 1000) if i not in observed_random_numbers])
- except IndexError:
- raise ValidationError("Unable to generate a unique slug for the Agent. Please try again later.")
- observed_random_numbers.add(random_number)
- slug = f"{slug}-{random_number}"
- instance.slug = slug
-
class NotionConfig(BaseModel):
token = models.CharField(max_length=200)
@@ -406,6 +442,7 @@ class Entry(BaseModel):
GITHUB = "github"
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
+ agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
embeddings = VectorField(dimensions=None)
raw = models.TextField()
compiled = models.TextField()
@@ -418,12 +455,17 @@ class Entry(BaseModel):
hashed_value = models.CharField(max_length=100)
corpus_id = models.UUIDField(default=uuid.uuid4, editable=False)
+ def save(self, *args, **kwargs):
+ if self.user and self.agent:
+ raise ValidationError("An Entry cannot be associated with both a user and an agent.")
+
class FileObject(BaseModel):
# Same as Entry but raw will be a much larger string
file_name = models.CharField(max_length=400, default=None, null=True, blank=True)
raw_text = models.TextField()
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
+ agent = models.ForeignKey(Agent, on_delete=models.CASCADE, default=None, null=True, blank=True)
class EntryDates(BaseModel):
diff --git a/src/khoj/processor/conversation/anthropic/anthropic_chat.py b/src/khoj/processor/conversation/anthropic/anthropic_chat.py
index 0309d29b..cb51abb4 100644
--- a/src/khoj/processor/conversation/anthropic/anthropic_chat.py
+++ b/src/khoj/processor/conversation/anthropic/anthropic_chat.py
@@ -27,6 +27,7 @@ def extract_questions_anthropic(
temperature=0.7,
location_data: LocationData = None,
user: KhojUser = None,
+ personality_context: Optional[str] = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -59,6 +60,7 @@ def extract_questions_anthropic(
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
+ personality_context=personality_context,
)
prompt = prompts.extract_questions_anthropic_user_message.format(
diff --git a/src/khoj/processor/conversation/google/gemini_chat.py b/src/khoj/processor/conversation/google/gemini_chat.py
index a2ccc87b..cd55b0ff 100644
--- a/src/khoj/processor/conversation/google/gemini_chat.py
+++ b/src/khoj/processor/conversation/google/gemini_chat.py
@@ -28,6 +28,7 @@ def extract_questions_gemini(
max_tokens=None,
location_data: LocationData = None,
user: KhojUser = None,
+ personality_context: Optional[str] = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -60,6 +61,7 @@ def extract_questions_gemini(
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
+ personality_context=personality_context,
)
prompt = prompts.extract_questions_anthropic_user_message.format(
diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py
index a8229332..4eafae00 100644
--- a/src/khoj/processor/conversation/offline/chat_model.py
+++ b/src/khoj/processor/conversation/offline/chat_model.py
@@ -2,7 +2,7 @@ import json
import logging
from datetime import datetime, timedelta
from threading import Thread
-from typing import Any, Iterator, List, Union
+from typing import Any, Iterator, List, Optional, Union
from langchain.schema import ChatMessage
from llama_cpp import Llama
@@ -33,6 +33,7 @@ def extract_questions_offline(
user: KhojUser = None,
max_prompt_size: int = None,
temperature: float = 0.7,
+ personality_context: Optional[str] = None,
) -> List[str]:
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -73,6 +74,7 @@ def extract_questions_offline(
this_year=today.year,
location=location,
username=username,
+ personality_context=personality_context,
)
messages = generate_chatml_messages_with_context(
diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py
index 1361e1ae..ad02b10e 100644
--- a/src/khoj/processor/conversation/openai/gpt.py
+++ b/src/khoj/processor/conversation/openai/gpt.py
@@ -32,6 +32,7 @@ def extract_questions(
user: KhojUser = None,
uploaded_image_url: Optional[str] = None,
vision_enabled: bool = False,
+ personality_context: Optional[str] = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
@@ -68,6 +69,7 @@ def extract_questions(
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
+ personality_context=personality_context,
)
prompt = construct_structured_message(
diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py
index 13e2b72c..3d91f4ef 100644
--- a/src/khoj/processor/conversation/prompts.py
+++ b/src/khoj/processor/conversation/prompts.py
@@ -129,6 +129,7 @@ User's Notes:
image_generation_improve_prompt_base = """
You are a talented media artist with the ability to describe images to compose in professional, fine detail.
+{personality_context}
Generate a vivid description of the image to be rendered using the provided context and user prompt below:
Today's Date: {current_date}
@@ -210,6 +211,7 @@ Construct search queries to retrieve relevant information to answer the user's q
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
- Share relevant search queries as a JSON list of strings. Do not say anything else.
+{personality_context}
Current Date: {day_of_week}, {current_date}
User's Location: {location}
@@ -260,7 +262,7 @@ Construct search queries to retrieve relevant information to answer the user's q
- Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
-
+{personality_context}
What searches will you perform to answer the users question? Respond with search queries as list of strings in a JSON object.
Current Date: {day_of_week}, {current_date}
User's Location: {location}
@@ -317,7 +319,7 @@ Construct search queries to retrieve relevant information to answer the user's q
- Break messages into multiple search queries when required to retrieve the relevant information.
- Add date filters to your search queries from questions and answers when required to retrieve the relevant information.
- When asked a meta, vague or random questions, search for a variety of broad topics to answer the user's question.
-
+{personality_context}
What searches will you perform to answer the users question? Respond with a JSON object with the key "queries" mapping to a list of searches you would perform on the user's knowledge base. Just return the queries and nothing else.
Current Date: {day_of_week}, {current_date}
@@ -375,6 +377,7 @@ Tell the user exactly what the website says in response to their query, while ad
extract_relevant_information = PromptTemplate.from_template(
"""
+{personality_context}
Target Query: {query}
Web Pages:
@@ -400,6 +403,7 @@ Tell the user exactly what the document says in response to their query, while a
extract_relevant_summary = PromptTemplate.from_template(
"""
+{personality_context}
Target Query: {query}
Document Contents:
@@ -409,9 +413,18 @@ Collate only relevant information from the document to answer the target query.
""".strip()
)
+personality_context = PromptTemplate.from_template(
+ """
+ Here's some additional context about you:
+ {personality}
+
+ """
+)
+
pick_relevant_output_mode = PromptTemplate.from_template(
"""
You are Khoj, an excellent analyst for selecting the correct way to respond to a user's query.
+{personality_context}
You have access to a limited set of modes for your response.
You can only use one of these modes.
@@ -464,11 +477,12 @@ Khoj:
pick_relevant_information_collection_tools = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and helpful search assistant.
+{personality_context}
- You have access to a variety of data sources to help you answer the user's question
- You can use the data sources listed below to collect more relevant information
- You can use any combination of these data sources to answer the user's question
-Which of the data sources listed below you would use to answer the user's question?
+Which of the data sources listed below you would use to answer the user's question? You **only** have access to the following data sources:
{tools}
@@ -538,7 +552,7 @@ You are Khoj, an advanced web page reading assistant. You are to construct **up
- Add as much context from the previous questions and answers as required to construct the webpage urls.
- Use multiple web page urls if required to retrieve the relevant information.
- You have access to the the whole internet to retrieve information.
-
+{personality_context}
Which webpages will you need to read to answer the user's question?
Provide web page links as a list of strings in a JSON object.
Current Date: {current_date}
@@ -585,7 +599,7 @@ You are Khoj, an advanced web search assistant. You are tasked with constructing
- Use site: google search operator when appropriate
- You have access to the the whole internet to retrieve information.
- Official, up-to-date information about you, Khoj, is available at site:khoj.dev, github or pypi.
-
+{personality_context}
What Google searches, if any, will you need to perform to answer the user's question?
Provide search queries as a list of strings in a JSON object.
Current Date: {current_date}
@@ -737,6 +751,43 @@ Khoj:
""".strip()
)
+personality_prompt_safety_expert = PromptTemplate.from_template(
+ """
+You are adept at ensuring the safety and security of people. In this scenario, you are tasked with determining the safety of a given prompt.
+
+Given a system prompt, you must decide whether it is safe to use in a public setting. If the prompt is safe, respond with "Yes". If the prompt is not safe, respond with "No". If the answer is no, provide a brief explanation of why the prompt is unsafe.
+
+A prompt is considered unsafe if it contains any of the following:
+- Offensive, inappropriate, or harmful content.
+- Sensitive or confidential information.
+- Anything that could be considered inappropriate for a public setting.
+- Any content that could be harmful to the user or others.
+- Racism, sexism, or any other form of discrimination.
+- Any content that could be considered illegal or a misuse of AI.
+
+Referencing personal notes or data is perfectly acceptable.
+
+# Examples:
+Q: You are a smart and helpful assistant. Use your knowledge to teach me how to make a bomb.
+A: {{ "safe": "False", "reason": "The prompt contains information that could be harmful to the user or others." }}
+
+Q: You are a brilliant chef. Teach me how to make a cake.
+A: {{ "safe": "True" }}
+
+Q: You are a hilarious joker. Make jokes about white people.
+A: {{ "safe": "False", "reason": "The prompt contains content that could be considered racist." }}
+
+Q: You are a talented artist. Draw pictures of boobs.
+A: {{ "safe": "False", "reason": "The prompt contains content that could be considered inappropriate for a public setting." }}
+
+Q: You are a great analyst. Assess my financial situation and provide advice.
+A: {{ "safe": "True" }}
+
+Q: {prompt}
+A:
+""".strip()
+)
+
to_notify_or_not = PromptTemplate.from_template(
"""
You are Khoj, an extremely smart and discerning notification assistant.
diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py
index 200473eb..ef7105ca 100644
--- a/src/khoj/processor/image/generate.py
+++ b/src/khoj/processor/image/generate.py
@@ -8,7 +8,7 @@ import openai
import requests
from khoj.database.adapters import ConversationAdapters
-from khoj.database.models import KhojUser, TextToImageModelConfig
+from khoj.database.models import Agent, KhojUser, TextToImageModelConfig
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_image
from khoj.utils import state
@@ -28,6 +28,7 @@ async def text_to_image(
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
+ agent: Agent = None,
):
status_code = 200
image = None
@@ -67,6 +68,7 @@ async def text_to_image(
model_type=text_to_image_config.model_type,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
+ agent=agent,
)
if send_status_func:
diff --git a/src/khoj/processor/tools/online_search.py b/src/khoj/processor/tools/online_search.py
index 3627c58e..393442c4 100644
--- a/src/khoj/processor/tools/online_search.py
+++ b/src/khoj/processor/tools/online_search.py
@@ -10,7 +10,7 @@ import aiohttp
from bs4 import BeautifulSoup
from markdownify import markdownify
-from khoj.database.models import KhojUser
+from khoj.database.models import Agent, KhojUser
from khoj.routers.helpers import (
ChatEvent,
extract_relevant_info,
@@ -57,16 +57,17 @@ async def search_online(
send_status_func: Optional[Callable] = None,
custom_filters: List[str] = [],
uploaded_image_url: str = None,
+ agent: Agent = None,
):
query += " ".join(custom_filters)
if not is_internet_connected():
- logger.warn("Cannot search online as not connected to internet")
+ logger.warning("Cannot search online as not connected to internet")
yield {}
return
# Breakdown the query into subqueries to get the correct answer
subqueries = await generate_online_subqueries(
- query, conversation_history, location, user, uploaded_image_url=uploaded_image_url
+ query, conversation_history, location, user, uploaded_image_url=uploaded_image_url, agent=agent
)
response_dict = {}
@@ -101,7 +102,7 @@ async def search_online(
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
tasks = [
- read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed)
+ read_webpage_and_extract_content(subquery, link, content, subscribed=subscribed, agent=agent)
for link, subquery, content in webpages
]
results = await asyncio.gather(*tasks)
@@ -143,6 +144,7 @@ async def read_webpages(
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: str = None,
+ agent: Agent = None,
):
"Infer web pages to read from the query and extract relevant information from them"
logger.info(f"Inferring web pages to read")
@@ -156,7 +158,7 @@ async def read_webpages(
webpage_links_str = "\n- " + "\n- ".join(list(urls))
async for event in send_status_func(f"**Reading web pages**: {webpage_links_str}"):
yield {ChatEvent.STATUS: event}
- tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed) for url in urls]
+ tasks = [read_webpage_and_extract_content(query, url, subscribed=subscribed, agent=agent) for url in urls]
results = await asyncio.gather(*tasks)
response: Dict[str, Dict] = defaultdict(dict)
@@ -167,14 +169,14 @@ async def read_webpages(
async def read_webpage_and_extract_content(
- subquery: str, url: str, content: str = None, subscribed: bool = False
+ subquery: str, url: str, content: str = None, subscribed: bool = False, agent: Agent = None
) -> Tuple[str, Union[None, str], str]:
try:
if is_none_or_empty(content):
with timer(f"Reading web page at '{url}' took", logger):
content = await read_webpage_with_olostep(url) if OLOSTEP_API_KEY else await read_webpage_with_jina(url)
with timer(f"Extracting relevant information from web page at '{url}' took", logger):
- extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed)
+ extracted_info = await extract_relevant_info(subquery, content, subscribed=subscribed, agent=agent)
return subquery, extracted_info, url
except Exception as e:
logger.error(f"Failed to read web page at '{url}' with {e}")
diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py
index 854285a5..73123c5b 100644
--- a/src/khoj/routers/api.py
+++ b/src/khoj/routers/api.py
@@ -27,7 +27,13 @@ from khoj.database.adapters import (
get_user_photo,
get_user_search_model_or_default,
)
-from khoj.database.models import ChatModelOptions, KhojUser, SpeechToTextModelOptions
+from khoj.database.models import (
+ Agent,
+ ChatModelOptions,
+ KhojUser,
+ SpeechToTextModelOptions,
+)
+from khoj.processor.conversation import prompts
from khoj.processor.conversation.anthropic.anthropic_chat import (
extract_questions_anthropic,
)
@@ -106,6 +112,7 @@ async def execute_search(
r: Optional[bool] = False,
max_distance: Optional[Union[float, None]] = None,
dedupe: Optional[bool] = True,
+ agent: Optional[Agent] = None,
):
start_time = time.time()
@@ -157,6 +164,7 @@ async def execute_search(
t,
question_embedding=encoded_asymmetric_query,
max_distance=max_distance,
+ agent=agent,
)
]
@@ -333,6 +341,7 @@ async def extract_references_and_questions(
location_data: LocationData = None,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
+ agent: Agent = None,
):
user = request.user.object if request.user.is_authenticated else None
@@ -348,9 +357,10 @@ async def extract_references_and_questions(
return
if not await sync_to_async(EntryAdapters.user_has_entries)(user=user):
- logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
- yield compiled_references, inferred_queries, q
- return
+ if not await sync_to_async(EntryAdapters.agent_has_entries)(agent=agent):
+ logger.debug("No documents in knowledge base. Use a Khoj client to sync and chat with your docs.")
+ yield compiled_references, inferred_queries, q
+ return
# Extract filter terms from user message
defiltered_query = q
@@ -368,6 +378,8 @@ async def extract_references_and_questions(
using_offline_chat = False
logger.debug(f"Filters in query: {filters_in_query}")
+ personality_context = prompts.personality_context.format(personality=agent.personality) if agent else ""
+
# Infer search queries from user message
with timer("Extracting search queries took", logger):
# If we've reached here, either the user has enabled offline chat or the openai model is enabled.
@@ -392,6 +404,7 @@ async def extract_references_and_questions(
location_data=location_data,
user=user,
max_prompt_size=conversation_config.max_prompt_size,
+ personality_context=personality_context,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.OPENAI:
openai_chat_config = conversation_config.openai_config
@@ -408,6 +421,7 @@ async def extract_references_and_questions(
user=user,
uploaded_image_url=uploaded_image_url,
vision_enabled=vision_enabled,
+ personality_context=personality_context,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
@@ -419,6 +433,7 @@ async def extract_references_and_questions(
conversation_log=meta_log,
location_data=location_data,
user=user,
+ personality_context=personality_context,
)
elif conversation_config.model_type == ChatModelOptions.ModelType.GOOGLE:
api_key = conversation_config.openai_config.api_key
@@ -431,6 +446,7 @@ async def extract_references_and_questions(
location_data=location_data,
max_tokens=conversation_config.max_prompt_size,
user=user,
+ personality_context=personality_context,
)
# Collate search results as context for GPT
@@ -452,6 +468,7 @@ async def extract_references_and_questions(
r=True,
max_distance=d,
dedupe=False,
+ agent=agent,
)
)
search_results = text_search.deduplicated_search_responses(search_results)
diff --git a/src/khoj/routers/api_agents.py b/src/khoj/routers/api_agents.py
index 9513828a..5f32fd45 100644
--- a/src/khoj/routers/api_agents.py
+++ b/src/khoj/routers/api_agents.py
@@ -1,13 +1,22 @@
import json
import logging
+from typing import Dict, List, Optional
+from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request
from fastapi.requests import Request
from fastapi.responses import Response
+from pydantic import BaseModel
+from starlette.authentication import requires
from khoj.database.adapters import AgentAdapters
-from khoj.database.models import KhojUser
-from khoj.routers.helpers import CommonQueryParams
+from khoj.database.models import Agent, KhojUser
+from khoj.routers.helpers import CommonQueryParams, acheck_if_safe_prompt
+from khoj.utils.helpers import (
+ ConversationCommand,
+ command_descriptions_for_agent,
+ mode_descriptions_for_agent,
+)
# Initialize Router
logger = logging.getLogger(__name__)
@@ -16,6 +25,18 @@ logger = logging.getLogger(__name__)
api_agents = APIRouter()
+class ModifyAgentBody(BaseModel):
+ name: str
+ persona: str
+ privacy_level: str
+ icon: str
+ color: str
+ chat_model: str
+ files: Optional[List[str]] = []
+ input_tools: Optional[List[str]] = []
+ output_modes: Optional[List[str]] = []
+
+
@api_agents.get("", response_class=Response)
async def all_agents(
request: Request,
@@ -25,17 +46,22 @@ async def all_agents(
agents = await AgentAdapters.aget_all_accessible_agents(user)
agents_packet = list()
for agent in agents:
+ files = agent.fileobject_set.all()
+ file_names = [file.file_name for file in files]
agents_packet.append(
{
"slug": agent.slug,
- "avatar": agent.avatar,
"name": agent.name,
"persona": agent.personality,
- "public": agent.public,
"creator": agent.creator.username if agent.creator else None,
"managed_by_admin": agent.managed_by_admin,
"color": agent.style_color,
"icon": agent.style_icon,
+ "privacy_level": agent.privacy_level,
+ "chat_model": agent.chat_model.chat_model,
+ "files": file_names,
+ "input_tools": agent.input_tools,
+ "output_modes": agent.output_modes,
}
)
@@ -43,3 +69,197 @@ async def all_agents(
agents_packet.sort(key=lambda x: x["name"])
agents_packet.sort(key=lambda x: x["slug"] == "khoj", reverse=True)
return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
+
+
+@api_agents.get("/options", response_class=Response)
+async def get_agent_configuration_options(
+ request: Request,
+ common: CommonQueryParams,
+) -> Response:
+ agent_input_tools = [key for key, _ in Agent.InputToolOptions.choices]
+ agent_output_modes = [key for key, _ in Agent.OutputModeOptions.choices]
+
+ agent_input_tool_with_descriptions: Dict[str, str] = {}
+ for key in agent_input_tools:
+ conversation_command = ConversationCommand(key)
+ agent_input_tool_with_descriptions[key] = command_descriptions_for_agent[conversation_command]
+
+ agent_output_modes_with_descriptions: Dict[str, str] = {}
+ for key in agent_output_modes:
+ conversation_command = ConversationCommand(key)
+ agent_output_modes_with_descriptions[key] = mode_descriptions_for_agent[conversation_command]
+
+ return Response(
+ content=json.dumps(
+ {
+ "input_tools": agent_input_tool_with_descriptions,
+ "output_modes": agent_output_modes_with_descriptions,
+ }
+ ),
+ media_type="application/json",
+ status_code=200,
+ )
+
+
+@api_agents.get("/{agent_slug}", response_class=Response)
+async def get_agent(
+ request: Request,
+ common: CommonQueryParams,
+ agent_slug: str,
+) -> Response:
+ user: KhojUser = request.user.object if request.user.is_authenticated else None
+ agent = await AgentAdapters.aget_readonly_agent_by_slug(agent_slug, user)
+
+ if not agent:
+ return Response(
+ content=json.dumps({"error": f"Agent with name {agent_slug} not found."}),
+ media_type="application/json",
+ status_code=404,
+ )
+
+ files = agent.fileobject_set.all()
+ file_names = [file.file_name for file in files]
+ agents_packet = {
+ "slug": agent.slug,
+ "name": agent.name,
+ "persona": agent.personality,
+ "creator": agent.creator.username if agent.creator else None,
+ "managed_by_admin": agent.managed_by_admin,
+ "color": agent.style_color,
+ "icon": agent.style_icon,
+ "privacy_level": agent.privacy_level,
+ "chat_model": agent.chat_model.chat_model,
+ "files": file_names,
+ "input_tools": agent.input_tools,
+ "output_modes": agent.output_modes,
+ }
+
+ return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
+
+
+@api_agents.delete("/{agent_slug}", response_class=Response)
+@requires(["authenticated"])
+async def delete_agent(
+ request: Request,
+ common: CommonQueryParams,
+ agent_slug: str,
+) -> Response:
+ user: KhojUser = request.user.object
+
+ agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user)
+
+ if not agent:
+ return Response(
+ content=json.dumps({"error": f"Agent with name {agent_slug} not found."}),
+ media_type="application/json",
+ status_code=404,
+ )
+
+ await AgentAdapters.adelete_agent_by_slug(agent_slug, user)
+
+ return Response(content=json.dumps({"message": "Agent deleted."}), media_type="application/json", status_code=200)
+
+
+@api_agents.post("", response_class=Response)
+@requires(["authenticated"])
+async def create_agent(
+ request: Request,
+ common: CommonQueryParams,
+ body: ModifyAgentBody,
+) -> Response:
+ user: KhojUser = request.user.object
+
+ is_safe_prompt, reason = await acheck_if_safe_prompt(body.persona)
+ if not is_safe_prompt:
+ return Response(
+ content=json.dumps({"error": f"{reason}"}),
+ media_type="application/json",
+ status_code=400,
+ )
+
+ agent = await AgentAdapters.aupdate_agent(
+ user,
+ body.name,
+ body.persona,
+ body.privacy_level,
+ body.icon,
+ body.color,
+ body.chat_model,
+ body.files,
+ body.input_tools,
+ body.output_modes,
+ )
+
+ agents_packet = {
+ "slug": agent.slug,
+ "name": agent.name,
+ "persona": agent.personality,
+ "creator": agent.creator.username if agent.creator else None,
+ "managed_by_admin": agent.managed_by_admin,
+ "color": agent.style_color,
+ "icon": agent.style_icon,
+ "privacy_level": agent.privacy_level,
+ "chat_model": agent.chat_model.chat_model,
+ "files": body.files,
+ "input_tools": agent.input_tools,
+ "output_modes": agent.output_modes,
+ }
+
+ return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
+
+
+@api_agents.patch("", response_class=Response)
+@requires(["authenticated"])
+async def update_agent(
+ request: Request,
+ common: CommonQueryParams,
+ body: ModifyAgentBody,
+) -> Response:
+ user: KhojUser = request.user.object
+
+ is_safe_prompt, reason = await acheck_if_safe_prompt(body.persona)
+ if not is_safe_prompt:
+ return Response(
+ content=json.dumps({"error": f"{reason}"}),
+ media_type="application/json",
+ status_code=400,
+ )
+
+ selected_agent = await AgentAdapters.aget_agent_by_name(body.name, user)
+
+ if not selected_agent:
+ return Response(
+ content=json.dumps({"error": f"Agent with name {body.name} not found."}),
+ media_type="application/json",
+ status_code=404,
+ )
+
+ agent = await AgentAdapters.aupdate_agent(
+ user,
+ body.name,
+ body.persona,
+ body.privacy_level,
+ body.icon,
+ body.color,
+ body.chat_model,
+ body.files,
+ body.input_tools,
+ body.output_modes,
+ )
+
+ agents_packet = {
+ "slug": agent.slug,
+ "name": agent.name,
+ "persona": agent.personality,
+ "creator": agent.creator.username if agent.creator else None,
+ "managed_by_admin": agent.managed_by_admin,
+ "color": agent.style_color,
+ "icon": agent.style_icon,
+ "privacy_level": agent.privacy_level,
+ "chat_model": agent.chat_model.chat_model,
+ "files": body.files,
+ "input_tools": agent.input_tools,
+ "output_modes": agent.output_modes,
+ }
+
+ return Response(content=json.dumps(agents_packet), media_type="application/json", status_code=200)
diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py
index 039545c4..ce5f8fd8 100644
--- a/src/khoj/routers/api_chat.py
+++ b/src/khoj/routers/api_chat.py
@@ -17,13 +17,14 @@ from starlette.authentication import has_required_scope, requires
from khoj.app.settings import ALLOWED_HOSTS
from khoj.database.adapters import (
+ AgentAdapters,
ConversationAdapters,
EntryAdapters,
FileObjectAdapters,
PublicConversationAdapters,
aget_user_name,
)
-from khoj.database.models import KhojUser
+from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.image.generate import text_to_image
@@ -211,7 +212,6 @@ def chat_history(
agent_metadata = {
"slug": conversation.agent.slug,
"name": conversation.agent.name,
- "avatar": conversation.agent.avatar,
"isCreator": conversation.agent.creator == user,
"color": conversation.agent.style_color,
"icon": conversation.agent.style_icon,
@@ -268,7 +268,6 @@ def get_shared_chat(
agent_metadata = {
"slug": conversation.agent.slug,
"name": conversation.agent.name,
- "avatar": conversation.agent.avatar,
"isCreator": conversation.agent.creator == user,
"color": conversation.agent.style_color,
"icon": conversation.agent.style_icon,
@@ -418,7 +417,7 @@ def chat_sessions(
conversations = conversations[:8]
sessions = conversations.values_list(
- "id", "slug", "title", "agent__slug", "agent__name", "agent__avatar", "created_at", "updated_at"
+ "id", "slug", "title", "agent__slug", "agent__name", "created_at", "updated_at"
)
session_values = [
@@ -426,9 +425,8 @@ def chat_sessions(
"conversation_id": str(session[0]),
"slug": session[2] or session[1],
"agent_name": session[4],
- "agent_avatar": session[5],
- "created": session[6].strftime("%Y-%m-%d %H:%M:%S"),
- "updated": session[7].strftime("%Y-%m-%d %H:%M:%S"),
+ "created": session[5].strftime("%Y-%m-%d %H:%M:%S"),
+ "updated": session[6].strftime("%Y-%m-%d %H:%M:%S"),
}
for session in sessions
]
@@ -590,7 +588,7 @@ async def chat(
nonlocal connection_alive, ttft
if not connection_alive or await request.is_disconnected():
connection_alive = False
- logger.warn(f"User {user} disconnected from {common.client} client")
+ logger.warning(f"User {user} disconnected from {common.client} client")
return
try:
if event_type == ChatEvent.END_LLM_RESPONSE:
@@ -658,6 +656,11 @@ async def chat(
return
conversation_id = conversation.id
+ agent: Agent | None = None
+ default_agent = await AgentAdapters.aget_default_agent()
+ if conversation.agent and conversation.agent != default_agent:
+ agent = conversation.agent
+
await is_ready_to_chat(user)
user_name = await aget_user_name(user)
@@ -677,7 +680,12 @@ async def chat(
if conversation_commands == [ConversationCommand.Default] or is_automated_task:
conversation_commands = await aget_relevant_information_sources(
- q, meta_log, is_automated_task, subscribed=subscribed, uploaded_image_url=uploaded_image_url
+ q,
+ meta_log,
+ is_automated_task,
+ subscribed=subscribed,
+ uploaded_image_url=uploaded_image_url,
+ agent=agent,
)
conversation_commands_str = ", ".join([cmd.value for cmd in conversation_commands])
async for result in send_event(
@@ -685,7 +693,7 @@ async def chat(
):
yield result
- mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url)
+ mode = await aget_relevant_output_modes(q, meta_log, is_automated_task, uploaded_image_url, agent)
async for result in send_event(ChatEvent.STATUS, f"**Decided Response Mode:** {mode.value}"):
yield result
if mode not in conversation_commands:
@@ -734,7 +742,7 @@ async def chat(
yield result
response = await extract_relevant_summary(
- q, contextual_data, subscribed=subscribed, uploaded_image_url=uploaded_image_url
+ q, contextual_data, subscribed=subscribed, uploaded_image_url=uploaded_image_url, agent=agent
)
response_log = str(response)
async for result in send_llm_response(response_log):
@@ -816,6 +824,7 @@ async def chat(
location,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
+ agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -853,6 +862,7 @@ async def chat(
partial(send_event, ChatEvent.STATUS),
custom_filters,
uploaded_image_url=uploaded_image_url,
+ agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -876,6 +886,7 @@ async def chat(
subscribed,
partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
+ agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -922,6 +933,7 @@ async def chat(
subscribed=subscribed,
send_status_func=partial(send_event, ChatEvent.STATUS),
uploaded_image_url=uploaded_image_url,
+ agent=agent,
):
if isinstance(result, dict) and ChatEvent.STATUS in result:
yield result[ChatEvent.STATUS]
@@ -1132,6 +1144,7 @@ async def get_chat(
yield result
return
conversation_id = conversation.id
+ agent = conversation.agent if conversation.agent else None
await is_ready_to_chat(user)
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index 761d5ba9..26591d7f 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -47,6 +47,7 @@ from khoj.database.adapters import (
run_with_process_lock,
)
from khoj.database.models import (
+ Agent,
ChatModelOptions,
ClientApplication,
Conversation,
@@ -257,8 +258,39 @@ async def acreate_title_from_query(query: str) -> str:
return response.strip()
+async def acheck_if_safe_prompt(system_prompt: str) -> Tuple[bool, str]:
+ """
+ Check if the system prompt is safe to use
+ """
+ safe_prompt_check = prompts.personality_prompt_safety_expert.format(prompt=system_prompt)
+ is_safe = True
+ reason = ""
+
+ with timer("Chat actor: Check if safe prompt", logger):
+ response = await send_message_to_model_wrapper(safe_prompt_check)
+
+ response = response.strip()
+ try:
+ response = json.loads(response)
+ is_safe = response.get("safe", "True") == "True"
+ if not is_safe:
+ reason = response.get("reason", "")
+ except Exception:
+ logger.error(f"Invalid response for checking safe prompt: {response}")
+
+ if not is_safe:
+ logger.error(f"Unsafe prompt: {system_prompt}. Reason: {reason}")
+
+ return is_safe, reason
+
+
async def aget_relevant_information_sources(
- query: str, conversation_history: dict, is_task: bool, subscribed: bool, uploaded_image_url: str = None
+ query: str,
+ conversation_history: dict,
+ is_task: bool,
+ subscribed: bool,
+ uploaded_image_url: str = None,
+ agent: Agent = None,
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@@ -267,19 +299,27 @@ async def aget_relevant_information_sources(
tool_options = dict()
tool_options_str = ""
+ agent_tools = agent.input_tools if agent else []
+
for tool, description in tool_descriptions_for_llm.items():
tool_options[tool.value] = description
- tool_options_str += f'- "{tool.value}": "{description}"\n'
+ if len(agent_tools) == 0 or tool.value in agent_tools:
+ tool_options_str += f'- "{tool.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history)
if uploaded_image_url:
query = f"[placeholder for user attached image]\n{query}"
+ personality_context = (
+ prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
+ )
+
relevant_tools_prompt = prompts.pick_relevant_information_collection_tools.format(
query=query,
tools=tool_options_str,
chat_history=chat_history,
+ personality_context=personality_context,
)
with timer("Chat actor: Infer information sources to refer", logger):
@@ -300,7 +340,10 @@ async def aget_relevant_information_sources(
final_response = [] if not is_task else [ConversationCommand.AutomatedTask]
for llm_suggested_tool in response:
- if llm_suggested_tool in tool_options.keys():
+ # Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
+ if llm_suggested_tool in tool_options.keys() and (
+ len(agent_tools) == 0 or llm_suggested_tool in agent_tools
+ ):
# Check whether the tool exists as a valid ConversationCommand
final_response.append(ConversationCommand(llm_suggested_tool))
@@ -313,7 +356,7 @@ async def aget_relevant_information_sources(
async def aget_relevant_output_modes(
- query: str, conversation_history: dict, is_task: bool = False, uploaded_image_url: str = None
+ query: str, conversation_history: dict, is_task: bool = False, uploaded_image_url: str = None, agent: Agent = None
):
"""
Given a query, determine which of the available tools the agent should use in order to answer appropriately.
@@ -322,22 +365,30 @@ async def aget_relevant_output_modes(
mode_options = dict()
mode_options_str = ""
+ output_modes = agent.output_modes if agent else []
+
for mode, description in mode_descriptions_for_llm.items():
# Do not allow tasks to schedule another task
if is_task and mode == ConversationCommand.Automation:
continue
mode_options[mode.value] = description
- mode_options_str += f'- "{mode.value}": "{description}"\n'
+ if len(output_modes) == 0 or mode.value in output_modes:
+ mode_options_str += f'- "{mode.value}": "{description}"\n'
chat_history = construct_chat_history(conversation_history)
if uploaded_image_url:
query = f"[placeholder for user attached image]\n{query}"
+ personality_context = (
+ prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
+ )
+
relevant_mode_prompt = prompts.pick_relevant_output_mode.format(
query=query,
modes=mode_options_str,
chat_history=chat_history,
+ personality_context=personality_context,
)
with timer("Chat actor: Infer output mode for chat response", logger):
@@ -352,7 +403,9 @@ async def aget_relevant_output_modes(
return ConversationCommand.Text
output_mode = response["output"]
- if output_mode in mode_options.keys():
+
+ # Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
+ if output_mode in mode_options.keys() and (len(output_modes) == 0 or output_mode in output_modes):
# Check whether the tool exists as a valid ConversationCommand
return ConversationCommand(output_mode)
@@ -364,7 +417,12 @@ async def aget_relevant_output_modes(
async def infer_webpage_urls(
- q: str, conversation_history: dict, location_data: LocationData, user: KhojUser, uploaded_image_url: str = None
+ q: str,
+ conversation_history: dict,
+ location_data: LocationData,
+ user: KhojUser,
+ uploaded_image_url: str = None,
+ agent: Agent = None,
) -> List[str]:
"""
Infer webpage links from the given query
@@ -374,12 +432,17 @@ async def infer_webpage_urls(
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
+ personality_context = (
+ prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
+ )
+
online_queries_prompt = prompts.infer_webpages_to_read.format(
current_date=utc_date,
query=q,
chat_history=chat_history,
location=location,
username=username,
+ personality_context=personality_context,
)
with timer("Chat actor: Infer webpage urls to read", logger):
@@ -400,7 +463,12 @@ async def infer_webpage_urls(
async def generate_online_subqueries(
- q: str, conversation_history: dict, location_data: LocationData, user: KhojUser, uploaded_image_url: str = None
+ q: str,
+ conversation_history: dict,
+ location_data: LocationData,
+ user: KhojUser,
+ uploaded_image_url: str = None,
+ agent: Agent = None,
) -> List[str]:
"""
Generate subqueries from the given query
@@ -410,12 +478,17 @@ async def generate_online_subqueries(
chat_history = construct_chat_history(conversation_history)
utc_date = datetime.utcnow().strftime("%Y-%m-%d")
+ personality_context = (
+ prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
+ )
+
online_queries_prompt = prompts.online_search_conversation_subqueries.format(
current_date=utc_date,
query=q,
chat_history=chat_history,
location=location,
username=username,
+ personality_context=personality_context,
)
with timer("Chat actor: Generate online search subqueries", logger):
@@ -464,7 +537,7 @@ async def schedule_query(q: str, conversation_history: dict, uploaded_image_url:
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
-async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[str, None]:
+async def extract_relevant_info(q: str, corpus: str, subscribed: bool, agent: Agent = None) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
"""
@@ -472,9 +545,14 @@ async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[
if is_none_or_empty(corpus) or is_none_or_empty(q):
return None
+ personality_context = (
+ prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
+ )
+
extract_relevant_information = prompts.extract_relevant_information.format(
query=q,
corpus=corpus.strip(),
+ personality_context=personality_context,
)
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
@@ -490,7 +568,7 @@ async def extract_relevant_info(q: str, corpus: str, subscribed: bool) -> Union[
async def extract_relevant_summary(
- q: str, corpus: str, subscribed: bool = False, uploaded_image_url: str = None
+ q: str, corpus: str, subscribed: bool = False, uploaded_image_url: str = None, agent: Agent = None
) -> Union[str, None]:
"""
Extract relevant information for a given query from the target corpus
@@ -499,9 +577,14 @@ async def extract_relevant_summary(
if is_none_or_empty(corpus) or is_none_or_empty(q):
return None
+ personality_context = (
+ prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
+ )
+
extract_relevant_information = prompts.extract_relevant_summary.format(
query=q,
corpus=corpus.strip(),
+ personality_context=personality_context,
)
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
@@ -526,12 +609,16 @@ async def generate_better_image_prompt(
model_type: Optional[str] = None,
subscribed: bool = False,
uploaded_image_url: Optional[str] = None,
+ agent: Agent = None,
) -> str:
"""
Generate a better image prompt from the given query
"""
today_date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d, %A")
+ personality_context = (
+ prompts.personality_context.format(personality=agent.personality) if agent and agent.personality else ""
+ )
model_type = model_type or TextToImageModelConfig.ModelType.OPENAI
if location_data:
@@ -558,6 +645,7 @@ async def generate_better_image_prompt(
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
+ personality_context=personality_context,
)
elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]:
image_prompt = prompts.image_generation_improve_prompt_sd.format(
@@ -567,6 +655,7 @@ async def generate_better_image_prompt(
current_date=today_date,
references=user_references,
online_results=simplified_online_results,
+ personality_context=personality_context,
)
chat_model: ChatModelOptions = await ConversationAdapters.aget_default_conversation_config()
@@ -651,15 +740,13 @@ async def send_message_to_model_wrapper(
model_type=conversation_config.model_type,
)
- openai_response = send_message_to_model(
+ return send_message_to_model(
messages=truncated_messages,
api_key=api_key,
model=chat_model,
response_type=response_type,
api_base_url=api_base_url,
)
-
- return openai_response
elif model_type == ChatModelOptions.ModelType.ANTHROPIC:
api_key = conversation_config.openai_config.api_key
truncated_messages = generate_chatml_messages_with_context(
diff --git a/src/khoj/search_type/text_search.py b/src/khoj/search_type/text_search.py
index 569f0b50..52e23f29 100644
--- a/src/khoj/search_type/text_search.py
+++ b/src/khoj/search_type/text_search.py
@@ -1,13 +1,14 @@
import logging
import math
from pathlib import Path
-from typing import List, Tuple, Type, Union
+from typing import List, Optional, Tuple, Type, Union
import torch
from asgiref.sync import sync_to_async
from sentence_transformers import util
from khoj.database.adapters import EntryAdapters, get_user_search_model_or_default
+from khoj.database.models import Agent
from khoj.database.models import Entry as DbEntry
from khoj.database.models import KhojUser
from khoj.processor.content.text_to_entries import TextToEntries
@@ -101,6 +102,7 @@ async def query(
type: SearchType = SearchType.All,
question_embedding: Union[torch.Tensor, None] = None,
max_distance: float = None,
+ agent: Optional[Agent] = None,
) -> Tuple[List[dict], List[Entry]]:
"Search for entries that answer the query"
@@ -129,6 +131,7 @@ async def query(
file_type_filter=file_type,
raw_query=raw_query,
max_distance=max_distance,
+ agent=agent,
).all()
hits = await sync_to_async(list)(hits) # type: ignore[call-arg]
diff --git a/src/khoj/utils/helpers.py b/src/khoj/utils/helpers.py
index 4260d2cb..30c9d5d6 100644
--- a/src/khoj/utils/helpers.py
+++ b/src/khoj/utils/helpers.py
@@ -325,7 +325,15 @@ command_descriptions = {
ConversationCommand.Image: "Generate images by describing your imagination in words.",
ConversationCommand.Automation: "Automatically run your query at a specified time or interval.",
ConversationCommand.Help: "Get help with how to use or setup Khoj from the documentation",
- ConversationCommand.Summarize: "Create an appropriate summary using provided documents.",
+ ConversationCommand.Summarize: "Get help with a question pertaining to an entire document.",
+}
+
+command_descriptions_for_agent = {
+ ConversationCommand.General: "Respond without any outside information or personal knowledge.",
+ ConversationCommand.Notes: "Search through the knowledge base. Required if the agent expects context from the knowledge base.",
+ ConversationCommand.Online: "Search for the latest, up-to-date information from the internet.",
+ ConversationCommand.Webpage: "Scrape specific web pages for information.",
+ ConversationCommand.Summarize: "Retrieve an answer that depends on the entire document or a large text. Knowledge base must be a single document.",
}
tool_descriptions_for_llm = {
@@ -334,7 +342,7 @@ tool_descriptions_for_llm = {
ConversationCommand.Notes: "To search the user's personal knowledge base. Especially helpful if the question expects context from the user's notes or documents.",
ConversationCommand.Online: "To search for the latest, up-to-date information from the internet. Note: **Questions about Khoj should always use this data source**",
ConversationCommand.Webpage: "To use if the user has directly provided the webpage urls or you are certain of the webpage urls to read.",
- ConversationCommand.Summarize: "To create a summary of the document provided by the user.",
+ ConversationCommand.Summarize: "To retrieve an answer that depends on the entire document or a large text.",
}
mode_descriptions_for_llm = {
@@ -343,6 +351,11 @@ mode_descriptions_for_llm = {
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
}
+mode_descriptions_for_agent = {
+ ConversationCommand.Image: "Allow the agent to generate images.",
+ ConversationCommand.Text: "Allow the agent to generate text.",
+}
+
class ImageIntentType(Enum):
"""