mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Only notify when scheduled task results satisfy user's requirements
There's a difference between running a scheduled task and notifying the user about the results of running the scheduled task. Decide to notify the user only when the results of running the scheduled task satisfy the user's requirements. Use sync version of send_message_to_model_wrapper for scheduled tasks
This commit is contained in:
parent
7e084ef1e0
commit
7f5981594c
4 changed files with 183 additions and 13 deletions
|
@ -543,6 +543,47 @@ Khoj:
|
|||
""".strip()
|
||||
)
|
||||
|
||||
to_notify_or_not = PromptTemplate.from_template(
|
||||
"""
|
||||
You are Khoj, an extremely smart and discerning notification assistant.
|
||||
- Decide whether the user should be notified of the AI's response using the Original User Query, Executed User Query and AI Response triplet.
|
||||
- Notify the user only if the AI's response satisfies the user specified requirements.
|
||||
- You should only respond with a "Yes" or "No". Do not say anything else.
|
||||
|
||||
# Examples:
|
||||
Original User Query: Hahah, nice! Show a new one every morning at 9am. My Current Location: Shanghai, China
|
||||
Executed User Query: Could you share a funny Calvin and Hobbes quote from my notes?
|
||||
AI Reponse: Here is one I found: "It's not denial. I'm just selective about the reality I accept."
|
||||
Khoj: Yes
|
||||
|
||||
Original User Query: Every evening check if it's going to rain tomorrow. Notify me only if I'll need an umbrella. My Current Location: Nairobi, Kenya
|
||||
Executed User Query: Is it going to rain tomorrow in Nairobi, Kenya
|
||||
AI Response: Tomorrow's forecast is sunny with a high of 28°C and a low of 18°C
|
||||
Khoj: No
|
||||
|
||||
Original User Query: Tell me when version 2.0.0 is released. My Current Location: Mexico City, Mexico
|
||||
Executed User Query: Check if version 2.0.0 of the Khoj python package is released
|
||||
AI Response: The latest released Khoj python package version is 1.5.0.
|
||||
Khoj: No
|
||||
|
||||
Original User Query: Paint me a sunset every evening. My Current Location: Shanghai, China
|
||||
Executed User Query: Paint me a sunset in Shanghai, China
|
||||
AI Response: https://khoj-generated-images.khoj.dev/user110/image78124.webp
|
||||
Khoj: Yes
|
||||
|
||||
Original User Query: Share a summary of the tasks I've completed at the end of the day. My Current Location: Oslo, Norway
|
||||
Executed User Query: Share a summary of the tasks I've completed today.
|
||||
AI Response: I'm sorry, I couldn't find any relevant notes to respond to your message.
|
||||
Khoj: No
|
||||
|
||||
Original User Query: {original_query}
|
||||
Executed User Query: {executed_query}
|
||||
AI Response: {response}
|
||||
Khoj:
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
# System messages to user
|
||||
# --
|
||||
help_message = PromptTemplate.from_template(
|
||||
|
|
|
@ -8,7 +8,7 @@ from typing import Dict, Optional
|
|||
from urllib.parse import unquote
|
||||
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from asgiref.sync import async_to_sync, sync_to_async
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
|
@ -404,7 +404,7 @@ async def websocket_endpoint(
|
|||
# Generate the job id from the hash of inferred_query and crontime
|
||||
job_id = hashlib.md5(f"{inferred_query}_{crontime}".encode("utf-8")).hexdigest()
|
||||
partial_scheduled_chat = functools.partial(
|
||||
scheduled_chat, inferred_query, websocket.user.object, websocket.url
|
||||
scheduled_chat, inferred_query, q, websocket.user.object, websocket.url
|
||||
)
|
||||
try:
|
||||
job = state.scheduler.add_job(
|
||||
|
@ -668,7 +668,7 @@ async def chat(
|
|||
|
||||
# Generate the job id from the hash of inferred_query and crontime
|
||||
job_id = hashlib.md5(f"{inferred_query}_{crontime}".encode("utf-8")).hexdigest()
|
||||
partial_scheduled_chat = functools.partial(scheduled_chat, inferred_query, request.user.object, request.url)
|
||||
partial_scheduled_chat = functools.partial(scheduled_chat, inferred_query, q, request.user.object, request.url)
|
||||
try:
|
||||
job = state.scheduler.add_job(
|
||||
run_with_process_lock,
|
||||
|
|
|
@ -475,6 +475,51 @@ async def send_message_to_model_wrapper(
|
|||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
||||
|
||||
def send_message_to_model_wrapper_sync(
|
||||
message: str,
|
||||
system_message: str = "",
|
||||
response_type: str = "text",
|
||||
):
|
||||
conversation_config: ChatModelOptions = ConversationAdapters.get_default_conversation_config()
|
||||
|
||||
if conversation_config is None:
|
||||
raise HTTPException(status_code=500, detail="Contact the server administrator to set a default chat model.")
|
||||
|
||||
chat_model = conversation_config.chat_model
|
||||
max_tokens = conversation_config.max_prompt_size
|
||||
|
||||
if conversation_config.model_type == "offline":
|
||||
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
|
||||
state.offline_chat_processor_config = OfflineChatProcessorModel(chat_model, max_tokens)
|
||||
|
||||
loaded_model = state.offline_chat_processor_config.loaded_model
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message, system_message=system_message, model_name=chat_model, loaded_model=loaded_model
|
||||
)
|
||||
|
||||
return send_message_to_model_offline(
|
||||
messages=truncated_messages,
|
||||
loaded_model=loaded_model,
|
||||
model=chat_model,
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
elif conversation_config.model_type == "openai":
|
||||
openai_chat_config = ConversationAdapters.get_openai_conversation_config()
|
||||
api_key = openai_chat_config.api_key
|
||||
truncated_messages = generate_chatml_messages_with_context(
|
||||
user_message=message, system_message=system_message, model_name=chat_model
|
||||
)
|
||||
|
||||
openai_response = send_message_to_model(
|
||||
messages=truncated_messages, api_key=api_key, model=chat_model, response_type=response_type
|
||||
)
|
||||
|
||||
return openai_response
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Invalid conversation config")
|
||||
|
||||
|
||||
def generate_chat_response(
|
||||
q: str,
|
||||
meta_log: dict,
|
||||
|
@ -790,16 +835,41 @@ class CommonQueryParamsClass:
|
|||
CommonQueryParams = Annotated[CommonQueryParamsClass, Depends()]
|
||||
|
||||
|
||||
def scheduled_chat(query, user: KhojUser, calling_url: URL):
|
||||
# Construct the URL, header for the chat API
|
||||
scheme = "http" if calling_url.scheme == "http" or calling_url.scheme == "ws" else "https"
|
||||
# Replace the original scheduling query with the scheduled query
|
||||
query_dict = parse_qs(calling_url.query)
|
||||
query_dict["q"] = [query]
|
||||
# Convert the dictionary back into a query string
|
||||
scheduled_query = urlencode(query_dict, doseq=True)
|
||||
url = f"{scheme}://{calling_url.netloc}/api/chat?{scheduled_query}"
|
||||
def should_notify(original_query: str, executed_query: str, ai_response: str) -> bool:
|
||||
"""
|
||||
Decide whether to notify the user of the AI response.
|
||||
Default to notifying the user for now.
|
||||
"""
|
||||
if any(is_none_or_empty(message) for message in [original_query, executed_query, ai_response]):
|
||||
return False
|
||||
|
||||
to_notify_or_not = prompts.to_notify_or_not.format(
|
||||
original_query=original_query,
|
||||
executed_query=executed_query,
|
||||
response=ai_response,
|
||||
)
|
||||
|
||||
with timer("Chat actor: Decide to notify user of AI response", logger):
|
||||
try:
|
||||
response = send_message_to_model_wrapper_sync(to_notify_or_not)
|
||||
return "no" not in response.lower()
|
||||
except:
|
||||
return True
|
||||
|
||||
|
||||
def scheduled_chat(executing_query: str, scheduling_query: str, user: KhojUser, calling_url: URL):
|
||||
# Extract relevant params from the original URL
|
||||
scheme = "http" if not calling_url.is_secure else "https"
|
||||
query_dict = parse_qs(calling_url.query)
|
||||
|
||||
# Replace the original scheduling query with the scheduled query
|
||||
query_dict["q"] = [executing_query]
|
||||
|
||||
# Construct the URL to call the chat API with the scheduled query string
|
||||
encoded_query = urlencode(query_dict, doseq=True)
|
||||
url = f"{scheme}://{calling_url.netloc}/api/chat?{encoded_query}"
|
||||
|
||||
# Construct the Headers for the chat API
|
||||
headers = {"User-Agent": "Khoj"}
|
||||
if not state.anonymous_mode:
|
||||
# Add authorization request header in non-anonymous mode
|
||||
|
@ -811,4 +881,20 @@ def scheduled_chat(query, user: KhojUser, calling_url: URL):
|
|||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
# Call the chat API endpoint with authenticated user token and query
|
||||
return requests.get(url, headers=headers)
|
||||
raw_response = requests.get(url, headers=headers)
|
||||
|
||||
# Stop if the chat API call was not successful
|
||||
if raw_response.status_code != 200:
|
||||
logger.error(f"Failed to run schedule chat: {raw_response.text}")
|
||||
return None
|
||||
|
||||
# Extract the AI response from the chat API response
|
||||
if raw_response.headers.get("Content-Type") == "application/json":
|
||||
response_map = raw_response.json()
|
||||
ai_response = response_map.get("response") or response_map.get("image")
|
||||
else:
|
||||
ai_response = raw_response.text
|
||||
|
||||
# Notify user if the AI response is satisfactory
|
||||
if should_notify(original_query=scheduling_query, executed_query=executing_query, ai_response=ai_response):
|
||||
return raw_response
|
||||
|
|
|
@ -13,6 +13,7 @@ from khoj.routers.helpers import (
|
|||
generate_online_subqueries,
|
||||
infer_webpage_urls,
|
||||
schedule_query,
|
||||
should_notify,
|
||||
)
|
||||
from khoj.utils.helpers import ConversationCommand
|
||||
from khoj.utils.rawconfig import LocationData
|
||||
|
@ -571,6 +572,48 @@ async def test_infer_task_scheduling_request(chat_client, user_query, location,
|
|||
assert query in inferred_query.lower()
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
@pytest.mark.parametrize(
|
||||
"scheduling_query, executing_query, generated_response, expected_should_notify",
|
||||
[
|
||||
(
|
||||
"Notify me if it is going to rain tomorrow?",
|
||||
"What's the weather forecast for tomorrow?",
|
||||
"It is sunny and warm tomorrow.",
|
||||
False,
|
||||
),
|
||||
(
|
||||
"Summarize the latest news every morning",
|
||||
"Summarize today's news",
|
||||
"Today in the news: AI is taking over the world",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"Create a weather wallpaper every morning using the current weather",
|
||||
"Paint a weather wallpaper using the current weather",
|
||||
"https://khoj-generated-wallpaper.khoj.dev/user110/weathervane.webp",
|
||||
True,
|
||||
),
|
||||
(
|
||||
"Let me know the election results once they are offically declared",
|
||||
"What are the results of the elections? Has the winner been declared?",
|
||||
"The election results has not been declared yet.",
|
||||
False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_decision_on_when_to_notify_scheduled_task_results(
|
||||
chat_client, scheduling_query, executing_query, generated_response, expected_should_notify
|
||||
):
|
||||
# Act
|
||||
generated_should_notify = should_notify(scheduling_query, executing_query, generated_response)
|
||||
|
||||
# Assert
|
||||
assert generated_should_notify == expected_should_notify
|
||||
|
||||
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def populate_chat_history(message_list):
|
||||
|
|
Loading…
Reference in a new issue