From 352168d6c29efb43435a1bd87bb3e70ffd57de7a Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 11 Mar 2024 14:20:28 +0530 Subject: [PATCH] Customize default behaviors for conversations without agents or with default agents --- src/khoj/configure.py | 6 +++ src/khoj/database/adapters/__init__.py | 43 +++++++++++++++++++ src/khoj/database/models/__init__.py | 10 +++++ .../conversation/offline/chat_model.py | 2 +- src/khoj/processor/conversation/openai/gpt.py | 2 +- src/khoj/processor/conversation/prompts.py | 4 +- src/khoj/routers/api.py | 2 +- src/khoj/routers/helpers.py | 5 +-- 8 files changed, 66 insertions(+), 8 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 95b32018..4cc2ff3a 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -21,6 +21,7 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.requests import HTTPConnection from khoj.database.adapters import ( + AgentAdapters, ClientApplicationAdapters, ConversationAdapters, SubscriptionState, @@ -229,11 +230,16 @@ def configure_server( state.SearchType = configure_search_types() state.search_models = configure_search(state.search_models, state.config.search_type) + setup_default_agent() initialize_content(regenerate, search_type, init, user) except Exception as e: raise e +def setup_default_agent(): + AgentAdapters.create_default_agent() + + def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None): # Initialize Content from Config if state.search_models: diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index c639277f..54cce808 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -38,6 +38,7 @@ from khoj.database.models import ( UserRequests, UserSearchModelConfig, ) +from khoj.processor.conversation import prompts from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter @@ -393,10 +394,52 @@ class ClientApplicationAdapters: class AgentAdapters: + DEFAULT_AGENT_NAME = "khoj" + DEFAULT_AGENT_AVATAR = "https://khoj-web-bucket.s3.amazonaws.com/lamp-128.png" + @staticmethod async def aget_agent_by_id(agent_id: int): return await Agent.objects.filter(id=agent_id).afirst() + @staticmethod + def get_all_acessible_agents(user: KhojUser = None): + return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct() + + @staticmethod + def get_conversation_agent_by_id(agent_id: int): + agent = Agent.objects.filter(id=agent_id).first() + if agent == AgentAdapters.get_default_agent(): + # If the agent is set to the default agent, then return None and let the default application code be used + return None + return agent + + @staticmethod + def get_default_agent(): + return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first() + + @staticmethod + def create_default_agent(): + # First delete the existing default + Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).delete() + + default_conversation_config = ConversationAdapters.get_default_conversation_config() + default_personality = prompts.personality.format(current_date="placeholder") + + # The default agent is public and managed by the admin. It's handled a little differently than other agents. + return Agent.objects.create( + name=AgentAdapters.DEFAULT_AGENT_NAME, + public=True, + managed_by_admin=True, + chat_model=default_conversation_config, + tuning=default_personality, + tools=["*"], + avatar=AgentAdapters.DEFAULT_AGENT_AVATAR, + ) + + @staticmethod + async def aget_default_agent(): + return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst() + class ConversationAdapters: @staticmethod diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index fc2ca4ac..5272a1f6 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -1,7 +1,10 @@ import uuid from django.contrib.auth.models import AbstractUser +from django.core.exceptions import ValidationError from django.db import models +from django.db.models.signals import pre_save +from django.dispatch import receiver from pgvector.django import VectorField from phonenumber_field.modelfields import PhoneNumberField @@ -91,6 +94,13 @@ class Agent(BaseModel): chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE) +@receiver(pre_save, sender=Agent) +def check_public_name(sender, instance, **kwargs): + if instance.public: + if Agent.objects.filter(name=instance.name, public=True).exists(): + raise ValidationError(f"A public Agent with the name {instance.name} already exists.") + + class NotionConfig(BaseModel): token = models.CharField(max_length=200) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 4c51a085..aceaacaa 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -161,7 +161,7 @@ def converse_offline( system_prompt = "" current_date = datetime.now().strftime("%Y-%m-%d") - if agent: + if agent and agent.tuning: system_prompt = prompts.custom_system_prompt_message_gpt4all.format( name=agent.name, bio=agent.tuning, current_date=current_date ) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 7d520633..4d8ebb3c 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -132,7 +132,7 @@ def converse( system_prompt = "" - if agent: + if agent and agent.tuning: system_prompt = prompts.custom_personality.format(name=agent.name, bio=agent.tuning, current_date=current_date) else: system_prompt = prompts.personality.format(current_date=current_date) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 99ae7d7c..799c390e 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -35,7 +35,7 @@ You were created by Khoj Inc. with the following capabilities: Today is {current_date} in UTC. -Here's a bio about you: {bio} +Instructions:\n{bio} """.strip() ) @@ -89,7 +89,7 @@ You are {name}, a personal agent on Khoj. Today is {current_date} in UTC. -Here is your instruction set:\n{bio} +Instructions:\n{bio} """.strip() ) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index b54da90b..8f010fc2 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -13,7 +13,7 @@ from fastapi.requests import Request from fastapi.responses import Response from starlette.authentication import requires -from khoj.configure import configure_server, initialize_content +from khoj.configure import initialize_content from khoj.database.adapters import ( ConversationAdapters, EntryAdapters, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index ef93608b..b0fcc4a0 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -10,9 +10,8 @@ import openai from fastapi import Depends, Header, HTTPException, Request, UploadFile from starlette.authentication import has_required_scope -from khoj.database.adapters import ConversationAdapters, EntryAdapters +from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters from khoj.database.models import ( - Agent, ChatModelOptions, ClientApplication, Conversation, @@ -377,7 +376,7 @@ def generate_chat_response( logger.debug(f"Conversation Types: {conversation_commands}") metadata = {} - agent = conversation.agent + agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None try: partial_completion = partial(