Fix, improve openai chat actor, director tests & online search prompt

This commit is contained in:
Debanjum Singh Solanky 2024-08-22 19:07:53 -07:00
parent 9986c183ea
commit 238bc11a50
4 changed files with 55 additions and 46 deletions

View file

@ -578,11 +578,11 @@ Khoj:
online_search_conversation_subqueries = PromptTemplate.from_template( online_search_conversation_subqueries = PromptTemplate.from_template(
""" """
You are Khoj, an advanced google search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question. You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
- You will receive the conversation history as context. - You will receive the conversation history as context.
- Add as much context from the previous questions and answers as required into your search queries. - Add as much context from the previous questions and answers as required into your search queries.
- Break messages into multiple search queries when required to retrieve the relevant information. - Break messages into multiple search queries when required to retrieve the relevant information.
- Use site: google search operators when appropriate - Use site: google search operator when appropriate
- You have access to the the whole internet to retrieve information. - You have access to the the whole internet to retrieve information.
- Official, up-to-date information about you, Khoj, is available at site:khoj.dev, github or pypi. - Official, up-to-date information about you, Khoj, is available at site:khoj.dev, github or pypi.

View file

@ -196,7 +196,7 @@ def openai_agent():
return Agent.objects.create( return Agent.objects.create(
name="Accountant", name="Accountant",
chat_model=chat_model, chat_model=chat_model,
personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent.", personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent. ALWAYS RESPOND WITH A SUMMARY TOTAL OF HOW MUCH MONEY I HAVE SPENT.",
) )

View file

@ -405,10 +405,10 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
response = "".join([response_chunk for response_chunk in response_gen]) response = "".join([response_chunk for response_chunk in response_gen])
# Assert # Assert
expected_responses = ["test", "Test"] expected_responses = ["test", "bug", "code"]
assert len(response.splitlines()) == 3 # haikus are 3 lines long assert len(response.splitlines()) == 3 # haikus are 3 lines long
assert any([expected_response in response for expected_response in expected_responses]), ( assert any([expected_response in response.lower() for expected_response in expected_responses]), (
"Expected [T|t]est in response, but got: " + response "Expected haiku about unit test, but got: " + response
) )
@ -441,7 +441,13 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
response = "".join([response_chunk for response_chunk in response_gen]) response = "".join([response_chunk for response_chunk in response_gen])
# Assert # Assert
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"] expected_responses = [
"which sister",
"Which sister",
"which of your sister",
"Which of your sister",
"Could you provide",
]
assert any([expected_response in response for expected_response in expected_responses]), ( assert any([expected_response in response for expected_response in expected_responses]), (
"Expected chat actor to ask for clarification in response, but got: " + response "Expected chat actor to ask for clarification in response, but got: " + response
) )
@ -555,7 +561,7 @@ async def test_select_data_sources_actor_chooses_to_search_notes(
chat_client, user_query, expected_conversation_commands chat_client, user_query, expected_conversation_commands
): ):
# Act # Act
conversation_commands = await aget_relevant_information_sources(user_query, {}, False) conversation_commands = await aget_relevant_information_sources(user_query, {}, False, False)
# Assert # Assert
assert set(expected_conversation_commands) == set(conversation_commands) assert set(expected_conversation_commands) == set(conversation_commands)

View file

