Fix, improve openai chat actor, director tests & online search prompt

This commit is contained in:
Debanjum Singh Solanky 2024-08-22 19:07:53 -07:00
parent 9986c183ea
commit 238bc11a50
4 changed files with 55 additions and 46 deletions

View file

@ -578,11 +578,11 @@ Khoj:
online_search_conversation_subqueries = PromptTemplate.from_template(
"""
You are Khoj, an advanced google search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
- You will receive the conversation history as context.
- Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information.
- Use site: google search operators when appropriate
- Use site: google search operator when appropriate
- You have access to the the whole internet to retrieve information.
- Official, up-to-date information about you, Khoj, is available at site:khoj.dev, github or pypi.

View file

@ -196,7 +196,7 @@ def openai_agent():
return Agent.objects.create(
name="Accountant",
chat_model=chat_model,
personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent.",
personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent. ALWAYS RESPOND WITH A SUMMARY TOTAL OF HOW MUCH MONEY I HAVE SPENT.",
)

View file

@ -405,10 +405,10 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["test", "Test"]
expected_responses = ["test", "bug", "code"]
assert len(response.splitlines()) == 3 # haikus are 3 lines long
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected [T|t]est in response, but got: " + response
assert any([expected_response in response.lower() for expected_response in expected_responses]), (
"Expected haiku about unit test, but got: " + response
)
@ -441,7 +441,13 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
response = "".join([response_chunk for response_chunk in response_gen])
# Assert
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"]
expected_responses = [
"which sister",
"Which sister",
"which of your sister",
"Which of your sister",
"Could you provide",
]
assert any([expected_response in response for expected_response in expected_responses]), (
"Expected chat actor to ask for clarification in response, but got: " + response
)
@ -555,7 +561,7 @@ async def test_select_data_sources_actor_chooses_to_search_notes(
chat_client, user_query, expected_conversation_commands
):
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {}, False)
conversation_commands = await aget_relevant_information_sources(user_query, {}, False, False)
# Assert
assert set(expected_conversation_commands) == set(conversation_commands)

View file

@ -72,9 +72,8 @@ def test_chat_with_online_content(chat_client):
# Assert
expected_responses = [
"https://paulgraham.com/greatwork.html",
"https://www.paulgraham.com/greatwork.html",
"http://www.paulgraham.com/greatwork.html",
"paulgraham.com/greatwork.html",
"paulgraham.com/hwh.html",
]
assert response.status_code == 200
assert any(
@ -112,7 +111,7 @@ def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
create_conversation(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
response = chat_client.get(f'/api/chat?q="What is my name?"')
response_message = response.content.decode("utf-8")
# Assert
@ -259,8 +258,8 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client, default_
create_conversation(message_list, default_user2)
# Act
response = chat_client.get(f"/api/chat?q={query}&stream=true")
response_message = response.content.decode("utf-8")
response = chat_client.get(f"/api/chat?q={query}")
response_message = response.json()["response"]
# Assert
assert response.status_code == 200
@ -310,8 +309,8 @@ def test_summarize_one_file(chat_client, default_user2: KhojUser):
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
)
query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.json()["response"]
# Assert
assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left."
@ -342,8 +341,8 @@ def test_summarize_extra_text(chat_client, default_user2: KhojUser):
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
)
query = urllib.parse.quote("/summarize tell me about Xiu")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.json()["response"]
# Assert
assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left."
@ -370,8 +369,8 @@ def test_summarize_multiple_files(chat_client, default_user2: KhojUser):
)
query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.json()["response"]
# Assert
assert response_message == "Only one file can be selected for summarization."
@ -383,8 +382,8 @@ def test_summarize_no_files(chat_client, default_user2: KhojUser):
message_list = []
conversation = create_conversation(message_list, default_user2)
query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.json()["response"]
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
@ -415,15 +414,15 @@ def test_summarize_different_conversation(chat_client, default_user2: KhojUser):
)
query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation2.id}&stream=true")
response_message = response.content.decode("utf-8")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation2.id}")
response_message = response.json()["response"]
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
# now make sure that the file filter is still in conversation 1
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation1.id}&stream=true")
response_message = response.content.decode("utf-8")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation1.id}")
response_message = response.json()["response"]
# Assert
assert response_message != ""
@ -442,8 +441,8 @@ def test_summarize_nonexistant_file(chat_client, default_user2: KhojUser):
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
)
query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.json()["response"]
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
@ -471,8 +470,8 @@ def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_confi
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
)
query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.json()["response"]
# Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left."
@ -503,8 +502,8 @@ def test_answer_requires_current_date_awareness(chat_client):
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
# Act
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
response_message = response.content.decode("utf-8")
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"')
response_message = response.json()["response"]
# Assert
assert response.status_code == 200
@ -575,8 +574,8 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user
create_conversation(message_list, default_user2)
# Act
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
response_message = response.content.decode("utf-8")
response = chat_client.get(f'/api/chat?q="What is my name?"')
response_message = response.json()["response"]
# Assert
expected_responses = ["Testatron", "testatron"]
@ -631,20 +630,24 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
("When was I born?", "You were born on 1st April 1984.", []),
("What's my favorite color", "Your favorite color is green.", []),
("Where was I born?", "You were born Testville.", []),
("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []),
(
"What did I buy?",
"You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00 for breakfast",
[],
),
]
conversation = create_conversation(message_list, default_user2, openai_agent)
# Act
query = urllib.parse.quote("/general What did I eat for breakfast?")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
response_message = response.content.decode("utf-8")
query = urllib.parse.quote("/general What did I buy for breakfast?")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.json()["response"]
# Assert that agent only responds with the summary of spending
expected_responses = ["13.00", "13", "13.0", "thirteen"]
assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected green in response, but got: " + response_message
"Expected amount in response, but got: " + response_message
)
@ -654,7 +657,7 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information"
# Act
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"')
response = chat_client.get(f'/api/chat?q="Is Xi Li older than Namita? Just say the older persons full name"')
response_message = response.json()["response"].lower()
# Assert
@ -676,7 +679,7 @@ def test_answer_using_file_filter(chat_client):
"Chat should be able to use search filters in the query"
# Act
query = urllib.parse.quote(
'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
'Is Xi Li older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
)
response = chat_client.get(f"/api/chat?q={query}")
@ -703,7 +706,7 @@ async def test_get_correct_tools_online(chat_client):
user_query = "What's the weather in Patagonia this week?"
# Act
tools = await aget_relevant_information_sources(user_query, {}, False)
tools = await aget_relevant_information_sources(user_query, {}, False, False)
# Assert
tools = [tool.value for tool in tools]
@ -718,7 +721,7 @@ async def test_get_correct_tools_notes(chat_client):
user_query = "Where did I go for my first battleship training?"
# Act
tools = await aget_relevant_information_sources(user_query, {}, False)
tools = await aget_relevant_information_sources(user_query, {}, False, False)
# Assert
tools = [tool.value for tool in tools]
@ -733,7 +736,7 @@ async def test_get_correct_tools_online_or_general_and_notes(chat_client):
user_query = "What's the highest point in Patagonia and have I been there?"
# Act
tools = await aget_relevant_information_sources(user_query, {}, False)
tools = await aget_relevant_information_sources(user_query, {}, False, False)
# Assert
tools = [tool.value for tool in tools]
@ -750,7 +753,7 @@ async def test_get_correct_tools_general(chat_client):
user_query = "How many noble gases are there?"
# Act
tools = await aget_relevant_information_sources(user_query, {}, False)
tools = await aget_relevant_information_sources(user_query, {}, False, False)
# Assert
tools = [tool.value for tool in tools]
@ -774,7 +777,7 @@ async def test_get_correct_tools_with_chat_history(chat_client):
chat_history = generate_history(chat_log)
# Act
tools = await aget_relevant_information_sources(user_query, chat_history, False)
tools = await aget_relevant_information_sources(user_query, chat_history, False, False)
# Assert
tools = [tool.value for tool in tools]