mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Fix, improve openai chat actor, director tests & online search prompt
This commit is contained in:
parent
9986c183ea
commit
238bc11a50
4 changed files with 55 additions and 46 deletions
|
@ -578,11 +578,11 @@ Khoj:
|
||||||
|
|
||||||
online_search_conversation_subqueries = PromptTemplate.from_template(
|
online_search_conversation_subqueries = PromptTemplate.from_template(
|
||||||
"""
|
"""
|
||||||
You are Khoj, an advanced google search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
|
You are Khoj, an advanced web search assistant. You are tasked with constructing **up to three** google search queries to answer the user's question.
|
||||||
- You will receive the conversation history as context.
|
- You will receive the conversation history as context.
|
||||||
- Add as much context from the previous questions and answers as required into your search queries.
|
- Add as much context from the previous questions and answers as required into your search queries.
|
||||||
- Break messages into multiple search queries when required to retrieve the relevant information.
|
- Break messages into multiple search queries when required to retrieve the relevant information.
|
||||||
- Use site: google search operators when appropriate
|
- Use site: google search operator when appropriate
|
||||||
- You have access to the the whole internet to retrieve information.
|
- You have access to the the whole internet to retrieve information.
|
||||||
- Official, up-to-date information about you, Khoj, is available at site:khoj.dev, github or pypi.
|
- Official, up-to-date information about you, Khoj, is available at site:khoj.dev, github or pypi.
|
||||||
|
|
||||||
|
|
|
@ -196,7 +196,7 @@ def openai_agent():
|
||||||
return Agent.objects.create(
|
return Agent.objects.create(
|
||||||
name="Accountant",
|
name="Accountant",
|
||||||
chat_model=chat_model,
|
chat_model=chat_model,
|
||||||
personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent.",
|
personality="You are a certified CPA. You are able to tell me how much I've spent based on my notes. Regardless of what I ask, you should always respond with the total amount I've spent. ALWAYS RESPOND WITH A SUMMARY TOTAL OF HOW MUCH MONEY I HAVE SPENT.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -405,10 +405,10 @@ def test_answer_general_question_not_in_chat_history_or_retrieved_content():
|
||||||
response = "".join([response_chunk for response_chunk in response_gen])
|
response = "".join([response_chunk for response_chunk in response_gen])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["test", "Test"]
|
expected_responses = ["test", "bug", "code"]
|
||||||
assert len(response.splitlines()) == 3 # haikus are 3 lines long
|
assert len(response.splitlines()) == 3 # haikus are 3 lines long
|
||||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
assert any([expected_response in response.lower() for expected_response in expected_responses]), (
|
||||||
"Expected [T|t]est in response, but got: " + response
|
"Expected haiku about unit test, but got: " + response
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -441,7 +441,13 @@ My sister, Aiyla is married to Tolga. They have 3 kids, Yildiz, Ali and Ahmet.""
|
||||||
response = "".join([response_chunk for response_chunk in response_gen])
|
response = "".join([response_chunk for response_chunk in response_gen])
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["which sister", "Which sister", "which of your sister", "Which of your sister"]
|
expected_responses = [
|
||||||
|
"which sister",
|
||||||
|
"Which sister",
|
||||||
|
"which of your sister",
|
||||||
|
"Which of your sister",
|
||||||
|
"Could you provide",
|
||||||
|
]
|
||||||
assert any([expected_response in response for expected_response in expected_responses]), (
|
assert any([expected_response in response for expected_response in expected_responses]), (
|
||||||
"Expected chat actor to ask for clarification in response, but got: " + response
|
"Expected chat actor to ask for clarification in response, but got: " + response
|
||||||
)
|
)
|
||||||
|
@ -555,7 +561,7 @@ async def test_select_data_sources_actor_chooses_to_search_notes(
|
||||||
chat_client, user_query, expected_conversation_commands
|
chat_client, user_query, expected_conversation_commands
|
||||||
):
|
):
|
||||||
# Act
|
# Act
|
||||||
conversation_commands = await aget_relevant_information_sources(user_query, {}, False)
|
conversation_commands = await aget_relevant_information_sources(user_query, {}, False, False)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert set(expected_conversation_commands) == set(conversation_commands)
|
assert set(expected_conversation_commands) == set(conversation_commands)
|
||||||
|
|
|
@ -72,9 +72,8 @@ def test_chat_with_online_content(chat_client):
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = [
|
expected_responses = [
|
||||||
"https://paulgraham.com/greatwork.html",
|
"paulgraham.com/greatwork.html",
|
||||||
"https://www.paulgraham.com/greatwork.html",
|
"paulgraham.com/hwh.html",
|
||||||
"http://www.paulgraham.com/greatwork.html",
|
|
||||||
]
|
]
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert any(
|
assert any(
|
||||||
|
@ -112,7 +111,7 @@ def test_answer_from_chat_history(chat_client, default_user2: KhojUser):
|
||||||
create_conversation(message_list, default_user2)
|
create_conversation(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
response = chat_client.get(f'/api/chat?q="What is my name?"')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.content.decode("utf-8")
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -259,8 +258,8 @@ def test_answer_from_retrieved_content_using_notes_command(chat_client, default_
|
||||||
create_conversation(message_list, default_user2)
|
create_conversation(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f"/api/chat?q={query}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
@ -310,8 +309,8 @@ def test_summarize_one_file(chat_client, default_user2: KhojUser):
|
||||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||||
)
|
)
|
||||||
query = urllib.parse.quote("/summarize")
|
query = urllib.parse.quote("/summarize")
|
||||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message != ""
|
assert response_message != ""
|
||||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||||
|
@ -342,8 +341,8 @@ def test_summarize_extra_text(chat_client, default_user2: KhojUser):
|
||||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||||
)
|
)
|
||||||
query = urllib.parse.quote("/summarize tell me about Xiu")
|
query = urllib.parse.quote("/summarize tell me about Xiu")
|
||||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message != ""
|
assert response_message != ""
|
||||||
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
assert response_message != "No files selected for summarization. Please add files using the section on the left."
|
||||||
|
@ -370,8 +369,8 @@ def test_summarize_multiple_files(chat_client, default_user2: KhojUser):
|
||||||
)
|
)
|
||||||
|
|
||||||
query = urllib.parse.quote("/summarize")
|
query = urllib.parse.quote("/summarize")
|
||||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message == "Only one file can be selected for summarization."
|
assert response_message == "Only one file can be selected for summarization."
|
||||||
|
@ -383,8 +382,8 @@ def test_summarize_no_files(chat_client, default_user2: KhojUser):
|
||||||
message_list = []
|
message_list = []
|
||||||
conversation = create_conversation(message_list, default_user2)
|
conversation = create_conversation(message_list, default_user2)
|
||||||
query = urllib.parse.quote("/summarize")
|
query = urllib.parse.quote("/summarize")
|
||||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||||
|
|
||||||
|
@ -415,15 +414,15 @@ def test_summarize_different_conversation(chat_client, default_user2: KhojUser):
|
||||||
)
|
)
|
||||||
|
|
||||||
query = urllib.parse.quote("/summarize")
|
query = urllib.parse.quote("/summarize")
|
||||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation2.id}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation2.id}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||||
|
|
||||||
# now make sure that the file filter is still in conversation 1
|
# now make sure that the file filter is still in conversation 1
|
||||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation1.id}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation1.id}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message != ""
|
assert response_message != ""
|
||||||
|
@ -442,8 +441,8 @@ def test_summarize_nonexistant_file(chat_client, default_user2: KhojUser):
|
||||||
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
|
json={"filename": "imaginary.markdown", "conversation_id": str(conversation.id)},
|
||||||
)
|
)
|
||||||
query = urllib.parse.quote("/summarize")
|
query = urllib.parse.quote("/summarize")
|
||||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||||
|
|
||||||
|
@ -471,8 +470,8 @@ def test_summarize_diff_user_file(chat_client, default_user: KhojUser, pdf_confi
|
||||||
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
json={"filename": summarization_file, "conversation_id": str(conversation.id)},
|
||||||
)
|
)
|
||||||
query = urllib.parse.quote("/summarize")
|
query = urllib.parse.quote("/summarize")
|
||||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
# Assert
|
# Assert
|
||||||
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
assert response_message == "No files selected for summarization. Please add files using the section on the left."
|
||||||
|
|
||||||
|
@ -503,8 +502,8 @@ def test_answer_requires_current_date_awareness(chat_client):
|
||||||
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
|
def test_answer_requires_date_aware_aggregation_across_provided_notes(chat_client):
|
||||||
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
"Chat director should be able to answer questions that require date aware aggregation across multiple notes"
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"&stream=true')
|
response = chat_client.get(f'/api/chat?q="How much did I spend on dining this year?"')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
@ -575,8 +574,8 @@ def test_answer_in_chat_history_beyond_lookback_window(chat_client, default_user
|
||||||
create_conversation(message_list, default_user2)
|
create_conversation(message_list, default_user2)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="What is my name?"&stream=true')
|
response = chat_client.get(f'/api/chat?q="What is my name?"')
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
expected_responses = ["Testatron", "testatron"]
|
expected_responses = ["Testatron", "testatron"]
|
||||||
|
@ -631,20 +630,24 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
|
||||||
("When was I born?", "You were born on 1st April 1984.", []),
|
("When was I born?", "You were born on 1st April 1984.", []),
|
||||||
("What's my favorite color", "Your favorite color is green.", []),
|
("What's my favorite color", "Your favorite color is green.", []),
|
||||||
("Where was I born?", "You were born Testville.", []),
|
("Where was I born?", "You were born Testville.", []),
|
||||||
("What did I buy?", "You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00", []),
|
(
|
||||||
|
"What did I buy?",
|
||||||
|
"You bought an apple for 2.00, an orange for 3.00, and a potato for 8.00 for breakfast",
|
||||||
|
[],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
conversation = create_conversation(message_list, default_user2, openai_agent)
|
conversation = create_conversation(message_list, default_user2, openai_agent)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
query = urllib.parse.quote("/general What did I eat for breakfast?")
|
query = urllib.parse.quote("/general What did I buy for breakfast?")
|
||||||
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}&stream=true")
|
response = chat_client.get(f"/api/chat?q={query}&conversation_id={conversation.id}")
|
||||||
response_message = response.content.decode("utf-8")
|
response_message = response.json()["response"]
|
||||||
|
|
||||||
# Assert that agent only responds with the summary of spending
|
# Assert that agent only responds with the summary of spending
|
||||||
expected_responses = ["13.00", "13", "13.0", "thirteen"]
|
expected_responses = ["13.00", "13", "13.0", "thirteen"]
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
assert any([expected_response in response_message.lower() for expected_response in expected_responses]), (
|
||||||
"Expected green in response, but got: " + response_message
|
"Expected amount in response, but got: " + response_message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -654,7 +657,7 @@ def test_answer_in_chat_history_by_conversation_id_with_agent(
|
||||||
def test_answer_requires_multiple_independent_searches(chat_client):
|
def test_answer_requires_multiple_independent_searches(chat_client):
|
||||||
"Chat director should be able to answer by doing multiple independent searches for required information"
|
"Chat director should be able to answer by doing multiple independent searches for required information"
|
||||||
# Act
|
# Act
|
||||||
response = chat_client.get(f'/api/chat?q="Is Xi older than Namita? Just the older persons full name"')
|
response = chat_client.get(f'/api/chat?q="Is Xi Li older than Namita? Just say the older persons full name"')
|
||||||
response_message = response.json()["response"].lower()
|
response_message = response.json()["response"].lower()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
|
@ -676,7 +679,7 @@ def test_answer_using_file_filter(chat_client):
|
||||||
"Chat should be able to use search filters in the query"
|
"Chat should be able to use search filters in the query"
|
||||||
# Act
|
# Act
|
||||||
query = urllib.parse.quote(
|
query = urllib.parse.quote(
|
||||||
'Is Xi older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
|
'Is Xi Li older than Namita? Just say the older persons full name. file:"Namita.markdown" file:"Xi Li.markdown"'
|
||||||
)
|
)
|
||||||
|
|
||||||
response = chat_client.get(f"/api/chat?q={query}")
|
response = chat_client.get(f"/api/chat?q={query}")
|
||||||
|
@ -703,7 +706,7 @@ async def test_get_correct_tools_online(chat_client):
|
||||||
user_query = "What's the weather in Patagonia this week?"
|
user_query = "What's the weather in Patagonia this week?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
tools = await aget_relevant_information_sources(user_query, {}, False)
|
tools = await aget_relevant_information_sources(user_query, {}, False, False)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
tools = [tool.value for tool in tools]
|
tools = [tool.value for tool in tools]
|
||||||
|
@ -718,7 +721,7 @@ async def test_get_correct_tools_notes(chat_client):
|
||||||
user_query = "Where did I go for my first battleship training?"
|
user_query = "Where did I go for my first battleship training?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
tools = await aget_relevant_information_sources(user_query, {}, False)
|
tools = await aget_relevant_information_sources(user_query, {}, False, False)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
tools = [tool.value for tool in tools]
|
tools = [tool.value for tool in tools]
|
||||||
|
@ -733,7 +736,7 @@ async def test_get_correct_tools_online_or_general_and_notes(chat_client):
|
||||||
user_query = "What's the highest point in Patagonia and have I been there?"
|
user_query = "What's the highest point in Patagonia and have I been there?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
tools = await aget_relevant_information_sources(user_query, {}, False)
|
tools = await aget_relevant_information_sources(user_query, {}, False, False)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
tools = [tool.value for tool in tools]
|
tools = [tool.value for tool in tools]
|
||||||
|
@ -750,7 +753,7 @@ async def test_get_correct_tools_general(chat_client):
|
||||||
user_query = "How many noble gases are there?"
|
user_query = "How many noble gases are there?"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
tools = await aget_relevant_information_sources(user_query, {}, False)
|
tools = await aget_relevant_information_sources(user_query, {}, False, False)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
tools = [tool.value for tool in tools]
|
tools = [tool.value for tool in tools]
|
||||||
|
@ -774,7 +777,7 @@ async def test_get_correct_tools_with_chat_history(chat_client):
|
||||||
chat_history = generate_history(chat_log)
|
chat_history = generate_history(chat_log)
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
tools = await aget_relevant_information_sources(user_query, chat_history, False)
|
tools = await aget_relevant_information_sources(user_query, chat_history, False, False)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
tools = [tool.value for tool in tools]
|
tools = [tool.value for tool in tools]
|
||||||
|
|
Loading…
Reference in a new issue