mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-24 07:55:07 +01:00
Customize default behaviors for conversations without agents or with default agents
This commit is contained in:
parent
9b88976f36
commit
352168d6c2
8 changed files with 66 additions and 8 deletions
|
@ -21,6 +21,7 @@ from starlette.middleware.sessions import SessionMiddleware
|
||||||
from starlette.requests import HTTPConnection
|
from starlette.requests import HTTPConnection
|
||||||
|
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
|
AgentAdapters,
|
||||||
ClientApplicationAdapters,
|
ClientApplicationAdapters,
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
SubscriptionState,
|
SubscriptionState,
|
||||||
|
@ -229,11 +230,16 @@ def configure_server(
|
||||||
|
|
||||||
state.SearchType = configure_search_types()
|
state.SearchType = configure_search_types()
|
||||||
state.search_models = configure_search(state.search_models, state.config.search_type)
|
state.search_models = configure_search(state.search_models, state.config.search_type)
|
||||||
|
setup_default_agent()
|
||||||
initialize_content(regenerate, search_type, init, user)
|
initialize_content(regenerate, search_type, init, user)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def setup_default_agent():
|
||||||
|
AgentAdapters.create_default_agent()
|
||||||
|
|
||||||
|
|
||||||
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
|
def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None):
|
||||||
# Initialize Content from Config
|
# Initialize Content from Config
|
||||||
if state.search_models:
|
if state.search_models:
|
||||||
|
|
|
@ -38,6 +38,7 @@ from khoj.database.models import (
|
||||||
UserRequests,
|
UserRequests,
|
||||||
UserSearchModelConfig,
|
UserSearchModelConfig,
|
||||||
)
|
)
|
||||||
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.search_filter.date_filter import DateFilter
|
from khoj.search_filter.date_filter import DateFilter
|
||||||
from khoj.search_filter.file_filter import FileFilter
|
from khoj.search_filter.file_filter import FileFilter
|
||||||
from khoj.search_filter.word_filter import WordFilter
|
from khoj.search_filter.word_filter import WordFilter
|
||||||
|
@ -393,10 +394,52 @@ class ClientApplicationAdapters:
|
||||||
|
|
||||||
|
|
||||||
class AgentAdapters:
|
class AgentAdapters:
|
||||||
|
DEFAULT_AGENT_NAME = "khoj"
|
||||||
|
DEFAULT_AGENT_AVATAR = "https://khoj-web-bucket.s3.amazonaws.com/lamp-128.png"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def aget_agent_by_id(agent_id: int):
|
async def aget_agent_by_id(agent_id: int):
|
||||||
return await Agent.objects.filter(id=agent_id).afirst()
|
return await Agent.objects.filter(id=agent_id).afirst()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_acessible_agents(user: KhojUser = None):
|
||||||
|
return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_conversation_agent_by_id(agent_id: int):
|
||||||
|
agent = Agent.objects.filter(id=agent_id).first()
|
||||||
|
if agent == AgentAdapters.get_default_agent():
|
||||||
|
# If the agent is set to the default agent, then return None and let the default application code be used
|
||||||
|
return None
|
||||||
|
return agent
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_default_agent():
|
||||||
|
return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_default_agent():
|
||||||
|
# First delete the existing default
|
||||||
|
Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).delete()
|
||||||
|
|
||||||
|
default_conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||||
|
default_personality = prompts.personality.format(current_date="placeholder")
|
||||||
|
|
||||||
|
# The default agent is public and managed by the admin. It's handled a little differently than other agents.
|
||||||
|
return Agent.objects.create(
|
||||||
|
name=AgentAdapters.DEFAULT_AGENT_NAME,
|
||||||
|
public=True,
|
||||||
|
managed_by_admin=True,
|
||||||
|
chat_model=default_conversation_config,
|
||||||
|
tuning=default_personality,
|
||||||
|
tools=["*"],
|
||||||
|
avatar=AgentAdapters.DEFAULT_AGENT_AVATAR,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def aget_default_agent():
|
||||||
|
return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst()
|
||||||
|
|
||||||
|
|
||||||
class ConversationAdapters:
|
class ConversationAdapters:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from django.contrib.auth.models import AbstractUser
|
from django.contrib.auth.models import AbstractUser
|
||||||
|
from django.core.exceptions import ValidationError
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
from django.db.models.signals import pre_save
|
||||||
|
from django.dispatch import receiver
|
||||||
from pgvector.django import VectorField
|
from pgvector.django import VectorField
|
||||||
from phonenumber_field.modelfields import PhoneNumberField
|
from phonenumber_field.modelfields import PhoneNumberField
|
||||||
|
|
||||||
|
@ -91,6 +94,13 @@ class Agent(BaseModel):
|
||||||
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
|
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
|
||||||
|
|
||||||
|
|
||||||
|
@receiver(pre_save, sender=Agent)
|
||||||
|
def check_public_name(sender, instance, **kwargs):
|
||||||
|
if instance.public:
|
||||||
|
if Agent.objects.filter(name=instance.name, public=True).exists():
|
||||||
|
raise ValidationError(f"A public Agent with the name {instance.name} already exists.")
|
||||||
|
|
||||||
|
|
||||||
class NotionConfig(BaseModel):
|
class NotionConfig(BaseModel):
|
||||||
token = models.CharField(max_length=200)
|
token = models.CharField(max_length=200)
|
||||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||||
|
|
|
@ -161,7 +161,7 @@ def converse_offline(
|
||||||
system_prompt = ""
|
system_prompt = ""
|
||||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||||
|
|
||||||
if agent:
|
if agent and agent.tuning:
|
||||||
system_prompt = prompts.custom_system_prompt_message_gpt4all.format(
|
system_prompt = prompts.custom_system_prompt_message_gpt4all.format(
|
||||||
name=agent.name, bio=agent.tuning, current_date=current_date
|
name=agent.name, bio=agent.tuning, current_date=current_date
|
||||||
)
|
)
|
||||||
|
|
|
@ -132,7 +132,7 @@ def converse(
|
||||||
|
|
||||||
system_prompt = ""
|
system_prompt = ""
|
||||||
|
|
||||||
if agent:
|
if agent and agent.tuning:
|
||||||
system_prompt = prompts.custom_personality.format(name=agent.name, bio=agent.tuning, current_date=current_date)
|
system_prompt = prompts.custom_personality.format(name=agent.name, bio=agent.tuning, current_date=current_date)
|
||||||
else:
|
else:
|
||||||
system_prompt = prompts.personality.format(current_date=current_date)
|
system_prompt = prompts.personality.format(current_date=current_date)
|
||||||
|
|
|
@ -35,7 +35,7 @@ You were created by Khoj Inc. with the following capabilities:
|
||||||
|
|
||||||
Today is {current_date} in UTC.
|
Today is {current_date} in UTC.
|
||||||
|
|
||||||
Here's a bio about you: {bio}
|
Instructions:\n{bio}
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -89,7 +89,7 @@ You are {name}, a personal agent on Khoj.
|
||||||
|
|
||||||
Today is {current_date} in UTC.
|
Today is {current_date} in UTC.
|
||||||
|
|
||||||
Here is your instruction set:\n{bio}
|
Instructions:\n{bio}
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ from fastapi.requests import Request
|
||||||
from fastapi.responses import Response
|
from fastapi.responses import Response
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
|
|
||||||
from khoj.configure import configure_server, initialize_content
|
from khoj.configure import initialize_content
|
||||||
from khoj.database.adapters import (
|
from khoj.database.adapters import (
|
||||||
ConversationAdapters,
|
ConversationAdapters,
|
||||||
EntryAdapters,
|
EntryAdapters,
|
||||||
|
|
|
@ -10,9 +10,8 @@ import openai
|
||||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters
|
||||||
from khoj.database.models import (
|
from khoj.database.models import (
|
||||||
Agent,
|
|
||||||
ChatModelOptions,
|
ChatModelOptions,
|
||||||
ClientApplication,
|
ClientApplication,
|
||||||
Conversation,
|
Conversation,
|
||||||
|
@ -377,7 +376,7 @@ def generate_chat_response(
|
||||||
logger.debug(f"Conversation Types: {conversation_commands}")
|
logger.debug(f"Conversation Types: {conversation_commands}")
|
||||||
|
|
||||||
metadata = {}
|
metadata = {}
|
||||||
agent = conversation.agent
|
agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
partial_completion = partial(
|
partial_completion = partial(
|
||||||
|
|
Loading…
Reference in a new issue