mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Customize default behaviors for conversations without agents or with default agents
This commit is contained in:
parent
9b88976f36
commit
352168d6c2
8 changed files with 66 additions and 8 deletions
|
@ -21,6 +21,7 @@ from starlette.middleware.sessions import SessionMiddleware
|
|||
from starlette.requests import HTTPConnection
|
||||
|
||||
from khoj.database.adapters import (
|
||||
AgentAdapters,
|
||||
ClientApplicationAdapters,
|
||||
ConversationAdapters,
|
||||
SubscriptionState,
|
||||
|
@ -229,11 +230,16 @@ def configure_server(
|
|||
|
||||
state.SearchType = configure_search_types()
|
||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||
setup_default_agent()
|
||||
initialize_content(regenerate, search_type, init, user)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def setup_default_agent():
|
||||
AgentAdapters.create_default_agent()
|
||||
|
||||
|
||||
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
|
||||
# Initialize Content from Config
|
||||
if state.search_models:
|
||||
|
|
|
@ -38,6 +38,7 @@ from khoj.database.models import (
|
|||
UserRequests,
|
||||
UserSearchModelConfig,
|
||||
)
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
from khoj.search_filter.file_filter import FileFilter
|
||||
from khoj.search_filter.word_filter import WordFilter
|
||||
|
@ -393,10 +394,52 @@ class ClientApplicationAdapters:
|
|||
|
||||
|
||||
class AgentAdapters:
|
||||
DEFAULT_AGENT_NAME = "khoj"
|
||||
DEFAULT_AGENT_AVATAR = "https://khoj-web-bucket.s3.amazonaws.com/lamp-128.png"
|
||||
|
||||
@staticmethod
|
||||
async def aget_agent_by_id(agent_id: int):
|
||||
return await Agent.objects.filter(id=agent_id).afirst()
|
||||
|
||||
@staticmethod
|
||||
def get_all_acessible_agents(user: KhojUser = None):
|
||||
return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct()
|
||||
|
||||
@staticmethod
|
||||
def get_conversation_agent_by_id(agent_id: int):
|
||||
agent = Agent.objects.filter(id=agent_id).first()
|
||||
if agent == AgentAdapters.get_default_agent():
|
||||
# If the agent is set to the default agent, then return None and let the default application code be used
|
||||
return None
|
||||
return agent
|
||||
|
||||
@staticmethod
|
||||
def get_default_agent():
|
||||
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
|
||||
|
||||
@staticmethod
|
||||
def create_default_agent():
|
||||
# First delete the existing default
|
||||
Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).delete()
|
||||
|
||||
default_conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||
default_personality = prompts.personality.format(current_date="placeholder")
|
||||
|
||||
# The default agent is public and managed by the admin. It's handled a little differently than other agents.
|
||||
return Agent.objects.create(
|
||||
name=AgentAdapters.DEFAULT_AGENT_NAME,
|
||||
public=True,
|
||||
managed_by_admin=True,
|
||||
chat_model=default_conversation_config,
|
||||
tuning=default_personality,
|
||||
tools=["*"],
|
||||
avatar=AgentAdapters.DEFAULT_AGENT_AVATAR,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def aget_default_agent():
|
||||
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
|
||||
|
||||
|
||||
class ConversationAdapters:
|
||||
@staticmethod
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
import uuid
|
||||
|
||||
from django.contrib.auth.models import AbstractUser
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.db.models.signals import pre_save
|
||||
from django.dispatch import receiver
|
||||
from pgvector.django import VectorField
|
||||
from phonenumber_field.modelfields import PhoneNumberField
|
||||
|
||||
|
@ -91,6 +94,13 @@ class Agent(BaseModel):
|
|||
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
@receiver(pre_save, sender=Agent)
|
||||
def check_public_name(sender, instance, **kwargs):
|
||||
if instance.public:
|
||||
if Agent.objects.filter(name=instance.name, public=True).exists():
|
||||
raise ValidationError(f"A public Agent with the name {instance.name} already exists.")
|
||||
|
||||
|
||||
class NotionConfig(BaseModel):
|
||||
token = models.CharField(max_length=200)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
|
|
@ -161,7 +161,7 @@ def converse_offline(
|
|||
system_prompt = ""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
if agent:
|
||||
if agent and agent.tuning:
|
||||
system_prompt = prompts.custom_system_prompt_message_gpt4all.format(
|
||||
name=agent.name, bio=agent.tuning, current_date=current_date
|
||||
)
|
||||
|
|
|
@ -132,7 +132,7 @@ def converse(
|
|||
|
||||
system_prompt = ""
|
||||
|
||||
if agent:
|
||||
if agent and agent.tuning:
|
||||
system_prompt = prompts.custom_personality.format(name=agent.name, bio=agent.tuning, current_date=current_date)
|
||||
else:
|
||||
system_prompt = prompts.personality.format(current_date=current_date)
|
||||
|
|
|
@ -35,7 +35,7 @@ You were created by Khoj Inc. with the following capabilities:
|
|||
|
||||
Today is {current_date} in UTC.
|
||||
|
||||
Here's a bio about you: {bio}
|
||||
Instructions:\n{bio}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
@ -89,7 +89,7 @@ You are {name}, a personal agent on Khoj.
|
|||
|
||||
Today is {current_date} in UTC.
|
||||
|
||||
Here is your instruction set:\n{bio}
|
||||
Instructions:\n{bio}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from fastapi.requests import Request
|
|||
from fastapi.responses import Response
|
||||
from starlette.authentication import requires
|
||||
|
||||
from khoj.configure import configure_server, initialize_content
|
||||
from khoj.configure import initialize_content
|
||||
from khoj.database.adapters import (
|
||||
ConversationAdapters,
|
||||
EntryAdapters,
|
||||
|
|
|
@ -10,9 +10,8 @@ import openai
|
|||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from starlette.authentication import has_required_scope
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatModelOptions,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
|
@ -377,7 +376,7 @@ def generate_chat_response(
|
|||
logger.debug(f"Conversation Types: {conversation_commands}")
|
||||
|
||||
metadata = {}
|
||||
agent = conversation.agent
|
||||
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||
|
||||
try:
|
||||
partial_completion = partial(
|
||||
|
|
Loading…
Reference in a new issue