From 8abc8ded82e6bdef629d6c055a0c64f8f52e45d2 Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Sat, 23 Mar 2024 09:39:38 -0700 Subject: [PATCH 1/6] Part 1: Server-side changes to support agents integrated with Conversations (#671) * Initial pass at backend changes to support agents - Add a db model for Agents, attaching them to conversations - When an agent is added to a conversation, override the system prompt to tweak the instructions - Agents can be configured with prompt modification, model specification, a profile picture, and other things - Admin-configured models will not be editable by individual users - Add unit tests to verify agent behavior. Unit tests demonstrate imperfect adherence to prompt specifications * Customize default behaviors for conversations without agents or with default agents * Use agent_id for getting correct agent * Merge migrations * Simplify some variable definitions, add additional security checks for agents * Rename agent.tuning -> agent.personality --- src/khoj/configure.py | 6 ++ src/khoj/database/adapters/__init__.py | 81 ++++++++++++-- src/khoj/database/admin.py | 2 + .../0031_agent_conversation_agent.py | 52 +++++++++ .../migrations/0032_merge_20240322_0427.py | 14 +++ .../0033_rename_tuning_agent_personality.py | 17 +++ src/khoj/database/models/__init__.py | 46 ++++++-- .../conversation/offline/chat_model.py | 14 ++- src/khoj/processor/conversation/openai/gpt.py | 11 +- src/khoj/processor/conversation/prompts.py | 32 ++++++ src/khoj/routers/api.py | 2 +- src/khoj/routers/api_chat.py | 4 +- src/khoj/routers/helpers.py | 9 +- tests/conftest.py | 23 ++++ tests/test_gpt4all_chat_actors.py | 41 +++++++ tests/test_gpt4all_chat_director.py | 101 ++++++++++++++---- tests/test_openai_chat_actors.py | 36 +++++++ tests/test_openai_chat_director.py | 96 +++++++++++++---- 18 files changed, 527 insertions(+), 60 deletions(-) create mode 100644 src/khoj/database/migrations/0031_agent_conversation_agent.py create mode 100644 src/khoj/database/migrations/0032_merge_20240322_0427.py create mode 100644 src/khoj/database/migrations/0033_rename_tuning_agent_personality.py diff --git a/src/khoj/configure.py b/src/khoj/configure.py index add685e3..fb3a93ce 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -21,6 +21,7 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.requests import HTTPConnection from khoj.database.adapters import ( + AgentAdapters, ClientApplicationAdapters, ConversationAdapters, SubscriptionState, @@ -229,11 +230,16 @@ def configure_server( state.SearchType = configure_search_types() state.search_models = configure_search(state.search_models, state.config.search_type) + setup_default_agent() initialize_content(regenerate, search_type, init, user) except Exception as e: raise e +def setup_default_agent(): + AgentAdapters.create_default_agent() + + def initialize_content(regenerate: bool, search_type: Optional[SearchType] = None, init=False, user: KhojUser = None): # Initialize Content from Config if state.search_models: diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 882e8896..cb317275 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -16,6 +16,7 @@ from pgvector.django import CosineDistance from torch import Tensor from khoj.database.models import ( + Agent, ChatModelOptions, ClientApplication, Conversation, @@ -37,6 +38,7 @@ from khoj.database.models import ( UserRequests, UserSearchModelConfig, ) +from khoj.processor.conversation import prompts from khoj.search_filter.date_filter import DateFilter from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.word_filter import WordFilter @@ -391,6 +393,58 @@ class ClientApplicationAdapters: return await ClientApplication.objects.filter(client_id=client_id, client_secret=client_secret).afirst() +class AgentAdapters: + DEFAULT_AGENT_NAME = "khoj" + DEFAULT_AGENT_AVATAR = "https://khoj-web-bucket.s3.amazonaws.com/lamp-128.png" + + @staticmethod + async def aget_agent_by_id(agent_id: int, user: KhojUser): + agent = await Agent.objects.filter(id=agent_id).afirst() + # Check if it's accessible to the user + if agent and (agent.public or agent.creator == user): + return agent + return None + + @staticmethod + def get_all_accessible_agents(user: KhojUser = None): + return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct() + + @staticmethod + def get_conversation_agent_by_id(agent_id: int): + agent = Agent.objects.filter(id=agent_id).first() + if agent == AgentAdapters.get_default_agent(): + # If the agent is set to the default agent, then return None and let the default application code be used + return None + return agent + + @staticmethod + def get_default_agent(): + return Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first() + + @staticmethod + def create_default_agent(): + # First delete the existing default + Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).delete() + + default_conversation_config = ConversationAdapters.get_default_conversation_config() + default_personality = prompts.personality.format(current_date="placeholder") + + # The default agent is public and managed by the admin. It's handled a little differently than other agents. + return Agent.objects.create( + name=AgentAdapters.DEFAULT_AGENT_NAME, + public=True, + managed_by_admin=True, + chat_model=default_conversation_config, + personality=default_personality, + tools=["*"], + avatar=AgentAdapters.DEFAULT_AGENT_AVATAR, + ) + + @staticmethod + async def aget_default_agent(): + return await Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).afirst() + + class ConversationAdapters: @staticmethod def get_conversation_by_user( @@ -431,7 +485,12 @@ class ConversationAdapters: return Conversation.objects.filter(id=conversation_id).first() @staticmethod - async def acreate_conversation_session(user: KhojUser, client_application: ClientApplication = None): + async def acreate_conversation_session( + user: KhojUser, client_application: ClientApplication = None, agent_id: int = None + ): + if agent_id: + agent = await AgentAdapters.aget_agent_by_id(agent_id, user) + return await Conversation.objects.acreate(user=user, client=client_application, agent=agent) return await Conversation.objects.acreate(user=user, client=client_application) @staticmethod @@ -443,9 +502,14 @@ class ConversationAdapters: elif title: return await Conversation.objects.filter(user=user, client=client_application, title=title).afirst() else: - return await ( - Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() - ) or await Conversation.objects.acreate(user=user, client=client_application) + conversation = Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at") + + if await conversation.aexists(): + return await conversation.prefetch_related("agent").afirst() + + return await ( + Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at").afirst() + ) or await Conversation.objects.acreate(user=user, client=client_application) @staticmethod async def adelete_conversation_by_user( @@ -603,9 +667,14 @@ class ConversationAdapters: return random.sample(all_questions, max_results) @staticmethod - def get_valid_conversation_config(user: KhojUser): + def get_valid_conversation_config(user: KhojUser, conversation: Conversation): offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config() - conversation_config = ConversationAdapters.get_conversation_config(user) + + if conversation.agent and conversation.agent.chat_model: + conversation_config = conversation.agent.chat_model + else: + conversation_config = ConversationAdapters.get_conversation_config(user) + if conversation_config is None: conversation_config = ConversationAdapters.get_default_conversation_config() diff --git a/src/khoj/database/admin.py b/src/khoj/database/admin.py index 22400d14..cc1be7e4 100644 --- a/src/khoj/database/admin.py +++ b/src/khoj/database/admin.py @@ -6,6 +6,7 @@ from django.contrib.auth.admin import UserAdmin from django.http import HttpResponse from khoj.database.models import ( + Agent, ChatModelOptions, ClientApplication, Conversation, @@ -50,6 +51,7 @@ admin.site.register(ReflectiveQuestion) admin.site.register(UserSearchModelConfig) admin.site.register(TextToImageModelConfig) admin.site.register(ClientApplication) +admin.site.register(Agent) @admin.register(Entry) diff --git a/src/khoj/database/migrations/0031_agent_conversation_agent.py b/src/khoj/database/migrations/0031_agent_conversation_agent.py new file mode 100644 index 00000000..16586499 --- /dev/null +++ b/src/khoj/database/migrations/0031_agent_conversation_agent.py @@ -0,0 +1,52 @@ +# Generated by Django 4.2.10 on 2024-03-11 05:12 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0030_conversation_slug_and_title"), + ] + + operations = [ + migrations.CreateModel( + name="Agent", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("name", models.CharField(max_length=200)), + ("tuning", models.TextField()), + ("avatar", models.URLField(blank=True, default=None, max_length=400, null=True)), + ("tools", models.JSONField(default=list)), + ("public", models.BooleanField(default=False)), + ("managed_by_admin", models.BooleanField(default=False)), + ( + "chat_model", + models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.chatmodeloptions"), + ), + ( + "creator", + models.ForeignKey( + blank=True, + default=None, + null=True, + on_delete=django.db.models.deletion.CASCADE, + to=settings.AUTH_USER_MODEL, + ), + ), + ], + options={ + "abstract": False, + }, + ), + migrations.AddField( + model_name="conversation", + name="agent", + field=models.ForeignKey( + blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, to="database.agent" + ), + ), + ] diff --git a/src/khoj/database/migrations/0032_merge_20240322_0427.py b/src/khoj/database/migrations/0032_merge_20240322_0427.py new file mode 100644 index 00000000..aee557c0 --- /dev/null +++ b/src/khoj/database/migrations/0032_merge_20240322_0427.py @@ -0,0 +1,14 @@ +# Generated by Django 4.2.10 on 2024-03-22 04:27 + +from typing import List + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0031_agent_conversation_agent"), + ("database", "0031_alter_googleuser_locale"), + ] + + operations: List[str] = [] diff --git a/src/khoj/database/migrations/0033_rename_tuning_agent_personality.py b/src/khoj/database/migrations/0033_rename_tuning_agent_personality.py new file mode 100644 index 00000000..089c86c5 --- /dev/null +++ b/src/khoj/database/migrations/0033_rename_tuning_agent_personality.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.10 on 2024-03-23 16:01 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0032_merge_20240322_0427"), + ] + + operations = [ + migrations.RenameField( + model_name="agent", + old_name="tuning", + new_name="personality", + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 3f8f50b4..b8eeb8b1 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -1,7 +1,10 @@ import uuid from django.contrib.auth.models import AbstractUser +from django.core.exceptions import ValidationError from django.db import models +from django.db.models.signals import pre_save +from django.dispatch import receiver from pgvector.django import VectorField from phonenumber_field.modelfields import PhoneNumberField @@ -69,6 +72,37 @@ class Subscription(BaseModel): renewal_date = models.DateTimeField(null=True, default=None, blank=True) +class ChatModelOptions(BaseModel): + class ModelType(models.TextChoices): + OPENAI = "openai" + OFFLINE = "offline" + + max_prompt_size = models.IntegerField(default=None, null=True, blank=True) + tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) + chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf") + model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) + + +class Agent(BaseModel): + creator = models.ForeignKey( + KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True + ) # Creator will only be null when the agents are managed by admin + name = models.CharField(max_length=200) + personality = models.TextField() + avatar = models.URLField(max_length=400, default=None, null=True, blank=True) + tools = models.JSONField(default=list) # List of tools the agent has access to, like online search or notes search + public = models.BooleanField(default=False) + managed_by_admin = models.BooleanField(default=False) + chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE) + + +@receiver(pre_save, sender=Agent) +def check_public_name(sender, instance, **kwargs): + if instance.public: + if Agent.objects.filter(name=instance.name, public=True).exists(): + raise ValidationError(f"A public Agent with the name {instance.name} already exists.") + + class NotionConfig(BaseModel): token = models.CharField(max_length=200) user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) @@ -153,17 +187,6 @@ class SpeechToTextModelOptions(BaseModel): model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) -class ChatModelOptions(BaseModel): - class ModelType(models.TextChoices): - OPENAI = "openai" - OFFLINE = "offline" - - max_prompt_size = models.IntegerField(default=None, null=True, blank=True) - tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True) - chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf") - model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE) - - class UserConversationConfig(BaseModel): user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True) @@ -180,6 +203,7 @@ class Conversation(BaseModel): client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True) slug = models.CharField(max_length=200, default=None, null=True, blank=True) title = models.CharField(max_length=200, default=None, null=True, blank=True) + agent = models.ForeignKey(Agent, on_delete=models.SET_NULL, default=None, null=True, blank=True) class ReflectiveQuestion(BaseModel): diff --git a/src/khoj/processor/conversation/offline/chat_model.py b/src/khoj/processor/conversation/offline/chat_model.py index 437bdd3d..6348bb1a 100644 --- a/src/khoj/processor/conversation/offline/chat_model.py +++ b/src/khoj/processor/conversation/offline/chat_model.py @@ -6,6 +6,7 @@ from typing import Any, Iterator, List, Union from langchain.schema import ChatMessage +from khoj.database.models import Agent from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import ( ThreadedGenerator, @@ -141,6 +142,7 @@ def converse_offline( tokenizer_name=None, location_data: LocationData = None, user_name: str = None, + agent: Agent = None, ) -> Union[ThreadedGenerator, Iterator[str]]: """ Converse with user using Llama @@ -156,6 +158,15 @@ def converse_offline( # Initialize Variables compiled_references_message = "\n\n".join({f"{item}" for item in references}) + current_date = datetime.now().strftime("%Y-%m-%d") + + if agent and agent.personality: + system_prompt = prompts.custom_system_prompt_message_gpt4all.format( + name=agent.name, bio=agent.personality, current_date=current_date + ) + else: + system_prompt = prompts.system_prompt_message_gpt4all.format(current_date=current_date) + conversation_primer = prompts.query_prompt.format(query=user_query) if location_data: @@ -185,10 +196,9 @@ def converse_offline( conversation_primer = f"{prompts.notes_conversation_gpt4all.format(references=compiled_references_message)}\n{conversation_primer}" # Setup Prompt with Primer or Conversation History - current_date = datetime.now().strftime("%Y-%m-%d") messages = generate_chatml_messages_with_context( conversation_primer, - prompts.system_prompt_message_gpt4all.format(current_date=current_date), + system_prompt, conversation_log, model_name=model, max_prompt_size=max_prompt_size, diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index 644bb961..a76fd6e9 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -5,6 +5,7 @@ from typing import Optional from langchain.schema import ChatMessage +from khoj.database.models import Agent from khoj.processor.conversation import prompts from khoj.processor.conversation.openai.utils import ( chat_completion_with_backoff, @@ -115,6 +116,7 @@ def converse( tokenizer_name=None, location_data: LocationData = None, user_name: str = None, + agent: Agent = None, ): """ Converse with user using OpenAI's ChatGPT @@ -125,6 +127,13 @@ def converse( conversation_primer = prompts.query_prompt.format(query=user_query) + if agent and agent.personality: + system_prompt = prompts.custom_personality.format( + name=agent.name, bio=agent.personality, current_date=current_date + ) + else: + system_prompt = prompts.personality.format(current_date=current_date) + if location_data: location = f"{location_data.city}, {location_data.region}, {location_data.country}" location_prompt = prompts.user_location.format(location=location) @@ -152,7 +161,7 @@ def converse( # Setup Prompt with Primer or Conversation History messages = generate_chatml_messages_with_context( conversation_primer, - prompts.personality.format(current_date=current_date), + system_prompt, conversation_log, model, max_prompt_size, diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index b6465ca5..3bdb683a 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -21,6 +21,24 @@ Today is {current_date} in UTC. """.strip() ) +custom_personality = PromptTemplate.from_template( + """ +Your are {name}, a personal agent on Khoj. +Use your general knowledge and past conversation with the user as context to inform your responses. +You were created by Khoj Inc. with the following capabilities: + +- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you. +- Users can share files and other information with you using the Khoj Desktop, Obsidian or Emacs app. They can also drag and drop their files into the chat window. +- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question. +- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations. +- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay". + +Today is {current_date} in UTC. + +Instructions:\n{bio} +""".strip() +) + ## General Conversation ## -- general_conversation = PromptTemplate.from_template( @@ -61,6 +79,20 @@ Today is {current_date} in UTC. """.strip() ) +custom_system_prompt_message_gpt4all = PromptTemplate.from_template( + """ +You are {name}, a personal agent on Khoj. +- Use your general knowledge and past conversation with the user as context to inform your responses. +- If you do not know the answer, say 'I don't know.' +- Think step-by-step and ask questions to get the necessary information to answer the user's question. +- Do not print verbatim Notes unless necessary. + +Today is {current_date} in UTC. + +Instructions:\n{bio} + """.strip() +) + system_prompt_message_extract_questions_gpt4all = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective. - Write the question as if you can search for the answer on the user's personal notes. - Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?". diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index aa615113..fded31f3 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -13,7 +13,7 @@ from fastapi.requests import Request from fastapi.responses import Response from starlette.authentication import requires -from khoj.configure import configure_server, initialize_content +from khoj.configure import initialize_content from khoj.database.adapters import ( ConversationAdapters, EntryAdapters, diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index ff83b95e..e9d87f92 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -148,11 +148,12 @@ def chat_sessions( async def create_chat_session( request: Request, common: CommonQueryParams, + agent_id: Optional[int] = None, ): user = request.user.object # Create new Conversation Session - conversation = await ConversationAdapters.acreate_conversation_session(user, request.user.client_app) + conversation = await ConversationAdapters.acreate_conversation_session(user, request.user.client_app, agent_id) response = {"conversation_id": conversation.id} @@ -341,6 +342,7 @@ async def chat( llm_response, chat_metadata = await agenerate_chat_response( defiltered_query, meta_log, + conversation, compiled_references, online_results, inferred_queries, diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 724d640a..bdcba09d 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -10,10 +10,11 @@ import openai from fastapi import Depends, Header, HTTPException, Request, UploadFile from starlette.authentication import has_required_scope -from khoj.database.adapters import ConversationAdapters, EntryAdapters +from khoj.database.adapters import AgentAdapters, ConversationAdapters, EntryAdapters from khoj.database.models import ( ChatModelOptions, ClientApplication, + Conversation, KhojUser, Subscription, TextToImageModelConfig, @@ -364,6 +365,7 @@ async def send_message_to_model_wrapper( def generate_chat_response( q: str, meta_log: dict, + conversation: Conversation, compiled_references: List[str] = [], online_results: Dict[str, Any] = {}, inferred_queries: List[str] = [], @@ -379,6 +381,7 @@ def generate_chat_response( logger.debug(f"Conversation Types: {conversation_commands}") metadata = {} + agent = AgentAdapters.get_conversation_agent_by_id(conversation.agent.id) if conversation.agent else None try: partial_completion = partial( @@ -393,7 +396,7 @@ def generate_chat_response( conversation_id=conversation_id, ) - conversation_config = ConversationAdapters.get_valid_conversation_config(user) + conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation) if conversation_config.model_type == "offline": if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None: state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model) @@ -412,6 +415,7 @@ def generate_chat_response( tokenizer_name=conversation_config.tokenizer, location_data=location_data, user_name=user_name, + agent=agent, ) elif conversation_config.model_type == "openai": @@ -431,6 +435,7 @@ def generate_chat_response( tokenizer_name=conversation_config.tokenizer, location_data=location_data, user_name=user_name, + agent=agent, ) metadata.update({"chat_model": conversation_config.chat_model}) diff --git a/tests/conftest.py b/tests/conftest.py index a7ff1512..8533ba4f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ from khoj.configure import ( configure_search_types, ) from khoj.database.models import ( + Agent, GithubConfig, GithubRepoConfig, KhojApiUser, @@ -181,6 +182,28 @@ def api_user4(default_user4): ) +@pytest.mark.django_db +@pytest.fixture +def offline_agent(): + chat_model = ChatModelOptionsFactory() + return Agent.objects.create( + name="Accountant", + chat_model=chat_model, + personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent. ALWAYS RESPOND WITH A SUMMARY TOTAL OF HOW MUCH MONEY I HAVE SPENT.", + ) + + +@pytest.mark.django_db +@pytest.fixture +def openai_agent(): + chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai") + return Agent.objects.create( + name="Accountant", + chat_model=chat_model, + personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent.", + ) + + @pytest.fixture(scope="session") def search_models(search_config: SearchConfig): search_models = SearchModels() diff --git a/tests/test_gpt4all_chat_actors.py b/tests/test_gpt4all_chat_actors.py index 5cce4fc5..c0e0e54b 100644 --- a/tests/test_gpt4all_chat_actors.py +++ b/tests/test_gpt4all_chat_actors.py @@ -465,6 +465,47 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."" ) +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_agent_prompt_should_be_used(loaded_model, offline_agent): + "Chat actor should ask be tuned to think like an accountant based on the agent definition" + # Arrange + context = [ + f"""I went to the store and bought some bananas for 2.20""", + f"""I went to the store and bought some apples for 1.30""", + f"""I went to the store and bought some oranges for 6.00""", + ] + + # Act + response_gen = converse_offline( + references=context, # Assume context retrieved from notes for the user_query + user_query="What did I buy?", + loaded_model=loaded_model, + ) + response = "".join([response_chunk for response_chunk in response_gen]) + + # Assert that the model without the agent prompt does not include the summary of purchases + expected_responses = ["9.50", "9.5"] + assert all([expected_response not in response for expected_response in expected_responses]), ( + "Expected chat actor to summarize values of purchases" + response + ) + + # Act + response_gen = converse_offline( + references=context, # Assume context retrieved from notes for the user_query + user_query="What did I buy?", + loaded_model=loaded_model, + agent=offline_agent, + ) + response = "".join([response_chunk for response_chunk in response_gen]) + + # Assert that the model with the agent prompt does include the summary of purchases + expected_responses = ["9.50", "9.5"] + assert any([expected_response in response for expected_response in expected_responses]), ( + "Expected chat actor to summarize values of purchases" + response + ) + + # ---------------------------------------------------------------------------------------------------- def test_chat_does_not_exceed_prompt_size(loaded_model): "Ensure chat context and response together do not exceed max prompt size for the model" diff --git a/tests/test_gpt4all_chat_director.py b/tests/test_gpt4all_chat_director.py index 87ed5116..eea6cafa 100644 --- a/tests/test_gpt4all_chat_director.py +++ b/tests/test_gpt4all_chat_director.py @@ -6,6 +6,7 @@ import pytest from faker import Faker from freezegun import freeze_time +from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import message_to_log from khoj.routers.helpers import aget_relevant_information_sources @@ -26,20 +27,20 @@ def generate_history(message_list): # Generate conversation logs conversation_log = {"chat": []} for user_message, gpt_message, context in message_list: - conversation_log["chat"] += message_to_log( + message_to_log( user_message, gpt_message, {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, + conversation_log=conversation_log.get("chat", []), ) return conversation_log -def populate_chat_history(message_list, user): +def create_conversation(message_list, user, agent=None): # Generate conversation logs conversation_log = generate_history(message_list) - # Update Conversation Metadata Logs in Database - ConversationFactory(user=user, conversation_log=conversation_log) + return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent) # Tests @@ -114,7 +115,7 @@ def test_answer_from_chat_history(client_offline_chat, default_user2): ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true') @@ -141,7 +142,7 @@ def test_answer_from_currently_retrieved_content(client_offline_chat, default_us ["Testatron was born on 1st April 1984 in Testville."], ), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"') @@ -165,7 +166,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin ["Testatron was born on 1st April 1984 in Testville."], ), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="Where was I born?"') @@ -191,7 +192,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(client_offline ("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="Where was I born?"') @@ -217,7 +218,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, def ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true') @@ -238,7 +239,7 @@ def test_answer_using_general_command(client_offline_chat, default_user2): # Arrange query = urllib.parse.quote("/general Where was Xi Li born?") message_list = [] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = client_offline_chat.get(f"/api/chat?q={query}&stream=true") @@ -256,7 +257,7 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat, # Arrange query = urllib.parse.quote("/notes Where was Xi Li born?") message_list = [] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = client_offline_chat.get(f"/api/chat?q={query}&stream=true") @@ -275,7 +276,7 @@ def test_answer_using_file_filter(client_offline_chat, default_user2): no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"') answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"') message_list = [] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8") @@ -293,7 +294,7 @@ def test_answer_not_known_using_notes_command(client_offline_chat, default_user2 # Arrange query = urllib.parse.quote("/notes Where was Testatron born?") message_list = [] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = client_offline_chat.get(f"/api/chat?q={query}&stream=true") @@ -351,7 +352,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(client ("When was I born?", "You were born on 1st April 1984.", []), ("Where was I born?", "You were born Testville.", []), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = client_offline_chat.get( @@ -394,14 +395,14 @@ def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_ @pytest.mark.xfail(reason="Chat director not capable of answering this question yet") @pytest.mark.chatquality @pytest.mark.django_db(transaction=True) -def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2): +def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2: KhojUser): # Arrange message_list = [ ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ("Where was I born?", "You were born Testville.", []), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true') @@ -415,13 +416,77 @@ def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, defa ) +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +@pytest.mark.django_db(transaction=True) +def test_answer_in_chat_history_by_conversation_id(client_offline_chat, default_user2: KhojUser): + # Arrange + message_list = [ + ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), + ("When was I born?", "You were born on 1st April 1984.", []), + ("What's my favorite color", "Your favorite color is green.", []), + ("Where was I born?", "You were born Testville.", []), + ] + message_list2 = [ + ("Hello, my name is Julia. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), + ("When was I born?", "You were born on 14th August 1947.", []), + ("What's my favorite color", "Your favorite color is maroon.", []), + ("Where was I born?", "You were born in a potato farm.", []), + ] + conversation = create_conversation(message_list, default_user2) + create_conversation(message_list2, default_user2) + + # Act + response = client_offline_chat.get( + f'/api/chat?q="What is my favorite color?"&conversation_id={conversation.id}&stream=true' + ) + response_message = response.content.decode("utf-8") + + # Assert + expected_responses = ["green"] + assert response.status_code == 200 + assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( + "Expected green in response, but got: " + response_message + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.xfail(reason="Chat director not great at adhering to agent instructions yet") +@pytest.mark.chatquality +@pytest.mark.django_db(transaction=True) +def test_answer_in_chat_history_by_conversation_id_with_agent( + client_offline_chat, default_user2: KhojUser, offline_agent: Agent +): + # Arrange + message_list = [ + ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), + ("When was I born?", "You were born on 1st April 1984.", []), + ("What's my favorite color", "Your favorite color is green.", []), + ("Where was I born?", "You were born Testville.", []), + ("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []), + ] + conversation = create_conversation(message_list, default_user2, offline_agent) + + # Act + query = urllib.parse.quote("/general What did I eat for breakfast?") + response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true") + response_message = response.content.decode("utf-8") + + # Assert that agent only responds with the summary of spending + expected_responses = ["13.00", "13", "13.0", "thirteen"] + assert response.status_code == 200 + assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( + "Expected green in response, but got: " + response_message + ) + + @pytest.mark.chatquality @pytest.mark.django_db(transaction=True) def test_answer_chat_history_very_long(client_offline_chat, default_user2): # Arrange message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true') @@ -525,7 +590,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat): ), ("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []), ] - chat_history = populate_chat_history(chat_log) + chat_history = create_conversation(chat_log) # Act tools = await aget_relevant_information_sources(user_query, chat_history) diff --git a/tests/test_openai_chat_actors.py b/tests/test_openai_chat_actors.py index 8db577e9..b08205de 100644 --- a/tests/test_openai_chat_actors.py +++ b/tests/test_openai_chat_actors.py @@ -413,6 +413,42 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet."" ) +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +def test_agent_prompt_should_be_used(openai_agent): + "Chat actor should ask be tuned to think like an accountant based on the agent definition" + # Arrange + context = [ + f"""I went to the store and bought some bananas for 2.20""", + f"""I went to the store and bought some apples for 1.30""", + f"""I went to the store and bought some oranges for 6.00""", + ] + expected_responses = ["9.50", "9.5"] + + # Act + response_gen = converse( + references=context, # Assume context retrieved from notes for the user_query + user_query="What did I buy?", + api_key=api_key, + ) + no_agent_response = "".join([response_chunk for response_chunk in response_gen]) + response_gen = converse( + references=context, # Assume context retrieved from notes for the user_query + user_query="What did I buy?", + api_key=api_key, + agent=openai_agent, + ) + agent_response = "".join([response_chunk for response_chunk in response_gen]) + + # Assert that the model without the agent prompt does not include the summary of purchases + assert all([expected_response not in no_agent_response for expected_response in expected_responses]), ( + "Expected chat actor to summarize values of purchases" + no_agent_response + ) + assert any([expected_response in agent_response for expected_response in expected_responses]), ( + "Expected chat actor to summarize values of purchases" + agent_response + ) + + # ---------------------------------------------------------------------------------------------------- @pytest.mark.anyio @pytest.mark.django_db(transaction=True) diff --git a/tests/test_openai_chat_director.py b/tests/test_openai_chat_director.py index 890605b1..fffaa0d9 100644 --- a/tests/test_openai_chat_director.py +++ b/tests/test_openai_chat_director.py @@ -5,13 +5,10 @@ from urllib.parse import quote import pytest from freezegun import freeze_time -from khoj.database.models import KhojUser +from khoj.database.models import Agent, KhojUser from khoj.processor.conversation import prompts from khoj.processor.conversation.utils import message_to_log -from khoj.routers.helpers import ( - aget_relevant_information_sources, - aget_relevant_output_modes, -) +from khoj.routers.helpers import aget_relevant_information_sources from tests.helpers import ConversationFactory # Initialize variables for tests @@ -29,20 +26,21 @@ def generate_history(message_list): # Generate conversation logs conversation_log = {"chat": []} for user_message, gpt_message, context in message_list: - conversation_log["chat"] += message_to_log( + message_to_log( user_message, gpt_message, {"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}}, + conversation_log=conversation_log.get("chat", []), ) return conversation_log -def populate_chat_history(message_list, user): +def create_conversation(message_list, user, agent=None): # Generate conversation logs conversation_log = generate_history(message_list) # Update Conversation Metadata Logs in Database - ConversationFactory(user=user, conversation_log=conversation_log) + return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent) # Tests @@ -116,7 +114,7 @@ def test_answer_from_chat_history(chat_client, default_user2: KhojUser): ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true') @@ -143,7 +141,7 @@ def test_answer_from_currently_retrieved_content(chat_client, default_user2: Kho ["Testatron was born on 1st April 1984 in Testville."], ), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"') @@ -167,7 +165,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_n ["Testatron was born on 1st April 1984 in Testville."], ), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"') @@ -190,7 +188,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, d ("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = chat_client.get(f'/api/chat?q="Where was I born?"') @@ -215,7 +213,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), ("When was I born?", "You were born on 1st April 1984.", []), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true') @@ -244,7 +242,7 @@ def test_answer_using_general_command(chat_client, default_user2: KhojUser): # Arrange query = urllib.parse.quote("/general Where was Xi Li born?") message_list = [] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = chat_client.get(f"/api/chat?q={query}&stream=true") @@ -262,7 +260,7 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client, default_ # Arrange query = urllib.parse.quote("/notes Where was Xi Li born?") message_list = [] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = chat_client.get(f"/api/chat?q={query}&stream=true") @@ -280,7 +278,7 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default # Arrange query = urllib.parse.quote("/notes Where was Testatron born?") message_list = [] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true") @@ -335,7 +333,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c ("When was I born?", "You were born on 1st April 1984.", []), ("Where was I born?", "You were born Testville.", []), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = chat_client.get(f'/api/chat?q="Write a haiku about unit testing. Do not say anything else."&stream=true') @@ -387,7 +385,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user ("When was I born?", "You were born on 1st April 1984.", []), ("Where was I born?", "You were born Testville.", []), ] - populate_chat_history(message_list, default_user2) + create_conversation(message_list, default_user2) # Act response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true') @@ -401,6 +399,68 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user ) +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +@pytest.mark.django_db(transaction=True) +def test_answer_in_chat_history_by_conversation_id(chat_client, default_user2: KhojUser): + # Arrange + message_list = [ + ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), + ("When was I born?", "You were born on 1st April 1984.", []), + ("What's my favorite color", "Your favorite color is green.", []), + ("Where was I born?", "You were born Testville.", []), + ] + message_list2 = [ + ("Hello, my name is Julia. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), + ("When was I born?", "You were born on 14th August 1947.", []), + ("What's my favorite color", "Your favorite color is maroon.", []), + ("Where was I born?", "You were born in a potato farm.", []), + ] + conversation = create_conversation(message_list, default_user2) + create_conversation(message_list2, default_user2) + + # Act + query = urllib.parse.quote("/general What is my favorite color?") + response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true") + response_message = response.content.decode("utf-8") + + # Assert + expected_responses = ["green"] + assert response.status_code == 200 + assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( + "Expected green in response, but got: " + response_message + ) + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.chatquality +@pytest.mark.django_db(transaction=True) +def test_answer_in_chat_history_by_conversation_id_with_agent( + chat_client, default_user2: KhojUser, openai_agent: Agent +): + # Arrange + message_list = [ + ("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []), + ("When was I born?", "You were born on 1st April 1984.", []), + ("What's my favorite color", "Your favorite color is green.", []), + ("Where was I born?", "You were born Testville.", []), + ("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []), + ] + conversation = create_conversation(message_list, default_user2, openai_agent) + + # Act + query = urllib.parse.quote("/general What did I eat for breakfast?") + response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true") + response_message = response.content.decode("utf-8") + + # Assert that agent only responds with the summary of spending + expected_responses = ["13.00", "13", "13.0", "thirteen"] + assert response.status_code == 200 + assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( + "Expected green in response, but got: " + response_message + ) + + # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) @pytest.mark.chatquality From fdf78525b4bb6631e74fc0f8408f51e32049140f Mon Sep 17 00:00:00 2001 From: sabaimran <65192171+sabaimran@users.noreply.github.com> Date: Tue, 26 Mar 2024 05:43:24 -0700 Subject: [PATCH 2/6] Part 2: Add web UI updates for basic agent interactions (#675) * Initial pass at backend changes to support agents - Add a db model for Agents, attaching them to conversations - When an agent is added to a conversation, override the system prompt to tweak the instructions - Agents can be configured with prompt modification, model specification, a profile picture, and other things - Admin-configured models will not be editable by individual users - Add unit tests to verify agent behavior. Unit tests demonstrate imperfect adherence to prompt specifications * Customize default behaviors for conversations without agents or with default agents * Add a new web client route for viewing all agents * Use agent_id for getting correct agent * Add web UI views for agents - Add a page to view all agents - Add slugs to manage agents - Add a view to view single agent - Display active agent when in chat window - Fix post-login redirect issue * Fix agent view * Spruce up the 404 page and improve the overall layout for agents pages * Create chat actor for directly reading webpages based on user message - Add prompt for the read webpages chat actor to extract, infer webpage links - Make chat actor infer or extract webpage to read directly from user message - Rename previous read_webpage function to more narrow read_webpage_at_url function * Rename agents_page -> agent_page * Fix unit test for adding the filename to the compiled markdown entry * Fix layout of agent, agents pages * Merge migrations * Let the name, slug of the default agent be Khoj, khoj * Fix chat-related unit tests * Add webpage chat command for read web pages requested by user Update auto chat command inference prompt to show example of when to use webpage chat command (i.e when url is directly provided in link) * Support webpage command in chat API - Fallback to use webpage when SERPER not setup and online command was attempted - Do not stop responding if can't retrieve online results. Try to respond without the online context * Test select webpage as data source and extract web urls chat actors * Tweak prompts to extract information from webpages, online results - Show more of the truncated messages for debugging context - Update Khoj personality prompt to encourage it to remember it's capabilities * Rename extract_content online results field to webpages * Parallelize simple webpage read and extractor Similar to what is being done with search_online with olostep * Pass multiple webpages with their urls in online results context Previously even if MAX_WEBPAGES_TO_READ was > 1, only 1 extracted content would ever be passed. URL of the extracted webpage content wasn't passed to clients in online results context. This limited them from being rendered * Render webpage read in chat response references on Web, Desktop apps * Time chat actor responses & chat api request start for perf analysis * Increase the keep alive timeout in the main application for testing * Do not pipe access/error logs to separate files. Flow to stdout/stderr * [Temp] Reduce to 1 gunicorn worker * Change prod docker image to use jammy, rather than nvidia base image * Use Khoj icon when Khoj web is installed on iOS as a PWA * Make slug required for agents * Simplify calling logic and prevent agent access for unauthenticated users * Standardize to use personality over tuning in agent nomenclature * Make filtering logic more stringent for accessible agents and remove unused method: * Format chat message query --------- Co-authored-by: Debanjum Singh Solanky --- gunicorn-config.py | 6 +- prod.Dockerfile | 5 +- src/interface/desktop/chat.html | 11 +- src/khoj/configure.py | 2 + src/khoj/database/adapters/__init__.py | 50 ++- .../0031_agent_conversation_agent.py | 3 +- src/khoj/database/models/__init__.py | 20 +- src/khoj/interface/web/404.html | 42 ++- src/khoj/interface/web/agent.html | 286 ++++++++++++++ src/khoj/interface/web/agents.html | 201 ++++++++++ src/khoj/interface/web/assets/khoj.css | 2 +- src/khoj/interface/web/base_config.html | 2 +- src/khoj/interface/web/chat.html | 351 ++++++++++++++++-- src/khoj/interface/web/search.html | 1 + src/khoj/interface/web/utils.html | 36 +- src/khoj/main.py | 4 +- .../content/markdown/markdown_to_entries.py | 2 +- .../conversation/offline/chat_model.py | 4 +- src/khoj/processor/conversation/openai/gpt.py | 8 +- src/khoj/processor/conversation/prompts.py | 61 ++- src/khoj/processor/tools/online_search.py | 56 ++- src/khoj/routers/api_agents.py | 43 +++ src/khoj/routers/api_chat.py | 62 +++- src/khoj/routers/auth.py | 6 +- src/khoj/routers/helpers.py | 56 ++- src/khoj/routers/web_client.py | 83 ++++- src/khoj/utils/helpers.py | 15 +- tests/test_gpt4all_chat_director.py | 4 +- tests/test_helpers.py | 7 +- tests/test_markdown_to_entries.py | 4 +- tests/test_openai_chat_actors.py | 29 ++ 31 files changed, 1332 insertions(+), 130 deletions(-) create mode 100644 src/khoj/interface/web/agent.html create mode 100644 src/khoj/interface/web/agents.html create mode 100644 src/khoj/routers/api_agents.py diff --git a/gunicorn-config.py b/gunicorn-config.py index bfed49e7..ea382346 100644 --- a/gunicorn-config.py +++ b/gunicorn-config.py @@ -1,10 +1,10 @@ import multiprocessing bind = "0.0.0.0:42110" -workers = 8 +workers = 1 worker_class = "uvicorn.workers.UvicornWorker" timeout = 120 keep_alive = 60 -accesslog = "access.log" -errorlog = "error.log" +accesslog = "-" +errorlog = "-" loglevel = "debug" diff --git a/prod.Dockerfile b/prod.Dockerfile index 413835d0..0da5363a 100644 --- a/prod.Dockerfile +++ b/prod.Dockerfile @@ -1,12 +1,9 @@ -# Use Nvidia's latest Ubuntu 22.04 image as the base image -FROM nvidia/cuda:12.2.0-devel-ubuntu22.04 +FROM ubuntu:jammy LABEL org.opencontainers.image.source https://github.com/khoj-ai/khoj # Install System Dependencies RUN apt update -y && apt -y install python3-pip libsqlite3-0 ffmpeg libsm6 libxext6 -# Install Optional Dependencies -RUN apt install vim -y WORKDIR /app diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 94cde782..f37ae562 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -87,7 +87,7 @@ function generateOnlineReference(reference, index) { // Generate HTML for Chat Reference - let title = reference.title; + let title = reference.title || reference.link; let link = reference.link; let snippet = reference.snippet; let question = reference.question; @@ -191,6 +191,15 @@ referenceSection.appendChild(polishedReference); } } + + if (onlineReference.webpages && onlineReference.webpages.length > 0) { + numOnlineReferences += onlineReference.webpages.length; + for (let index in onlineReference.webpages) { + let reference = onlineReference.webpages[index]; + let polishedReference = generateOnlineReference(reference, index); + referenceSection.appendChild(polishedReference); + } + } } return numOnlineReferences; diff --git a/src/khoj/configure.py b/src/khoj/configure.py index fb3a93ce..0adbe889 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -268,6 +268,7 @@ def initialize_content(regenerate: bool, search_type: Optional[SearchType] = Non def configure_routes(app): # Import APIs here to setup search types before while configuring server from khoj.routers.api import api + from khoj.routers.api_agents import api_agents from khoj.routers.api_chat import api_chat from khoj.routers.api_config import api_config from khoj.routers.indexer import indexer @@ -275,6 +276,7 @@ def configure_routes(app): app.include_router(api, prefix="/api") app.include_router(api_chat, prefix="/api/chat") + app.include_router(api_agents, prefix="/api/agents") app.include_router(api_config, prefix="/api/config") app.include_router(indexer, prefix="/api/v1/index") app.include_router(web_client) diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index cb317275..25f781fe 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -394,20 +394,32 @@ class ClientApplicationAdapters: class AgentAdapters: - DEFAULT_AGENT_NAME = "khoj" + DEFAULT_AGENT_NAME = "Khoj" DEFAULT_AGENT_AVATAR = "https://khoj-web-bucket.s3.amazonaws.com/lamp-128.png" + DEFAULT_AGENT_SLUG = "khoj" @staticmethod - async def aget_agent_by_id(agent_id: int, user: KhojUser): - agent = await Agent.objects.filter(id=agent_id).afirst() - # Check if it's accessible to the user - if agent and (agent.public or agent.creator == user): - return agent - return None + async def aget_agent_by_slug(agent_slug: str, user: KhojUser): + return await Agent.objects.filter( + (Q(slug__iexact=agent_slug.lower())) & (Q(public=True) | Q(creator=user)) + ).afirst() + + @staticmethod + def get_agent_by_slug(slug: str, user: KhojUser = None): + if user: + return Agent.objects.filter((Q(slug__iexact=slug.lower())) & (Q(public=True) | Q(creator=user))).first() + return Agent.objects.filter(slug__iexact=slug.lower(), public=True).first() @staticmethod def get_all_accessible_agents(user: KhojUser = None): - return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct() + if user: + return Agent.objects.filter(Q(public=True) | Q(creator=user)).distinct().order_by("created_at") + return Agent.objects.filter(public=True).order_by("created_at") + + @staticmethod + async def aget_all_accessible_agents(user: KhojUser = None) -> List[Agent]: + agents = await sync_to_async(AgentAdapters.get_all_accessible_agents)(user) + return await sync_to_async(list)(agents) @staticmethod def get_conversation_agent_by_id(agent_id: int): @@ -423,12 +435,19 @@ class AgentAdapters: @staticmethod def create_default_agent(): - # First delete the existing default - Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).delete() - default_conversation_config = ConversationAdapters.get_default_conversation_config() default_personality = prompts.personality.format(current_date="placeholder") + agent = Agent.objects.filter(name=AgentAdapters.DEFAULT_AGENT_NAME).first() + + if agent: + agent.personality = default_personality + agent.chat_model = default_conversation_config + agent.slug = AgentAdapters.DEFAULT_AGENT_SLUG + agent.name = AgentAdapters.DEFAULT_AGENT_NAME + agent.save() + return agent + # The default agent is public and managed by the admin. It's handled a little differently than other agents. return Agent.objects.create( name=AgentAdapters.DEFAULT_AGENT_NAME, @@ -438,6 +457,7 @@ class AgentAdapters: personality=default_personality, tools=["*"], avatar=AgentAdapters.DEFAULT_AGENT_AVATAR, + slug=AgentAdapters.DEFAULT_AGENT_SLUG, ) @staticmethod @@ -486,10 +506,12 @@ class ConversationAdapters: @staticmethod async def acreate_conversation_session( - user: KhojUser, client_application: ClientApplication = None, agent_id: int = None + user: KhojUser, client_application: ClientApplication = None, agent_slug: str = None ): - if agent_id: - agent = await AgentAdapters.aget_agent_by_id(agent_id, user) + if agent_slug: + agent = await AgentAdapters.aget_agent_by_slug(agent_slug, user) + if agent is None: + raise HTTPException(status_code=400, detail="No such agent currently exists.") return await Conversation.objects.acreate(user=user, client=client_application, agent=agent) return await Conversation.objects.acreate(user=user, client=client_application) diff --git a/src/khoj/database/migrations/0031_agent_conversation_agent.py b/src/khoj/database/migrations/0031_agent_conversation_agent.py index 16586499..1d08a118 100644 --- a/src/khoj/database/migrations/0031_agent_conversation_agent.py +++ b/src/khoj/database/migrations/0031_agent_conversation_agent.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.10 on 2024-03-11 05:12 +# Generated by Django 4.2.10 on 2024-03-13 07:38 import django.db.models.deletion from django.conf import settings @@ -23,6 +23,7 @@ class Migration(migrations.Migration): ("tools", models.JSONField(default=list)), ("public", models.BooleanField(default=False)), ("managed_by_admin", models.BooleanField(default=False)), + ("slug", models.CharField(max_length=200)), ( "chat_model", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.chatmodeloptions"), diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index b8eeb8b1..364a6d1a 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -1,4 +1,5 @@ import uuid +from random import choice from django.contrib.auth.models import AbstractUser from django.core.exceptions import ValidationError @@ -94,13 +95,28 @@ class Agent(BaseModel): public = models.BooleanField(default=False) managed_by_admin = models.BooleanField(default=False) chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE) + slug = models.CharField(max_length=200) @receiver(pre_save, sender=Agent) -def check_public_name(sender, instance, **kwargs): - if instance.public: +def verify_agent(sender, instance, **kwargs): + # check if this is a new instance + if instance._state.adding: if Agent.objects.filter(name=instance.name, public=True).exists(): raise ValidationError(f"A public Agent with the name {instance.name} already exists.") + if Agent.objects.filter(name=instance.name, creator=instance.creator).exists(): + raise ValidationError(f"A private Agent with the name {instance.name} already exists.") + + slug = instance.name.lower().replace(" ", "-") + observed_random_numbers = set() + while Agent.objects.filter(slug=slug).exists(): + try: + random_number = choice([i for i in range(0, 1000) if i not in observed_random_numbers]) + except IndexError: + raise ValidationError("Unable to generate a unique slug for the Agent. Please try again later.") + observed_random_numbers.add(random_number) + slug = f"{slug}-{random_number}" + instance.slug = slug class NotionConfig(BaseModel): diff --git a/src/khoj/interface/web/404.html b/src/khoj/interface/web/404.html index 7041ff80..0762bde8 100644 --- a/src/khoj/interface/web/404.html +++ b/src/khoj/interface/web/404.html @@ -2,14 +2,19 @@ Khoj: An AI Personal Assistant for your digital brain - - + + + + + {% import 'utils.html' as utils %} + {{ utils.heading_pane(user_photo, username, is_active, has_documents) }} +
-

Oops, this is awkward. That page couldn't be found.

+

Oops, this is awkward. Looks like there's nothing here.

- Go Home + Go Home @@ -18,5 +23,34 @@ body.not-found { padding: 0 10% } + + body { + background-color: var(--background-color); + color: var(--main-text-color); + text-align: center; + font-family: var(--font-family); + font-size: medium; + font-weight: 300; + line-height: 1.5em; + height: 100vh; + margin: 0; + } + + body a.redirect-link { + font-size: 18px; + font-weight: bold; + background-color: var(--primary); + text-decoration: none; + border: 1px solid var(--main-text-color); + color: var(--main-text-color); + border-radius: 8px; + padding: 4px; + } + + body a.redirect-link:hover { + background-color: var(--main-text-color); + color: var(--primary); + } + diff --git a/src/khoj/interface/web/agent.html b/src/khoj/interface/web/agent.html new file mode 100644 index 00000000..6e6619cf --- /dev/null +++ b/src/khoj/interface/web/agent.html @@ -0,0 +1,286 @@ + + + + + Khoj - Agents + + + + + + + + + {% import 'utils.html' as utils %} + {{ utils.heading_pane(user_photo, username, is_active, has_documents) }} +
+
+
+
Agent Settings
+
+
+
+
+ Agent Avatar + +
+
Instructions
+
+

{{ agent.personality }}

+
+
+
+

Public

+ +
+ + + +
+
+
+ + + + + diff --git a/src/khoj/interface/web/agents.html b/src/khoj/interface/web/agents.html new file mode 100644 index 00000000..dc26606c --- /dev/null +++ b/src/khoj/interface/web/agents.html @@ -0,0 +1,201 @@ + + + + + Khoj - Agents + + + + + + + + + {% import 'utils.html' as utils %} + {{ utils.heading_pane(user_photo, username, is_active, has_documents) }} + + +
+
+
+

Agents

+ +
+ {% for agent in agents %} +
+ +
+ {{ agent.name }} +
+
+
+ +

{{ agent.name }}

+
+

{{ agent.personality }}

+
+
+ +
+
+ {% endfor %} +
+
+ + + + + diff --git a/src/khoj/interface/web/assets/khoj.css b/src/khoj/interface/web/assets/khoj.css index 7ba93c6a..3d7e7d4a 100644 --- a/src/khoj/interface/web/assets/khoj.css +++ b/src/khoj/interface/web/assets/khoj.css @@ -130,7 +130,7 @@ img.khoj-logo { background-color: var(--background-color); min-width: 160px; box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2); - right: 15vw; + right: 5vw; top: 64px; z-index: 1; opacity: 0; diff --git a/src/khoj/interface/web/base_config.html b/src/khoj/interface/web/base_config.html index 5c26e060..870f4eb9 100644 --- a/src/khoj/interface/web/base_config.html +++ b/src/khoj/interface/web/base_config.html @@ -162,7 +162,7 @@ height: 40px; } .card-title { - font-size: 20px; + font-size: medium; font-weight: normal; margin: 0; padding: 0; diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 35047c31..52750c12 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -5,6 +5,7 @@ Khoj - Chat + @@ -12,15 +13,16 @@ @@ -1205,13 +1328,27 @@ To get started, just start typing below. You can also type / to see a list of co +
-
+ + +