Improve task scheduling by using json mode and agent scratchpad

- The task scheduling actor was having trouble calculating the
  timezone. Giving the actor a scratchpad to improve correctness by
  thinking step by step
- Add more examples to reduce chances of the inferred query looping to
  create another reminder instead of running the query and sharing
  results with user
- Improve task scheduling chat actor test with more tests and
  by ensuring unexpected words not present in response
This commit is contained in:
Debanjum Singh Solanky 2024-04-22 00:50:59 +05:30
parent 7f5981594c
commit 22289a0002
3 changed files with 84 additions and 21 deletions

View file

@ -513,27 +513,61 @@ crontime_prompt = PromptTemplate.from_template(
""" """
You are Khoj, an extremely smart and helpful task scheduling assistant You are Khoj, an extremely smart and helpful task scheduling assistant
- Given a user query, you infer the date, time to run the query at as a cronjob time string (converted to UTC time zone) - Given a user query, you infer the date, time to run the query at as a cronjob time string (converted to UTC time zone)
- Convert the cron job time to run in UTC - Convert the cron job time to run in UTC. Use the scratchpad to calculate the cron job time.
- Infer user's time zone from the current location provided in their message - Infer user's time zone from the current location provided in their message. Think step-by-step.
- Use an approximate time that makes sense, if it not unspecified. - Use an approximate time that makes sense, if it not unspecified.
- Also extract the query to run at the scheduled time. Add any context required from the chat history to improve the query. - Also extract the search query to run at the scheduled time. Add any context required from the chat history to improve the query.
- Return the scratchpad, cronjob time and the search query to run as a JSON object.
# Examples: # Examples:
## Chat History
User: Could you share a funny Calvin and Hobbes quote from my notes? User: Could you share a funny Calvin and Hobbes quote from my notes?
AI: Here is one I found: "It's not denial. I'm just selective about the reality I accept." AI: Here is one I found: "It's not denial. I'm just selective about the reality I accept."
User: Hahah, nice! Show a new one every morning at 9am. My Current Location: Shanghai, China
Khoj: ["0 1 * * *", "Share a funny Calvin and Hobbes or Bill Watterson quote from my notes."]
User: Share the top weekly posts on Hacker News on Monday evenings. Format it as a newsletter. My Current Location: Nairobi, Kenya User: Hahah, nice! Show a new one every morning at 9:40. My Current Location: Shanghai, China
Khoj: ["30 15 * * 1", "Top posts last week on Hacker News"] Khoj: {{
"Scratchpad": "Shanghai is UTC+8. So, 9:40 in Shanghai is 1:40 UTC. I'll also generalize the search query to get better results.",
"Crontime": "40 1 * * *",
"Query": "Share a funny Calvin and Hobbes or Bill Watterson quote from my notes."
}}
## Chat History
User: Every Monday evening share the top posts on Hacker News from last week. Format it as a newsletter. My Current Location: Nairobi, Kenya
Khoj: {{
"Scratchpad": "Nairobi is UTC+3. As evening specified, I'll share at 18:30 your time. Which will be 15:30 UTC.",
"Crontime": "30 15 * * 1",
"Query": "Top posts last week on Hacker News"
}}
## Chat History
User: What is the latest version of the Khoj python package? User: What is the latest version of the Khoj python package?
AI: The latest released Khoj python package version is 1.5.0. AI: The latest released Khoj python package version is 1.5.0.
User: Notify me when version 2.0.0 is released. My Current Location: Mexico City, Mexico User: Notify me when version 2.0.0 is released. My Current Location: Mexico City, Mexico
Khoj: ["0 16 * * *", "Check if the latest released version of the Khoj python package is >= 2.0.0?"] Khoj: {{
"Scratchpad": "Mexico City is UTC-6. No time is specified, so I'll notify at 10:00 your time. Which will be 16:00 in UTC. Also I'll ensure the search query doesn't trigger another reminder.",
"Crontime": "0 16 * * *",
"Query": "Check if the latest released version of the Khoj python package is >= 2.0.0?"
}}
## Chat History
User: Tell me the latest local tech news on the first Sunday of every Month. My Current Location: Dublin, Ireland User: Tell me the latest local tech news on the first Sunday of every Month. My Current Location: Dublin, Ireland
Khoj: ["0 9 1-7 * 0", "Latest tech, AI and engineering news from around Dublin, Ireland"] Khoj: {{
"Scratchpad": "Dublin is UTC+1. So, 10:00 in Dublin is 8:00 UTC. First Sunday of every month is 1-7. Also I'll enhance the search query.",
"Crontime": "0 9 1-7 * 0",
"Query": "Find the latest tech, AI and engineering news from around Dublin, Ireland"
}}
## Chat History
User: Inform me when the national election results are officially declared. Run task at 4pm every thursday. My Current Location: Trichy, India
Khoj: {{
"Scratchpad": "Trichy is UTC+5:30. So, 4pm in Trichy is 10:30 UTC. Also let's add location details to the search query.",
"Crontime": "30 10 * * 4",
"Query": "Check if the Indian national election results are officially declared."
}}
# Chat History: # Chat History:
{chat_history} {chat_history}

View file

@ -336,15 +336,15 @@ async def schedule_query(q: str, location_data: LocationData, conversation_histo
chat_history=chat_history, chat_history=chat_history,
) )
raw_response = await send_message_to_model_wrapper(crontime_prompt) raw_response = await send_message_to_model_wrapper(crontime_prompt, response_type="json_object")
# Validate that the response is a non-empty, JSON-serializable list # Validate that the response is a non-empty, JSON-serializable list
try: try:
raw_response = raw_response.strip() raw_response = raw_response.strip()
response: List[str] = json.loads(raw_response) response: Dict[str, str] = json.loads(raw_response)
if not isinstance(response, list) or not response or len(response) != 2: if not response or not isinstance(response, Dict) or len(response) != 3:
raise AssertionError(f"Invalid response for scheduling query : {response}") raise AssertionError(f"Invalid response for scheduling query : {response}")
return tuple(response) return tuple(response.values())[1:]
except Exception: except Exception:
raise AssertionError(f"Invalid response for scheduling query: {raw_response}") raise AssertionError(f"Invalid response for scheduling query: {raw_response}")

