Customize default behaviors for conversations without agents or with default agents

This commit is contained in:
sabaimran 2024-03-11 14:20:28 +05:30
parent 9b88976f36
commit 352168d6c2
8 changed files with 66 additions and 8 deletions

View file

@ -21,6 +21,7 @@ from starlette.middleware.sessions import SessionMiddleware
from starlette.requests import HTTPConnection from starlette.requests import HTTPConnection
from khoj.database.adapters import ( from khoj.database.adapters import (
AgentAdapters,
ClientApplicationAdapters, ClientApplicationAdapters,
ConversationAdapters, ConversationAdapters,
SubscriptionState, SubscriptionState,
@ -229,11 +230,16 @@ def configure_server(
state.SearchType = configure_search_types() state.SearchType = configure_search_types()
state.search_models = configure_search(state.search_models, state.config.search_type) state.search_models = configure_search(state.search_models, state.config.search_type)
setup_default_agent()
initialize_content(regenerate, search_type, init, user) initialize_content(regenerate, search_type, init, user)
except Exception as e: except Exception as e:
raise e raise e
def setup_default_agent():
AgentAdapters.create_default_agent()
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None): def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
# Initialize Content from Config # Initialize Content from Config
if state.search_models: if state.search_models:

View file

@ -38,6 +38,7 @@ from khoj.database.models import (
UserRequests, UserRequests,
UserSearchModelConfig, UserSearchModelConfig,
) )
from khoj.processor.conversation import prompts
from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
@ -393,10 +394,52 @@ class ClientApplicationAdapters:
class AgentAdapters: class AgentAdapters:
DEFAULT_AGENT_NAME = "khoj"
DEFAULT_AGENT_AVATAR = "https://khoj-web-bucket.s3.amazonaws.com/lamp-128.png"
@staticmethod @staticmethod
async def aget_agent_by_id(agent_id: int): async def aget_agent_by_id(agent_id: int):
return await Agent.objects.filter(id=agent_id).afirst() return await Agent.objects.filter(id=agent_id).afirst()
@staticmethod
def get_all_acessible_agents(user: KhojUser = None):
return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct()
@staticmethod
def get_conversation_agent_by_id(agent_id: int):
agent = Agent.objects.filter(id=agent_id).first()
if agent == AgentAdapters.get_default_agent():
# If the agent is set to the default agent, then return None and let the default application code be used
return None
return agent
@staticmethod
def get_default_agent():
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
@staticmethod
def create_default_agent():
# First delete the existing default
Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).delete()
default_conversation_config = ConversationAdapters.get_default_conversation_config()
default_personality = prompts.personality.format(current_date="placeholder")
# The default agent is public and managed by the admin. It's handled a little differently than other agents.
return Agent.objects.create(
name=AgentAdapters.DEFAULT_AGENT_NAME,
public=True,
managed_by_admin=True,
chat_model=default_conversation_config,
tuning=default_personality,
tools=["*"],
avatar=AgentAdapters.DEFAULT_AGENT_AVATAR,
)
@staticmethod
async def aget_default_agent():
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
class ConversationAdapters: class ConversationAdapters:
@staticmethod @staticmethod

View file

@ -1,7 +1,10 @@
import uuid import uuid
from django.contrib.auth.models import AbstractUser from django.contrib.auth.models import AbstractUser
from django.core.exceptions import ValidationError
from django.db import models from django.db import models
from django.db.models.signals import pre_save
from django.dispatch import receiver
from pgvector.django import VectorField from pgvector.django import VectorField
from phonenumber_field.modelfields import PhoneNumberField from phonenumber_field.modelfields import PhoneNumberField
@ -91,6 +94,13 @@ class Agent(BaseModel):
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE) chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
@receiver(pre_save, sender=Agent)
def check_public_name(sender, instance, **kwargs):
if instance.public:
if Agent.objects.filter(name=instance.name, public=True).exists():
raise ValidationError(f"A public Agent with the name {instance.name} already exists.")
class NotionConfig(BaseModel): class NotionConfig(BaseModel):
token = models.CharField(max_length=200) token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)

View file

@ -161,7 +161,7 @@ def converse_offline(
system_prompt = "" system_prompt = ""
current_date = datetime.now().strftime("%Y-%m-%d") current_date = datetime.now().strftime("%Y-%m-%d")
if agent: if agent and agent.tuning:
system_prompt = prompts.custom_system_prompt_message_gpt4all.format( system_prompt = prompts.custom_system_prompt_message_gpt4all.format(
name=agent.name, bio=agent.tuning, current_date=current_date name=agent.name, bio=agent.tuning, current_date=current_date
) )

View file

@ -132,7 +132,7 @@ def converse(
system_prompt = "" system_prompt = ""
if agent: if agent and agent.tuning:
system_prompt = prompts.custom_personality.format(name=agent.name, bio=agent.tuning, current_date=current_date) system_prompt = prompts.custom_personality.format(name=agent.name, bio=agent.tuning, current_date=current_date)
else: else:
system_prompt = prompts.personality.format(current_date=current_date) system_prompt = prompts.personality.format(current_date=current_date)

View file

@ -35,7 +35,7 @@ You were created by Khoj Inc. with the following capabilities:
Today is {current_date} in UTC. Today is {current_date} in UTC.
Here's a bio about you: {bio} Instructions:\n{bio}
""".strip() """.strip()
) )
@ -89,7 +89,7 @@ You are {name}, a personal agent on Khoj.
Today is {current_date} in UTC. Today is {current_date} in UTC.
Here is your instruction set:\n{bio} Instructions:\n{bio}
""".strip() """.strip()
) )

View file

@ -13,7 +13,7 @@ from fastapi.requests import Request
from fastapi.responses import Response from fastapi.responses import Response
from starlette.authentication import requires from starlette.authentication import requires
from khoj.configure import configure_server, initialize_content from khoj.configure import initialize_content
from khoj.database.adapters import ( from khoj.database.adapters import (
ConversationAdapters, ConversationAdapters,
EntryAdapters, EntryAdapters,

View file

@ -10,9 +10,8 @@ import openai
from fastapi import Depends, Header, HTTPException, Request, UploadFile from fastapi import Depends, Header, HTTPException, Request, UploadFile
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
from khoj.database.models import ( from khoj.database.models import (
Agent,
ChatModelOptions, ChatModelOptions,
ClientApplication, ClientApplication,
Conversation, Conversation,
@ -377,7 +376,7 @@ def generate_chat_response(
logger.debug(f"Conversation Types: {conversation_commands}") logger.debug(f"Conversation Types: {conversation_commands}")
metadata = {} metadata = {}
agent = conversation.agent agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
try: try:
partial_completion = partial( partial_completion = partial(