mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Initial pass at backend changes to support agents
- Add a db model for Agents, attaching them to conversations - When an agent is added to a conversation, override the system prompt to tweak the instructions - Agents can be configured with prompt modification, model specification, a profile picture, and other things - Admin-configured models will not be editable by individual users - Add unit tests to verify agent behavior. Unit tests demonstrate imperfect adherence to prompt specifications
This commit is contained in:
parent
1da453306e
commit
9b88976f36
14 changed files with 428 additions and 58 deletions
|
@ -16,6 +16,7 @@ from pgvector.django import CosineDistance
|
|||
from torch import Tensor
|
||||
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatModelOptions,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
|
@ -391,6 +392,12 @@ class ClientApplicationAdapters:
|
|||
return await ClientApplication.objects.filter(client_id=client_id, client_secret=client_secret).afirst()
|
||||
|
||||
|
||||
class AgentAdapters:
|
||||
@staticmethod
|
||||
async def aget_agent_by_id(agent_id: int):
|
||||
return await Agent.objects.filter(id=agent_id).afirst()
|
||||
|
||||
|
||||
class ConversationAdapters:
|
||||
@staticmethod
|
||||
def get_conversation_by_user(
|
||||
|
@ -431,7 +438,12 @@ class ConversationAdapters:
|
|||
return Conversation.objects.filter(id=conversation_id).first()
|
||||
|
||||
@staticmethod
|
||||
async def acreate_conversation_session(user: KhojUser, client_application: ClientApplication = None):
|
||||
async def acreate_conversation_session(
|
||||
user: KhojUser, client_application: ClientApplication = None, agent_id: int = None
|
||||
):
|
||||
if agent_id:
|
||||
agent = await AgentAdapters.aget_agent_by_id(id)
|
||||
return await Conversation.objects.acreate(user=user, client=client_application, agent=agent)
|
||||
return await Conversation.objects.acreate(user=user, client=client_application)
|
||||
|
||||
@staticmethod
|
||||
|
@ -446,7 +458,7 @@ class ConversationAdapters:
|
|||
conversation = Conversation.objects.filter(user=user, client=client_application).order_by("-updated_at")
|
||||
|
||||
if await conversation.aexists():
|
||||
return await conversation.afirst()
|
||||
return await conversation.prefetch_related("agent").afirst()
|
||||
|
||||
return await Conversation.objects.acreate(user=user, client=client_application, slug=slug)
|
||||
|
||||
|
@ -606,9 +618,14 @@ class ConversationAdapters:
|
|||
return random.sample(all_questions, max_results)
|
||||
|
||||
@staticmethod
|
||||
def get_valid_conversation_config(user: KhojUser):
|
||||
def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
|
||||
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config()
|
||||
conversation_config = ConversationAdapters.get_conversation_config(user)
|
||||
|
||||
if conversation.agent and conversation.agent.chat_model:
|
||||
conversation_config = conversation.agent.chat_model
|
||||
else:
|
||||
conversation_config = ConversationAdapters.get_conversation_config(user)
|
||||
|
||||
if conversation_config is None:
|
||||
conversation_config = ConversationAdapters.get_default_conversation_config()
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from django.contrib.auth.admin import UserAdmin
|
|||
from django.http import HttpResponse
|
||||
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatModelOptions,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
|
@ -50,6 +51,7 @@ admin.site.register(ReflectiveQuestion)
|
|||
admin.site.register(UserSearchModelConfig)
|
||||
admin.site.register(TextToImageModelConfig)
|
||||
admin.site.register(ClientApplication)
|
||||
admin.site.register(Agent)
|
||||
|
||||
|
||||
@admin.register(Entry)
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
# Generated by Django 4.2.10 on 2024-03-11 05:12
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0030_conversation_slug_and_title"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="Agent",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("name", models.CharField(max_length=200)),
|
||||
("tuning", models.TextField()),
|
||||
("avatar", models.URLField(blank=True, default=None, max_length=400, null=True)),
|
||||
("tools", models.JSONField(default=list)),
|
||||
("public", models.BooleanField(default=False)),
|
||||
("managed_by_admin", models.BooleanField(default=False)),
|
||||
(
|
||||
"chat_model",
|
||||
models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to="database.chatmodeloptions"),
|
||||
),
|
||||
(
|
||||
"creator",
|
||||
models.ForeignKey(
|
||||
blank=True,
|
||||
default=None,
|
||||
null=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to=settings.AUTH_USER_MODEL,
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="conversation",
|
||||
name="agent",
|
||||
field=models.ForeignKey(
|
||||
blank=True, default=None, null=True, on_delete=django.db.models.deletion.SET_NULL, to="database.agent"
|
||||
),
|
||||
),
|
||||
]
|
|
@ -69,6 +69,28 @@ class Subscription(BaseModel):
|
|||
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
|
||||
|
||||
|
||||
class ChatModelOptions(BaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
OFFLINE = "offline"
|
||||
|
||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
|
||||
|
||||
class Agent(BaseModel):
|
||||
creator = models.ForeignKey(KhojUser, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
name = models.CharField(max_length=200)
|
||||
tuning = models.TextField()
|
||||
avatar = models.URLField(max_length=400, default=None, null=True, blank=True)
|
||||
tools = models.JSONField(default=list) # List of tools the agent has access to, like online search or notes search
|
||||
public = models.BooleanField(default=False)
|
||||
managed_by_admin = models.BooleanField(default=False)
|
||||
chat_model = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE)
|
||||
|
||||
|
||||
class NotionConfig(BaseModel):
|
||||
token = models.CharField(max_length=200)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
@ -153,17 +175,6 @@ class SpeechToTextModelOptions(BaseModel):
|
|||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
|
||||
|
||||
class ChatModelOptions(BaseModel):
|
||||
class ModelType(models.TextChoices):
|
||||
OPENAI = "openai"
|
||||
OFFLINE = "offline"
|
||||
|
||||
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
|
||||
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
chat_model = models.CharField(max_length=200, default="mistral-7b-instruct-v0.1.Q4_0.gguf")
|
||||
model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OFFLINE)
|
||||
|
||||
|
||||
class UserConversationConfig(BaseModel):
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
setting = models.ForeignKey(ChatModelOptions, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
|
@ -180,6 +191,7 @@ class Conversation(BaseModel):
|
|||
client = models.ForeignKey(ClientApplication, on_delete=models.CASCADE, default=None, null=True, blank=True)
|
||||
slug = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
title = models.CharField(max_length=200, default=None, null=True, blank=True)
|
||||
agent = models.ForeignKey(Agent, on_delete=models.SET_NULL, default=None, null=True, blank=True)
|
||||
|
||||
|
||||
class ReflectiveQuestion(BaseModel):
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import Any, Iterator, List, Union
|
|||
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import (
|
||||
ThreadedGenerator,
|
||||
|
@ -141,6 +142,7 @@ def converse_offline(
|
|||
tokenizer_name=None,
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
) -> Union[ThreadedGenerator, Iterator[str]]:
|
||||
"""
|
||||
Converse with user using Llama
|
||||
|
@ -156,6 +158,16 @@ def converse_offline(
|
|||
# Initialize Variables
|
||||
compiled_references_message = "\n\n".join({f"{item}" for item in references})
|
||||
|
||||
system_prompt = ""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
if agent:
|
||||
system_prompt = prompts.custom_system_prompt_message_gpt4all.format(
|
||||
name=agent.name, bio=agent.tuning, current_date=current_date
|
||||
)
|
||||
else:
|
||||
system_prompt = prompts.system_prompt_message_gpt4all.format(current_date=current_date)
|
||||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
if location_data:
|
||||
|
@ -185,10 +197,9 @@ def converse_offline(
|
|||
conversation_primer = f"{prompts.notes_conversation_gpt4all.format(references=compiled_references_message)}\n{conversation_primer}"
|
||||
|
||||
# Setup Prompt with Primer or Conversation History
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
prompts.system_prompt_message_gpt4all.format(current_date=current_date),
|
||||
system_prompt,
|
||||
conversation_log,
|
||||
model_name=model,
|
||||
max_prompt_size=max_prompt_size,
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
|||
|
||||
from langchain.schema import ChatMessage
|
||||
|
||||
from khoj.database.models import Agent
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.openai.utils import (
|
||||
chat_completion_with_backoff,
|
||||
|
@ -118,6 +119,7 @@ def converse(
|
|||
tokenizer_name=None,
|
||||
location_data: LocationData = None,
|
||||
user_name: str = None,
|
||||
agent: Agent = None,
|
||||
):
|
||||
"""
|
||||
Converse with user using OpenAI's ChatGPT
|
||||
|
@ -128,6 +130,13 @@ def converse(
|
|||
|
||||
conversation_primer = prompts.query_prompt.format(query=user_query)
|
||||
|
||||
system_prompt = ""
|
||||
|
||||
if agent:
|
||||
system_prompt = prompts.custom_personality.format(name=agent.name, bio=agent.tuning, current_date=current_date)
|
||||
else:
|
||||
system_prompt = prompts.personality.format(current_date=current_date)
|
||||
|
||||
if location_data:
|
||||
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
|
||||
location_prompt = prompts.user_location.format(location=location)
|
||||
|
@ -158,7 +167,7 @@ def converse(
|
|||
# Setup Prompt with Primer or Conversation History
|
||||
messages = generate_chatml_messages_with_context(
|
||||
conversation_primer,
|
||||
prompts.personality.format(current_date=current_date),
|
||||
system_prompt,
|
||||
conversation_log,
|
||||
model,
|
||||
max_prompt_size,
|
||||
|
|
|
@ -21,6 +21,24 @@ Today is {current_date} in UTC.
|
|||
""".strip()
|
||||
)
|
||||
|
||||
custom_personality = PromptTemplate.from_template(
|
||||
"""
|
||||
Your are {name}, a personal agent on Khoj.
|
||||
Use your general knowledge and past conversation with the user as context to inform your responses.
|
||||
You were created by Khoj Inc. with the following capabilities:
|
||||
|
||||
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
|
||||
- Users can share files and other information with you using the Khoj Desktop, Obsidian or Emacs app. They can also drag and drop their files into the chat window.
|
||||
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
||||
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
||||
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
||||
|
||||
Today is {current_date} in UTC.
|
||||
|
||||
Here's a bio about you: {bio}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
## General Conversation
|
||||
## --
|
||||
general_conversation = PromptTemplate.from_template(
|
||||
|
@ -61,6 +79,20 @@ Today is {current_date} in UTC.
|
|||
""".strip()
|
||||
)
|
||||
|
||||
custom_system_prompt_message_gpt4all = PromptTemplate.from_template(
|
||||
"""
|
||||
You are {name}, a personal agent on Khoj.
|
||||
- Use your general knowledge and past conversation with the user as context to inform your responses.
|
||||
- If you do not know the answer, say 'I don't know.'
|
||||
- Think step-by-step and ask questions to get the necessary information to answer the user's question.
|
||||
- Do not print verbatim Notes unless necessary.
|
||||
|
||||
Today is {current_date} in UTC.
|
||||
|
||||
Here is your instruction set:\n{bio}
|
||||
""".strip()
|
||||
)
|
||||
|
||||
system_prompt_message_extract_questions_gpt4all = f"""You are Khoj, a kind and intelligent personal assistant. When the user asks you a question, you ask follow-up questions to clarify the necessary information you need in order to answer from the user's perspective.
|
||||
- Write the question as if you can search for the answer on the user's personal notes.
|
||||
- Try to be as specific as possible. Instead of saying "they" or "it" or "he", use the name of the person or thing you are referring to. For example, instead of saying "Which store did they go to?", say "Which store did Alice and Bob go to?".
|
||||
|
|
|
@ -148,6 +148,7 @@ def chat_sessions(
|
|||
async def create_chat_session(
|
||||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
agent_id: Optional[int] = None,
|
||||
):
|
||||
user = request.user.object
|
||||
|
||||
|
@ -250,9 +251,11 @@ async def chat(
|
|||
formatted_help = help_message.format(model=model_type, version=state.khoj_version, device=get_device())
|
||||
return StreamingResponse(iter([formatted_help]), media_type="text/event-stream", status_code=200)
|
||||
|
||||
meta_log = (
|
||||
await ConversationAdapters.aget_conversation_by_user(user, request.user.client_app, conversation_id, slug)
|
||||
).conversation_log
|
||||
conversation = await ConversationAdapters.aget_conversation_by_user(
|
||||
user, request.user.client_app, conversation_id, slug
|
||||
)
|
||||
|
||||
meta_log = conversation.conversation_log
|
||||
|
||||
if conversation_commands == [ConversationCommand.Default]:
|
||||
conversation_commands = await aget_relevant_information_sources(q, meta_log)
|
||||
|
@ -335,6 +338,7 @@ async def chat(
|
|||
llm_response, chat_metadata = await agenerate_chat_response(
|
||||
defiltered_query,
|
||||
meta_log,
|
||||
conversation,
|
||||
compiled_references,
|
||||
online_results,
|
||||
inferred_queries,
|
||||
|
|
|
@ -12,8 +12,10 @@ from starlette.authentication import has_required_scope
|
|||
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
ChatModelOptions,
|
||||
ClientApplication,
|
||||
Conversation,
|
||||
KhojUser,
|
||||
Subscription,
|
||||
TextToImageModelConfig,
|
||||
|
@ -359,6 +361,7 @@ async def send_message_to_model_wrapper(
|
|||
def generate_chat_response(
|
||||
q: str,
|
||||
meta_log: dict,
|
||||
conversation: Conversation,
|
||||
compiled_references: List[str] = [],
|
||||
online_results: Dict[str, Any] = {},
|
||||
inferred_queries: List[str] = [],
|
||||
|
@ -374,6 +377,7 @@ def generate_chat_response(
|
|||
logger.debug(f"Conversation Types: {conversation_commands}")
|
||||
|
||||
metadata = {}
|
||||
agent = conversation.agent
|
||||
|
||||
try:
|
||||
partial_completion = partial(
|
||||
|
@ -388,7 +392,7 @@ def generate_chat_response(
|
|||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user)
|
||||
conversation_config = ConversationAdapters.get_valid_conversation_config(user, conversation)
|
||||
if conversation_config.model_type == "offline":
|
||||
if state.gpt4all_processor_config is None or state.gpt4all_processor_config.loaded_model is None:
|
||||
state.gpt4all_processor_config = GPT4AllProcessorModel(conversation_config.chat_model)
|
||||
|
@ -407,6 +411,7 @@ def generate_chat_response(
|
|||
tokenizer_name=conversation_config.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == "openai":
|
||||
|
@ -426,6 +431,7 @@ def generate_chat_response(
|
|||
tokenizer_name=conversation_config.tokenizer,
|
||||
location_data=location_data,
|
||||
user_name=user_name,
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
metadata.update({"chat_model": conversation_config.chat_model})
|
||||
|
|
|
@ -12,6 +12,7 @@ from khoj.configure import (
|
|||
configure_search_types,
|
||||
)
|
||||
from khoj.database.models import (
|
||||
Agent,
|
||||
GithubConfig,
|
||||
GithubRepoConfig,
|
||||
KhojApiUser,
|
||||
|
@ -181,6 +182,28 @@ def api_user4(default_user4):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def offline_agent():
|
||||
chat_model = ChatModelOptionsFactory()
|
||||
return Agent.objects.create(
|
||||
name="Accountant",
|
||||
chat_model=chat_model,
|
||||
tuning="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent. ALWAYS RESPOND WITH A SUMMARY TOTAL OF HOW MUCH MONEY I HAVE SPENT.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def openai_agent():
|
||||
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
|
||||
return Agent.objects.create(
|
||||
name="Accountant",
|
||||
chat_model=chat_model,
|
||||
tuning="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent.",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_models(search_config: SearchConfig):
|
||||
search_models = SearchModels()
|
||||
|
|
|
@ -465,6 +465,47 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
|||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_agent_prompt_should_be_used(loaded_model, offline_agent):
|
||||
"Chat actor should ask be tuned to think like an accountant based on the agent definition"
|
||||
# Arrange
|
||||
context = [
|
||||
f"""I went to the store and bought some bananas for 2.20""",
|
||||
f"""I went to the store and bought some apples for 1.30""",
|
||||
f"""I went to the store and bought some oranges for 6.00""",
|
||||
]
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I buy?",
|
||||
loaded_model=loaded_model,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert that the model without the agent prompt does not include the summary of purchases
|
||||
expected_responses = ["9.50", "9.5"]
|
||||
assert all([expected_response not in response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to summarize values of purchases" + response
|
||||
)
|
||||
|
||||
# Act
|
||||
response_gen = converse_offline(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I buy?",
|
||||
loaded_model=loaded_model,
|
||||
agent=offline_agent,
|
||||
)
|
||||
response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert that the model with the agent prompt does include the summary of purchases
|
||||
expected_responses = ["9.50", "9.5"]
|
||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to summarize values of purchases" + response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def test_chat_does_not_exceed_prompt_size(loaded_model):
|
||||
"Ensure chat context and response together do not exceed max prompt size for the model"
|
||||
|
|
|
@ -6,6 +6,7 @@ import pytest
|
|||
from faker import Faker
|
||||
from freezegun import freeze_time
|
||||
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import aget_relevant_information_sources
|
||||
|
@ -26,20 +27,20 @@ def generate_history(message_list):
|
|||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
for user_message, gpt_message, context in message_list:
|
||||
conversation_log["chat"] += message_to_log(
|
||||
message_to_log(
|
||||
user_message,
|
||||
gpt_message,
|
||||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
conversation_log=conversation_log.get("chat", []),
|
||||
)
|
||||
return conversation_log
|
||||
|
||||
|
||||
def populate_chat_history(message_list, user):
|
||||
def create_conversation(message_list, user, agent=None):
|
||||
# Generate conversation logs
|
||||
conversation_log = generate_history(message_list)
|
||||
|
||||
# Update Conversation Metadata Logs in Database
|
||||
ConversationFactory(user=user, conversation_log=conversation_log)
|
||||
return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent)
|
||||
|
||||
|
||||
# Tests
|
||||
|
@ -114,7 +115,7 @@ def test_answer_from_chat_history(client_offline_chat, default_user2):
|
|||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
|
@ -141,7 +142,7 @@ def test_answer_from_currently_retrieved_content(client_offline_chat, default_us
|
|||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was Xi Li born?"')
|
||||
|
@ -165,7 +166,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(client_offlin
|
|||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
|
||||
|
@ -191,7 +192,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(client_offline
|
|||
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"')
|
||||
|
@ -217,7 +218,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(client_offline_chat, def
|
|||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="Where was I born?"&stream=true')
|
||||
|
@ -238,7 +239,7 @@ def test_answer_using_general_command(client_offline_chat, default_user2):
|
|||
# Arrange
|
||||
query = urllib.parse.quote("/general Where was Xi Li born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
||||
|
@ -256,7 +257,7 @@ def test_answer_from_retrieved_content_using_notes_command(client_offline_chat,
|
|||
# Arrange
|
||||
query = urllib.parse.quote("/notes Where was Xi Li born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
||||
|
@ -275,7 +276,7 @@ def test_answer_using_file_filter(client_offline_chat, default_user2):
|
|||
no_answer_query = urllib.parse.quote('Where was Xi Li born? file:"Namita.markdown"')
|
||||
answer_query = urllib.parse.quote('Where was Xi Li born? file:"Xi Li.markdown"')
|
||||
message_list = []
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
no_answer_response = client_offline_chat.get(f"/api/chat?q={no_answer_query}&stream=true").content.decode("utf-8")
|
||||
|
@ -293,7 +294,7 @@ def test_answer_not_known_using_notes_command(client_offline_chat, default_user2
|
|||
# Arrange
|
||||
query = urllib.parse.quote("/notes Where was Testatron born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&stream=true")
|
||||
|
@ -351,7 +352,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(client
|
|||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(
|
||||
|
@ -394,14 +395,14 @@ def test_ask_for_clarification_if_not_enough_context_in_question(client_offline_
|
|||
@pytest.mark.xfail(reason="Chat director not capable of answering this question yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2):
|
||||
def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
|
@ -415,13 +416,77 @@ def test_answer_in_chat_history_beyond_lookback_window(client_offline_chat, defa
|
|||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_by_conversation_id(client_offline_chat, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("What's my favorite color", "Your favorite color is green.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
message_list2 = [
|
||||
("Hello, my name is Julia. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 14th August 1947.", []),
|
||||
("What's my favorite color", "Your favorite color is maroon.", []),
|
||||
("Where was I born?", "You were born in a potato farm.", []),
|
||||
]
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
create_conversation(message_list2, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(
|
||||
f'/api/chat?q="What is my favorite color?"&conversation_id={conversation.id}&stream=true'
|
||||
)
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["green"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected green in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.xfail(reason="Chat director not great at adhering to agent instructions yet")
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_by_conversation_id_with_agent(
|
||||
client_offline_chat, default_user2: KhojUser, offline_agent: Agent
|
||||
):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("What's my favorite color", "Your favorite color is green.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []),
|
||||
]
|
||||
conversation = create_conversation(message_list, default_user2, offline_agent)
|
||||
|
||||
# Act
|
||||
query = urllib.parse.quote("/general What did I eat for breakfast?")
|
||||
response = client_offline_chat.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert that agent only responds with the summary of spending
|
||||
expected_responses = ["13.00", "13", "13.0", "thirteen"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected green in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_chat_history_very_long(client_offline_chat, default_user2):
|
||||
# Arrange
|
||||
message_list = [(" ".join([fake.paragraph() for _ in range(50)]), fake.sentence(), []) for _ in range(10)]
|
||||
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = client_offline_chat.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
|
@ -525,7 +590,7 @@ async def test_get_correct_tools_with_chat_history(client_offline_chat):
|
|||
),
|
||||
("What's up in New York City?", "A Pride parade has recently been held in New York City, on July 31st.", []),
|
||||
]
|
||||
chat_history = populate_chat_history(chat_log)
|
||||
chat_history = create_conversation(chat_log)
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, chat_history)
|
||||
|
|
|
@ -435,6 +435,42 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
|||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
def test_agent_prompt_should_be_used(openai_agent):
|
||||
"Chat actor should ask be tuned to think like an accountant based on the agent definition"
|
||||
# Arrange
|
||||
context = [
|
||||
f"""I went to the store and bought some bananas for 2.20""",
|
||||
f"""I went to the store and bought some apples for 1.30""",
|
||||
f"""I went to the store and bought some oranges for 6.00""",
|
||||
]
|
||||
expected_responses = ["9.50", "9.5"]
|
||||
|
||||
# Act
|
||||
response_gen = converse(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I buy?",
|
||||
api_key=api_key,
|
||||
)
|
||||
no_agent_response = "".join([response_chunk for response_chunk in response_gen])
|
||||
response_gen = converse(
|
||||
references=context, # Assume context retrieved from notes for the user_query
|
||||
user_query="What did I buy?",
|
||||
api_key=api_key,
|
||||
agent=openai_agent,
|
||||
)
|
||||
agent_response = "".join([response_chunk for response_chunk in response_gen])
|
||||
|
||||
# Assert that the model without the agent prompt does not include the summary of purchases
|
||||
assert all([expected_response not in no_agent_response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to summarize values of purchases" + no_agent_response
|
||||
)
|
||||
assert any([expected_response in agent_response for expected_response in expected_responses]), (
|
||||
"Expected chat actor to summarize values of purchases" + agent_response
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
|
|
|
@ -5,13 +5,10 @@ from urllib.parse import quote
|
|||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.database.models import Agent, KhojUser
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.utils import message_to_log
|
||||
from khoj.routers.helpers import (
|
||||
aget_relevant_information_sources,
|
||||
aget_relevant_output_modes,
|
||||
)
|
||||
from khoj.routers.helpers import aget_relevant_information_sources
|
||||
from tests.helpers import ConversationFactory
|
||||
|
||||
# Initialize variables for tests
|
||||
|
@ -29,20 +26,21 @@ def generate_history(message_list):
|
|||
# Generate conversation logs
|
||||
conversation_log = {"chat": []}
|
||||
for user_message, gpt_message, context in message_list:
|
||||
conversation_log["chat"] += message_to_log(
|
||||
message_to_log(
|
||||
user_message,
|
||||
gpt_message,
|
||||
{"context": context, "intent": {"query": user_message, "inferred-queries": f'["{user_message}"]'}},
|
||||
conversation_log=conversation_log.get("chat", []),
|
||||
)
|
||||
return conversation_log
|
||||
|
||||
|
||||
def populate_chat_history(message_list, user):
|
||||
def create_conversation(message_list, user, agent=None):
|
||||
# Generate conversation logs
|
||||
conversation_log = generate_history(message_list)
|
||||
|
||||
# Update Conversation Metadata Logs in Database
|
||||
ConversationFactory(user=user, conversation_log=conversation_log)
|
||||
return ConversationFactory(user=user, conversation_log=conversation_log, agent=agent)
|
||||
|
||||
|
||||
# Tests
|
||||
|
@ -116,7 +114,7 @@ def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
|
|||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
|
@ -143,7 +141,7 @@ def test_answer_from_currently_retrieved_content(chat_client, default_user2: Kho
|
|||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was Xi Li born?"')
|
||||
|
@ -167,7 +165,7 @@ def test_answer_from_chat_history_and_previously_retrieved_content(chat_client_n
|
|||
["Testatron was born on 1st April 1984 in Testville."],
|
||||
),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client_no_background.get(f'/api/chat?q="Where was I born?"')
|
||||
|
@ -190,7 +188,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content(chat_client, d
|
|||
("Hello, my name is Xi Li. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"')
|
||||
|
@ -215,7 +213,7 @@ def test_no_answer_in_chat_history_or_retrieved_content(chat_client, default_use
|
|||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="Where was I born?"&stream=true')
|
||||
|
@ -236,7 +234,7 @@ def test_answer_using_general_command(chat_client, default_user2: KhojUser):
|
|||
# Arrange
|
||||
query = urllib.parse.quote("/general Where was Xi Li born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||
|
@ -254,7 +252,7 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client, default_
|
|||
# Arrange
|
||||
query = urllib.parse.quote("/notes Where was Xi Li born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
||||
|
@ -272,7 +270,7 @@ def test_answer_not_known_using_notes_command(chat_client_no_background, default
|
|||
# Arrange
|
||||
query = urllib.parse.quote("/notes Where was Testatron born?")
|
||||
message_list = []
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client_no_background.get(f"/api/chat?q={query}&stream=true")
|
||||
|
@ -327,7 +325,7 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content(chat_c
|
|||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(
|
||||
|
@ -379,7 +377,7 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user
|
|||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
populate_chat_history(message_list, default_user2)
|
||||
create_conversation(message_list, default_user2)
|
||||
|
||||
# Act
|
||||
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
||||
|
@ -393,6 +391,68 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user
|
|||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_by_conversation_id(chat_client, default_user2: KhojUser):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("What's my favorite color", "Your favorite color is green.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
]
|
||||
message_list2 = [
|
||||
("Hello, my name is Julia. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 14th August 1947.", []),
|
||||
("What's my favorite color", "Your favorite color is maroon.", []),
|
||||
("Where was I born?", "You were born in a potato farm.", []),
|
||||
]
|
||||
conversation = create_conversation(message_list, default_user2)
|
||||
create_conversation(message_list2, default_user2)
|
||||
|
||||
# Act
|
||||
query = urllib.parse.quote("/general What is my favorite color?")
|
||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert
|
||||
expected_responses = ["green"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected green in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.chatquality
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_answer_in_chat_history_by_conversation_id_with_agent(
|
||||
chat_client, default_user2: KhojUser, openai_agent: Agent
|
||||
):
|
||||
# Arrange
|
||||
message_list = [
|
||||
("Hello, my name is Testatron. Who are you?", "Hi, I am Khoj, a personal assistant. How can I help?", []),
|
||||
("When was I born?", "You were born on 1st April 1984.", []),
|
||||
("What's my favorite color", "Your favorite color is green.", []),
|
||||
("Where was I born?", "You were born Testville.", []),
|
||||
("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []),
|
||||
]
|
||||
conversation = create_conversation(message_list, default_user2, openai_agent)
|
||||
|
||||
# Act
|
||||
query = urllib.parse.quote("/general What did I eat for breakfast?")
|
||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
||||
response_message = response.content.decode("utf-8")
|
||||
|
||||
# Assert that agent only responds with the summary of spending
|
||||
expected_responses = ["13.00", "13", "13.0", "thirteen"]
|
||||
assert response.status_code == 200
|
||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||
"Expected green in response, but got: " + response_message
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.chatquality
|
||||
|
|
Loading…
Add table
Reference in a new issue