mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 09:25:06 +01:00
Fix openai chat actor, director tests
- Update test ChatModelOptions setup since update to it's schema - Fix stale function calls using their updated signatures
This commit is contained in:
parent
f91cdf8e18
commit
f440ddbe1d
4 changed files with 62 additions and 46 deletions
|
@ -301,7 +301,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
|
|||
# Initialize Processor from Config
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
|
||||
OpenAIProcessorConversationConfigFactory()
|
||||
chat_model.openai_config = OpenAIProcessorConversationConfigFactory()
|
||||
UserConversationProcessorConfigFactory(user=user, setting=chat_model)
|
||||
|
||||
state.anonymous_mode = not require_auth
|
||||
|
|
|
@ -36,6 +36,13 @@ class ApiUserFactory(factory.django.DjangoModelFactory):
|
|||
token = factory.Faker("password")
|
||||
|
||||
|
||||
class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = OpenAIProcessorConversationConfig
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = ChatModelOptions
|
||||
|
@ -44,6 +51,7 @@ class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
|
|||
tokenizer = None
|
||||
chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
|
||||
model_type = "offline"
|
||||
openai_config = factory.SubFactory(OpenAIProcessorConversationConfigFactory)
|
||||
|
||||
|
||||
class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
|
||||
|
@ -54,13 +62,6 @@ class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
|
|||
setting = factory.SubFactory(ChatModelOptionsFactory)
|
||||
|
||||
|
||||
class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = OpenAIProcessorConversationConfig
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
|
||||
class ConversationFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = Conversation
|
||||
|
|
|
@ -250,7 +250,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content():
|
|||
# Act
|
||||
response_gen = converse(
|
||||
references=[
|
||||
"Testatron was born on 1st April 1984 in Testville."
|
||||
{"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "background.md"}
|
||||
], # Assume context retrieved from notes for the user_query
|
||||
user_query="Where was I born?",
|
||||
conversation_log=populate_chat_history(message_list),
|
||||
|
@ -304,14 +304,26 @@ def test_answer_requires_current_date_awareness():
|
|||
"Chat actor should be able to answer questions relative to current date using provided notes"
|
||||
# Arrange
|
||||
context = [
|
||||
f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
|
||||
{
|
||||
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
|
||||
"file": "Ledger.org",
|
||||
},
|
||||
{
|
||||
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
f"""2020-04-01 "SuperMercado" "Bananas"
|
||||
"file": "Ledger.org",
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-04-01 "SuperMercado" "Bananas"
|
||||
Expenses:Food:Groceries 10.00 USD""",
|
||||
f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
"file": "Ledger.org",
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
"file": "Ledger.org",
|
||||
},
|
||||
]
|
||||
|
||||
# Act
|
||||
|
@ -336,14 +348,26 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes():
|
|||
"Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
|
||||
# Arrange
|
||||
context = [
|
||||
f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
|
||||
{
|
||||
"compiled": f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
|
||||
"file": "Ledger.md",
|
||||
},
|
||||
{
|
||||
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
f"""2020-04-01 "SuperMercado" "Bananas"
|
||||
"file": "Ledger.md",
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-04-01 "SuperMercado" "Bananas"
|
||||
Expenses:Food:Groceries 10.00 USD""",
|
||||
f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
"file": "Ledger.md",
|
||||
},
|
||||
{
|
||||
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
|
||||
Expenses:Food:Dining 10.00 USD""",
|
||||
"file": "Ledger.md",
|
||||
},
|
||||
]
|
||||
|
||||
# Act
|
||||
|
@ -423,9 +447,9 @@ def test_agent_prompt_should_be_used(openai_agent):
|
|||
"Chat actor should ask be tuned to think like an accountant based on the agent definition"
|
||||
# Arrange
|
||||
context = [
|
||||
f"""I went to the store and bought some bananas for 2.20""",
|
||||
f"""I went to the store and bought some apples for 1.30""",
|
||||
f"""I went to the store and bought some oranges for 6.00""",
|
||||
{"compiled": f"""I went to the store and bought some bananas for 2.20""", "file": "Ledger.md"},
|
||||
{"compiled": f"""I went to the store and bought some apples for 1.30""", "file": "Ledger.md"},
|
||||
{"compiled": f"""I went to the store and bought some oranges for 6.00""", "file": "Ledger.md"},
|
||||
]
|
||||
expected_responses = ["9.50", "9.5"]
|
||||
|
||||
|
@ -496,10 +520,10 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client):
|
|||
@pytest.mark.parametrize(
|
||||
"user_query, expected_mode",
|
||||
[
|
||||
("What's the latest in the Israel/Palestine conflict?", "default"),
|
||||
("Summarize the latest tech news every Monday evening", "reminder"),
|
||||
("What's the latest in the Israel/Palestine conflict?", "text"),
|
||||
("Summarize the latest tech news every Monday evening", "automation"),
|
||||
("Paint a scenery in Timbuktu in the winter", "image"),
|
||||
("Remind me, when did I last visit the Serengeti?", "default"),
|
||||
("Remind me, when did I last visit the Serengeti?", "text"),
|
||||
],
|
||||
)
|
||||
async def test_use_default_response_mode(chat_client, user_query, expected_mode):
|
||||
|
@ -525,10 +549,10 @@ async def test_select_data_sources_actor_chooses_to_search_notes(
|
|||
chat_client, user_query, expected_conversation_commands
|
||||
):
|
||||
# Act
|
||||
conversation_commands = await aget_relevant_information_sources(user_query, {})
|
||||
conversation_commands = await aget_relevant_information_sources(user_query, {}, False)
|
||||
|
||||
# Assert
|
||||
assert expected_conversation_commands in conversation_commands
|
||||
assert set(expected_conversation_commands) == set(conversation_commands)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@ -549,46 +573,37 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client):
|
|||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.parametrize(
|
||||
"user_query, location, expected_crontime, expected_qs, unexpected_qs",
|
||||
"user_query, expected_crontime, expected_qs, unexpected_qs",
|
||||
[
|
||||
(
|
||||
"Share the weather forecast for the next day daily at 7:30pm",
|
||||
("Ubud", "Bali", "Indonesia"),
|
||||
"30 11 * * *", # ensure correctly converts to utc
|
||||
["weather forecast", "ubud"],
|
||||
"30 19 * * *",
|
||||
["weather forecast"],
|
||||
["7:30"],
|
||||
),
|
||||
(
|
||||
"Notify me when the new President of Brazil is announced",
|
||||
("Sao Paulo", "Sao Paulo", "Brazil"),
|
||||
"* *", # crontime is variable
|
||||
["brazil", "president"],
|
||||
["notify"], # ensure reminder isn't re-triggered on scheduled query run
|
||||
),
|
||||
(
|
||||
"Let me know whenever Elon leaves Twitter. Check this every afternoon at 12",
|
||||
("Karachi", "Sindh", "Pakistan"),
|
||||
"0 7 * * *", # ensure correctly converts to utc
|
||||
"0 12 * * *", # ensure correctly converts to utc
|
||||
["elon", "twitter"],
|
||||
["12"],
|
||||
),
|
||||
(
|
||||
"Draw a wallpaper every morning using the current weather",
|
||||
("Bogota", "Cundinamarca", "Colombia"),
|
||||
"* * *", # daily crontime
|
||||
["weather", "wallpaper", "bogota"],
|
||||
["weather", "wallpaper"],
|
||||
["every"],
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_infer_task_scheduling_request(
|
||||
chat_client, user_query, location, expected_crontime, expected_qs, unexpected_qs
|
||||
):
|
||||
# Arrange
|
||||
location_data = LocationData(city=location[0], region=location[1], country=location[2])
|
||||
|
||||
async def test_infer_task_scheduling_request(chat_client, user_query, expected_crontime, expected_qs, unexpected_qs):
|
||||
# Act
|
||||
crontime, inferred_query = await schedule_query(user_query, location_data, {})
|
||||
crontime, inferred_query, _ = await schedule_query(user_query, {})
|
||||
inferred_query = inferred_query.lower()
|
||||
|
||||
# Assert
|
||||
|
|
|
@ -516,7 +516,7 @@ async def test_get_correct_tools_online(chat_client):
|
|||
user_query = "What's the weather in Patagonia this week?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
tools = await aget_relevant_information_sources(user_query, {}, False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -531,7 +531,7 @@ async def test_get_correct_tools_notes(chat_client):
|
|||
user_query = "Where did I go for my first battleship training?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
tools = await aget_relevant_information_sources(user_query, {}, False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -546,7 +546,7 @@ async def test_get_correct_tools_online_or_general_and_notes(chat_client):
|
|||
user_query = "What's the highest point in Patagonia and have I been there?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
tools = await aget_relevant_information_sources(user_query, {}, False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -563,7 +563,7 @@ async def test_get_correct_tools_general(chat_client):
|
|||
user_query = "How many noble gases are there?"
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, {})
|
||||
tools = await aget_relevant_information_sources(user_query, {}, False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
@ -587,7 +587,7 @@ async def test_get_correct_tools_with_chat_history(chat_client):
|
|||
chat_history = generate_history(chat_log)
|
||||
|
||||
# Act
|
||||
tools = await aget_relevant_information_sources(user_query, chat_history)
|
||||
tools = await aget_relevant_information_sources(user_query, chat_history, False)
|
||||
|
||||
# Assert
|
||||
tools = [tool.value for tool in tools]
|
||||
|
|
Loading…
Reference in a new issue