View file

@ -549,27 +549,56 @@ async def test_infer_webpage_urls_actor_extracts_correct_links(chat_client):
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"user_query, location, expected_crontime, expected_queries", "user_query, location, expected_crontime, expected_qs, unexpected_qs",
[ [
( (
"Share the weather forecast for the next day at 19:30", "Share the weather forecast for the next day daily at 7:30pm",
("Boston", "MA", "USA"), ("Ubud", "Bali", "Indonesia"),
"30 23 * * *", "30 11 * * *", # ensure correctly converts to utc
["weather forecast", "boston"], ["weather forecast", "ubud"],
["7:30"],
),
(
"Notify me when the new President of Brazil is announced",
("Sao Paulo", "Sao Paulo", "Brazil"),
"* *", # crontime is variable
["brazil", "president"],
["notify"], # ensure reminder isn't re-triggered on scheduled query run
),
(
"Let me know whenever Elon leaves Twitter. Check this every afternoon at 12",
("Karachi", "Sindh", "Pakistan"),
"0 7 * * *", # ensure correctly converts to utc
["elon", "twitter"],
["12"],
),
(
"Draw a wallpaper every morning using the current weather",
("Bogota", "Cundinamarca", "Colombia"),
"* * *", # daily crontime
["weather", "wallpaper", "bogota"],
["every"],
), ),
], ],
) )
async def test_infer_task_scheduling_request(chat_client, user_query, location, expected_crontime, expected_queries): async def test_infer_task_scheduling_request(
chat_client, user_query, location, expected_crontime, expected_qs, unexpected_qs
):
# Arrange # Arrange
location_data = LocationData(city=location[0], region=location[1], country=location[2]) location_data = LocationData(city=location[0], region=location[1], country=location[2])
# Act # Act
crontime, inferred_query = await schedule_query(user_query, location_data, {}) crontime, inferred_query = await schedule_query(user_query, location_data, {})
inferred_query = inferred_query.lower()
# Assert # Assert
assert expected_crontime in crontime assert expected_crontime in crontime
for query in expected_queries: for expected_q in expected_qs:
assert query in inferred_query.lower() assert expected_q in inferred_query, f"Expected fragment {expected_q} in query: {inferred_query}"
for unexpected_q in unexpected_qs:
assert (
unexpected_q not in inferred_query
), f"Did not expect fragment '{unexpected_q}' in query: '{inferred_query}'"
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------