Fix openai chat actor, director tests

- Update test ChatModelOptions setup since update to it's schema
- Fix stale function calls using their updated signatures
This commit is contained in:
Debanjum Singh Solanky 2024-06-09 07:16:55 +05:30
parent f91cdf8e18
commit f440ddbe1d
4 changed files with 62 additions and 46 deletions

View file

@ -301,7 +301,7 @@ def chat_client_builder(search_config, user, index_content=True, require_auth=Fa
# Initialize Processor from Config
if os.getenv("OPENAI_API_KEY"):
chat_model = ChatModelOptionsFactory(chat_model="gpt-3.5-turbo", model_type="openai")
OpenAIProcessorConversationConfigFactory()
chat_model.openai_config = OpenAIProcessorConversationConfigFactory()
UserConversationProcessorConfigFactory(user=user, setting=chat_model)
state.anonymous_mode = not require_auth

View file

@ -36,6 +36,13 @@ class ApiUserFactory(factory.django.DjangoModelFactory):
token = factory.Faker("password")
class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = OpenAIProcessorConversationConfig
api_key = os.getenv("OPENAI_API_KEY")
class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
class Meta:
model = ChatModelOptions
@ -44,6 +51,7 @@ class ChatModelOptionsFactory(factory.django.DjangoModelFactory):
tokenizer = None
chat_model = "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF"
model_type = "offline"
openai_config = factory.SubFactory(OpenAIProcessorConversationConfigFactory)
class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
@ -54,13 +62,6 @@ class UserConversationProcessorConfigFactory(factory.django.DjangoModelFactory):
setting = factory.SubFactory(ChatModelOptionsFactory)
class OpenAIProcessorConversationConfigFactory(factory.django.DjangoModelFactory):
class Meta:
model = OpenAIProcessorConversationConfig
api_key = os.getenv("OPENAI_API_KEY")
class ConversationFactory(factory.django.DjangoModelFactory):
class Meta:
model = Conversation

View file

@ -250,7 +250,7 @@ def test_answer_from_chat_history_and_currently_retrieved_content():
# Act
response_gen = converse(
references=[
"Testatron was born on 1st April 1984 in Testville."
{"compiled": "Testatron was born on 1st April 1984 in Testville.", "file": "background.md"}
], # Assume context retrieved from notes for the user_query
user_query="Where was I born?",
conversation_log=populate_chat_history(message_list),
@ -304,14 +304,26 @@ def test_answer_requires_current_date_awareness():
"Chat actor should be able to answer questions relative to current date using provided notes"
# Arrange
context = [
f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
{
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
Expenses:Food:Dining 10.00 USD""",
f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
"file": "Ledger.org",
},
{
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
Expenses:Food:Dining 10.00 USD""",
f"""2020-04-01 "SuperMercado" "Bananas"
"file": "Ledger.org",
},
{
"compiled": f"""2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries 10.00 USD""",
f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
"file": "Ledger.org",
},
{
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining 10.00 USD""",
"file": "Ledger.org",
},
]
# Act
@ -336,14 +348,26 @@ def test_answer_requires_date_aware_aggregation_across_provided_notes():
"Chat actor should be able to answer questions that require date aware aggregation across multiple notes"
# Arrange
context = [
f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
{
"compiled": f"""# {datetime.now().strftime("%Y-%m-%d")} "Naco Taco" "Tacos for Dinner"
Expenses:Food:Dining 10.00 USD""",
f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
"file": "Ledger.md",
},
{
"compiled": f"""{datetime.now().strftime("%Y-%m-%d")} "Sagar Ratna" "Dosa for Lunch"
Expenses:Food:Dining 10.00 USD""",
f"""2020-04-01 "SuperMercado" "Bananas"
"file": "Ledger.md",
},
{
"compiled": f"""2020-04-01 "SuperMercado" "Bananas"
Expenses:Food:Groceries 10.00 USD""",
f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
"file": "Ledger.md",
},
{
"compiled": f"""2020-01-01 "Naco Taco" "Burittos for Dinner"
Expenses:Food:Dining 10.00 USD""",
"file": "Ledger.md",
},
]
# Act
@ -423,9 +447,9 @@ def test_agent_prompt_should_be_used(openai_agent):
"Chat actor should ask be tuned to think like an accountant based on the agent definition"
# Arrange
context = [
f"""I went to the store and bought some bananas for 2.20""",
f"""I went to the store and bought some apples for 1.30""",
f"""I went to the store and bought some oranges for 6.00""",
{"compiled": f"""I went to the store and bought some bananas for 2.20""", "file": "Ledger.md"},
{"compiled": f"""I went to the store and bought some apples for 1.30""", "file": "Ledger.md"},
{"compiled": f"""I went to the store and bought some oranges for 6.00""", "file": "Ledger.md"},
]
expected_responses = ["9.50", "9.5"]
@ -496,10 +520,10 @@ async def test_websearch_khoj_website_for_info_about_khoj(chat_client):
@pytest.mark.parametrize(
"user_query, expected_mode",
[
("What's the latest in the Israel/Palestine conflict?", "default"),
("Summarize the latest tech news every Monday evening", "reminder"),
("What's the latest in the Israel/Palestine conflict?", "text"),
("Summarize the latest tech news every Monday evening", "automation"),
("Paint a scenery in Timbuktu in the winter", "image"),
("Remind me, when did I last visit the Serengeti?", "default"),
("Remind me, when did I last visit the Serengeti?", "text"),
],
)
async def test_use_default_response_mode(chat_client, user_query, expected_mode):
@ -525,10 +549,10 @@ async def test_select_data_sources_actor_chooses_to_search_notes(
chat_client, user_query, expected_conversation_commands
):
# Act
conversation_commands = await aget_relevant_information_sources(user_query, {})
conversation_commands = await aget_relevant_information_sources(user_query, {}, False)
# Assert
assert expected_conversation_commands in conversation_commands
assert set(expected_conversation_commands) == set(conversation_commands)
# ----------------------------------------------------------------------------------------------------
@ -549,46 +573,37 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client):
@pytest.mark.anyio
@pytest.mark.django_db(transaction=True)
@pytest.mark.parametrize(
"user_query, location, expected_crontime, expected_qs, unexpected_qs",
"user_query, expected_crontime, expected_qs, unexpected_qs",
[
(
"Share the weather forecast for the next day daily at 7:30pm",
("Ubud", "Bali", "Indonesia"),
"30 11 * * *", # ensure correctly converts to utc
["weather forecast", "ubud"],
"30 19 * * *",
["weather forecast"],
["7:30"],
),
(
"Notify me when the new President of Brazil is announced",
("Sao Paulo", "Sao Paulo", "Brazil"),
"* *", # crontime is variable
["brazil", "president"],
["notify"], # ensure reminder isn't re-triggered on scheduled query run
),
(
"Let me know whenever Elon leaves Twitter. Check this every afternoon at 12",
("Karachi", "Sindh", "Pakistan"),
"0 7 * * *", # ensure correctly converts to utc
"0 12 * * *", # ensure correctly converts to utc
["elon", "twitter"],
["12"],
),
(
"Draw a wallpaper every morning using the current weather",
("Bogota", "Cundinamarca", "Colombia"),
"* * *", # daily crontime
["weather", "wallpaper", "bogota"],
["weather", "wallpaper"],
["every"],
),
],
)
async def test_infer_task_scheduling_request(
chat_client, user_query, location, expected_crontime, expected_qs, unexpected_qs
):
# Arrange
location_data = LocationData(city=location[0], region=location[1], country=location[2])
async def test_infer_task_scheduling_request(chat_client, user_query, expected_crontime, expected_qs, unexpected_qs):
# Act
crontime, inferred_query = await schedule_query(user_query, location_data, {})
crontime, inferred_query, _ = await schedule_query(user_query, {})
inferred_query = inferred_query.lower()
# Assert

View file

@ -516,7 +516,7 @@ async def test_get_correct_tools_online(chat_client):
user_query = "What's the weather in Patagonia this week?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
tools = await aget_relevant_information_sources(user_query, {}, False)
# Assert
tools = [tool.value for tool in tools]
@ -531,7 +531,7 @@ async def test_get_correct_tools_notes(chat_client):
user_query = "Where did I go for my first battleship training?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
tools = await aget_relevant_information_sources(user_query, {}, False)
# Assert
tools = [tool.value for tool in tools]
@ -546,7 +546,7 @@ async def test_get_correct_tools_online_or_general_and_notes(chat_client):
user_query = "What's the highest point in Patagonia and have I been there?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
tools = await aget_relevant_information_sources(user_query, {}, False)
# Assert
tools = [tool.value for tool in tools]
@ -563,7 +563,7 @@ async def test_get_correct_tools_general(chat_client):
user_query = "How many noble gases are there?"
# Act
tools = await aget_relevant_information_sources(user_query, {})
tools = await aget_relevant_information_sources(user_query, {}, False)
# Assert
tools = [tool.value for tool in tools]
@ -587,7 +587,7 @@ async def test_get_correct_tools_with_chat_history(chat_client):
chat_history = generate_history(chat_log)
# Act
tools = await aget_relevant_information_sources(user_query, chat_history)
tools = await aget_relevant_information_sources(user_query, chat_history, False)
# Assert
tools = [tool.value for tool in tools]