@ -72,9 +72,8 @@ def test_chat_with_online_content(chat_client):
# Assert # Assert
expected_responses = [ expected_responses = [
"https://paulgraham.com/greatwork.html", "paulgraham.com/greatwork.html",
"https://www.paulgraham.com/greatwork.html", "paulgraham.com/hwh.html",
"http://www.paulgraham.com/greatwork.html",
] ]
assert response.status_code == 200 assert response.status_code == 200
assert any( assert any(
@ -112,7 +111,7 @@ def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true') response = chat_client.get(f'/api/chat?q="What is my name?"')
response_message = response.content.decode("utf-8") response_message = response.content.decode("utf-8")
# Assert # Assert
@ -259,8 +258,8 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client, default_
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.get(f"/api/chat?q={query}&stream=true") response = chat_client.get(f"/api/chat?q={query}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -310,8 +309,8 @@ def test_summarize_one_file(chat_client, default_user2: KhojUser):
json={"filename": summarization_file, "conversation_id": str(conversation.id)}, json={"filename": summarization_file, "conversation_id": str(conversation.id)},
) )
query = urllib.parse.quote("/summarize") query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true") response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response_message != "" assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left." assert response_message != "No files selected for summarization. Please add files using the section on the left."
@ -342,8 +341,8 @@ def test_summarize_extra_text(chat_client, default_user2: KhojUser):
json={"filename": summarization_file, "conversation_id": str(conversation.id)}, json={"filename": summarization_file, "conversation_id": str(conversation.id)},
) )
query = urllib.parse.quote("/summarize tell me about Xiu") query = urllib.parse.quote("/summarize tell me about Xiu")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true") response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response_message != "" assert response_message != ""
assert response_message != "No files selected for summarization. Please add files using the section on the left." assert response_message != "No files selected for summarization. Please add files using the section on the left."
@ -370,8 +369,8 @@ def test_summarize_multiple_files(chat_client, default_user2: KhojUser):
) )
query = urllib.parse.quote("/summarize") query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true") response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response_message == "Only one file can be selected for summarization." assert response_message == "Only one file can be selected for summarization."
@ -383,8 +382,8 @@ def test_summarize_no_files(chat_client, default_user2: KhojUser):
message_list = [] message_list = []
conversation = create_conversation(message_list, default_user2) conversation = create_conversation(message_list, default_user2)
query = urllib.parse.quote("/summarize") query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true") response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left." assert response_message == "No files selected for summarization. Please add files using the section on the left."
@ -415,15 +414,15 @@ def test_summarize_different_conversation(chat_client, default_user2: KhojUser):
) )
query = urllib.parse.quote("/summarize") query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation2.id}&stream=true") response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation2.id}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left." assert response_message == "No files selected for summarization. Please add files using the section on the left."
# now make sure that the file filter is still in conversation 1 # now make sure that the file filter is still in conversation 1
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation1.id}&stream=true") response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation1.id}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response_message != "" assert response_message != ""
@ -442,8 +441,8 @@ def test_summarize_nonexistant_file(chat_client, default_user2: KhojUser):
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)}, json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
) )
query = urllib.parse.quote("/summarize") query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true") response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left." assert response_message == "No files selected for summarization. Please add files using the section on the left."
@ -471,8 +470,8 @@ def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_confi
json={"filename": summarization_file, "conversation_id": str(conversation.id)}, json={"filename": summarization_file, "conversation_id": str(conversation.id)},
) )
query = urllib.parse.quote("/summarize") query = urllib.parse.quote("/summarize")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true") response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response_message == "No files selected for summarization. Please add files using the section on the left." assert response_message == "No files selected for summarization. Please add files using the section on the left."
@ -503,8 +502,8 @@ def test_answer_requires_current_date_awareness(chat_client):
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client): def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
"Chat director should be able to answer questions that require date aware aggregation across multiple notes" "Chat director should be able to answer questions that require date aware aggregation across multiple notes"
# Act # Act
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true') response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"')
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
assert response.status_code == 200 assert response.status_code == 200
@ -575,8 +574,8 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user
create_conversation(message_list, default_user2) create_conversation(message_list, default_user2)
# Act # Act
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true') response = chat_client.get(f'/api/chat?q="What is my name?"')
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert # Assert
expected_responses = ["Testatron", "testatron"] expected_responses = ["Testatron", "testatron"]
@ -631,20 +630,24 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
("When was I born?", "You were born on 1st April 1984.", []), ("When was I born?", "You were born on 1st April 1984.", []),
("What's my favorite color", "Your favorite color is green.", []), ("What's my favorite color", "Your favorite color is green.", []),
("Where was I born?", "You were born Testville.", []), ("Where was I born?", "You were born Testville.", []),
("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []), (
"What did I buy?",
"You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00 for breakfast",
[],
),
] ]
conversation = create_conversation(message_list, default_user2, openai_agent) conversation = create_conversation(message_list, default_user2, openai_agent)
# Act # Act
query = urllib.parse.quote("/general What did I eat for breakfast?") query = urllib.parse.quote("/general What did I buy for breakfast?")
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true") response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
response_message = response.content.decode("utf-8") response_message = response.json()["response"]
# Assert that agent only responds with the summary of spending # Assert that agent only responds with the summary of spending
expected_responses = ["13.00", "13", "13.0", "thirteen"] expected_responses = ["13.00", "13", "13.0", "thirteen"]
assert response.status_code == 200 assert response.status_code == 200
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), ( assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
"Expected green in response, but got: " + response_message "Expected amount in response, but got: " + response_message
) )
@ -654,7 +657,7 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
def test_answer_requires_multiple_independent_searches(chat_client): def test_answer_requires_multiple_independent_searches(chat_client):
"Chat director should be able to answer by doing multiple independent searches for required information" "Chat director should be able to answer by doing multiple independent searches for required information"
# Act # Act
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"') response = chat_client.get(f'/api/chat?q="Is Xi Li older than Namita? Just say the older persons full name"')
response_message = response.json()["response"].lower() response_message = response.json()["response"].lower()
# Assert # Assert
@ -676,7 +679,7 @@ def test_answer_using_file_filter(chat_client):
"Chat should be able to use search filters in the query" "Chat should be able to use search filters in the query"
# Act # Act
query = urllib.parse.quote( query = urllib.parse.quote(
'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"' 'Is Xi Li older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
) )
response = chat_client.get(f"/api/chat?q={query}") response = chat_client.get(f"/api/chat?q={query}")
@ -703,7 +706,7 @@ async def test_get_correct_tools_online(chat_client):
user_query = "What's the weather in Patagonia this week?" user_query = "What's the weather in Patagonia this week?"
# Act # Act
tools = await aget_relevant_information_sources(user_query, {}, False) tools = await aget_relevant_information_sources(user_query, {}, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@ -718,7 +721,7 @@ async def test_get_correct_tools_notes(chat_client):
user_query = "Where did I go for my first battleship training?" user_query = "Where did I go for my first battleship training?"
# Act # Act
tools = await aget_relevant_information_sources(user_query, {}, False) tools = await aget_relevant_information_sources(user_query, {}, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@ -733,7 +736,7 @@ async def test_get_correct_tools_online_or_general_and_notes(chat_client):
user_query = "What's the highest point in Patagonia and have I been there?" user_query = "What's the highest point in Patagonia and have I been there?"
# Act # Act
tools = await aget_relevant_information_sources(user_query, {}, False) tools = await aget_relevant_information_sources(user_query, {}, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@ -750,7 +753,7 @@ async def test_get_correct_tools_general(chat_client):
user_query = "How many noble gases are there?" user_query = "How many noble gases are there?"
# Act # Act
tools = await aget_relevant_information_sources(user_query, {}, False) tools = await aget_relevant_information_sources(user_query, {}, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]
@ -774,7 +777,7 @@ async def test_get_correct_tools_with_chat_history(chat_client):
chat_history = generate_history(chat_log) chat_history = generate_history(chat_log)
# Act # Act
tools = await aget_relevant_information_sources(user_query, chat_history, False) tools = await aget_relevant_information_sources(user_query, chat_history, False, False)
# Assert # Assert
tools = [tool.value for tool in tools] tools = [tool.value for tool in tools]