mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-27 17:35:07 +01:00
Add chat actor to schedule run query for user at specified times
- Detect when user intends to schedule a task, aka reminder Add new output mode: reminder. Add example of selecting the reminder output mode - Extract schedule time (as cron timestring) and inferred query to run from user message - Use APScheduler to call chat with inferred query at scheduled time - Handle reminder scheduling from both websocket and http chat requests - Support constructing scheduled task using chat history as context Pass chat history to scheduled query generator for improved context for scheduled task generation
This commit is contained in:
parent
9e068fad4f
commit
c11742f443
7 changed files with 175 additions and 6 deletions
|
@ -79,6 +79,7 @@ dependencies = [
|
||||||
"websockets == 12.0",
|
"websockets == 12.0",
|
||||||
"psutil >= 5.8.0",
|
"psutil >= 5.8.0",
|
||||||
"huggingface-hub >= 0.22.2",
|
"huggingface-hub >= 0.22.2",
|
||||||
|
"apscheduler ~= 3.10.0",
|
||||||
]
|
]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ warnings.filterwarnings("ignore", message=r"legacy way to download files from th
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import django
|
import django
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
@ -126,6 +127,10 @@ def run(should_start_server=True):
|
||||||
# Setup task scheduler
|
# Setup task scheduler
|
||||||
poll_task_scheduler()
|
poll_task_scheduler()
|
||||||
|
|
||||||
|
# Setup Background Scheduler
|
||||||
|
state.scheduler = BackgroundScheduler()
|
||||||
|
state.scheduler.start()
|
||||||
|
|
||||||
# Start Server
|
# Start Server
|
||||||
configure_routes(app)
|
configure_routes(app)
|
||||||
|
|
||||||
|
|
|
@ -10,8 +10,7 @@ You were created by Khoj Inc. with the following capabilities:
|
||||||
|
|
||||||
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
|
- You *CAN REMEMBER ALL NOTES and PERSONAL INFORMATION FOREVER* that the user ever shares with you.
|
||||||
- Users can share files and other information with you using the Khoj Desktop, Obsidian or Emacs app. They can also drag and drop their files into the chat window.
|
- Users can share files and other information with you using the Khoj Desktop, Obsidian or Emacs app. They can also drag and drop their files into the chat window.
|
||||||
- You *CAN* generate images, look-up real-time information from the internet, and answer questions based on the user's notes.
|
- You *CAN* generate images, look-up real-time information from the internet, set reminders and answer questions based on the user's notes.
|
||||||
- You cannot set reminders.
|
|
||||||
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
|
||||||
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
|
||||||
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
|
||||||
|
@ -301,6 +300,22 @@ AI: I can help with that. I see online that there is a new model of the Dell XPS
|
||||||
Q: What are the specs of the new Dell XPS 15?
|
Q: What are the specs of the new Dell XPS 15?
|
||||||
Khoj: default
|
Khoj: default
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Chat History:
|
||||||
|
User: Where did I go on my last vacation?
|
||||||
|
AI: You went to Jordan and visited Petra, the Dead Sea, and Wadi Rum.
|
||||||
|
|
||||||
|
Q: Remind me who did I go with on that trip?
|
||||||
|
Khoj: default
|
||||||
|
|
||||||
|
Example:
|
||||||
|
Chat History:
|
||||||
|
User: How's the weather outside? Current Location: Bali, Indonesia
|
||||||
|
AI: It's currently 28°C and partly cloudy in Bali.
|
||||||
|
|
||||||
|
Q: Share a painting using the weather for Bali every morning.
|
||||||
|
Khoj: reminder
|
||||||
|
|
||||||
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a string.
|
Now it's your turn to pick the mode you would like to use to answer the user's question. Provide your response as a string.
|
||||||
|
|
||||||
Chat History:
|
Chat History:
|
||||||
|
@ -492,6 +507,42 @@ Khoj:
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Schedule task
|
||||||
|
# --
|
||||||
|
crontime_prompt = PromptTemplate.from_template(
|
||||||
|
"""
|
||||||
|
You are Khoj, an extremely smart and helpful task scheduling assistant
|
||||||
|
- Given a user query, you infer the date, time to run the query at as a cronjob time string (converted to UTC time zone)
|
||||||
|
- Convert the cron job time to run in UTC
|
||||||
|
- Infer user's time zone from the current location provided in their message
|
||||||
|
- Use an approximate time that makes sense, if it not unspecified.
|
||||||
|
- Also extract the query to run at the scheduled time. Add any context required from the chat history to improve the query.
|
||||||
|
|
||||||
|
# Examples:
|
||||||
|
User: Could you share a funny Calvin and Hobbes quote from my notes?
|
||||||
|
AI: Here is one I found: "It's not denial. I'm just selective about the reality I accept."
|
||||||
|
User: Hahah, nice! Show a new one every morning at 9am. My Current Location: Shanghai, China
|
||||||
|
Khoj: ["0 1 * * *", "Share a funny Calvin and Hobbes or Bill Watterson quote from my notes."]
|
||||||
|
|
||||||
|
User: Share the top weekly posts on Hacker News on Monday evenings. Format it as a newsletter. My Current Location: Nairobi, Kenya
|
||||||
|
Khoj: ["30 15 * * 1", "Top posts last week on Hacker News"]
|
||||||
|
|
||||||
|
User: What is the latest version of the Khoj python package?
|
||||||
|
AI: The latest released Khoj python package version is 1.5.0.
|
||||||
|
User: Notify me when version 2.0.0 is released. My Current Location: Mexico City, Mexico
|
||||||
|
Khoj: ["0 16 * * *", "Check if the latest released version of the Khoj python package is >= 2.0.0?"]
|
||||||
|
|
||||||
|
User: Tell me the latest local tech news on the first Sunday of every Month. My Current Location: Dublin, Ireland
|
||||||
|
Khoj: ["0 9 1-7 * 0", "Latest tech, AI and engineering news from around Dublin, Ireland"]
|
||||||
|
|
||||||
|
# Chat History:
|
||||||
|
{chat_history}
|
||||||
|
|
||||||
|
User: {query}. My Current Location: {user_location}
|
||||||
|
Khoj:
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
|
|
||||||
# System messages to user
|
# System messages to user
|
||||||
# --
|
# --
|
||||||
help_message = PromptTemplate.from_template(
|
help_message = PromptTemplate.from_template(
|
||||||
|
|
|
@ -4,7 +4,8 @@ import math
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
from asgiref.sync import sync_to_async
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
|
from asgiref.sync import async_to_sync, sync_to_async
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
|
from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
|
@ -29,12 +30,14 @@ from khoj.routers.api import extract_references_and_questions
|
||||||
from khoj.routers.helpers import (
|
from khoj.routers.helpers import (
|
||||||
ApiUserRateLimiter,
|
ApiUserRateLimiter,
|
||||||
CommonQueryParams,
|
CommonQueryParams,
|
||||||
|
CommonQueryParamsClass,
|
||||||
ConversationCommandRateLimiter,
|
ConversationCommandRateLimiter,
|
||||||
agenerate_chat_response,
|
agenerate_chat_response,
|
||||||
aget_relevant_information_sources,
|
aget_relevant_information_sources,
|
||||||
aget_relevant_output_modes,
|
aget_relevant_output_modes,
|
||||||
get_conversation_command,
|
get_conversation_command,
|
||||||
is_ready_to_chat,
|
is_ready_to_chat,
|
||||||
|
schedule_query,
|
||||||
text_to_image,
|
text_to_image,
|
||||||
update_telemetry_state,
|
update_telemetry_state,
|
||||||
validate_conversation_config,
|
validate_conversation_config,
|
||||||
|
@ -381,6 +384,55 @@ async def websocket_endpoint(
|
||||||
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
|
await conversation_command_rate_limiter.update_and_check_if_valid(websocket, cmd)
|
||||||
q = q.replace(f"/{cmd.value}", "").strip()
|
q = q.replace(f"/{cmd.value}", "").strip()
|
||||||
|
|
||||||
|
if ConversationCommand.Reminder in conversation_commands:
|
||||||
|
crontime, inferred_query = await schedule_query(q, location, meta_log)
|
||||||
|
trigger = CronTrigger.from_crontab(crontime)
|
||||||
|
common = CommonQueryParamsClass(
|
||||||
|
client=websocket.user.client_app,
|
||||||
|
user_agent=websocket.headers.get("user-agent"),
|
||||||
|
host=websocket.headers.get("host"),
|
||||||
|
)
|
||||||
|
scope = websocket.scope.copy()
|
||||||
|
scope["path"] = "/api/chat"
|
||||||
|
scope["type"] = "http"
|
||||||
|
request = Request(scope)
|
||||||
|
|
||||||
|
state.scheduler.add_job(
|
||||||
|
async_to_sync(chat),
|
||||||
|
trigger=trigger,
|
||||||
|
args=(request, common, inferred_query),
|
||||||
|
kwargs={
|
||||||
|
"stream": False,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"city": city,
|
||||||
|
"region": region,
|
||||||
|
"country": country,
|
||||||
|
},
|
||||||
|
id=f"job_{user.uuid}_{inferred_query}",
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_response = (
|
||||||
|
f'🕒 Scheduled running Query: "{inferred_query}" on Schedule: `{crontime}` (in server timezone).'
|
||||||
|
)
|
||||||
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
q,
|
||||||
|
llm_response,
|
||||||
|
user,
|
||||||
|
meta_log,
|
||||||
|
intent_type="reminder",
|
||||||
|
client_application=websocket.user.client_app,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
)
|
||||||
|
update_telemetry_state(
|
||||||
|
request=websocket,
|
||||||
|
telemetry_type="api",
|
||||||
|
api="chat",
|
||||||
|
**common.__dict__,
|
||||||
|
)
|
||||||
|
await send_complete_llm_response(llm_response)
|
||||||
|
continue
|
||||||
|
|
||||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||||
websocket, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update
|
websocket, meta_log, q, 7, 0.18, conversation_commands, location, send_status_update
|
||||||
)
|
)
|
||||||
|
@ -576,6 +628,33 @@ async def chat(
|
||||||
|
|
||||||
user_name = await aget_user_name(user)
|
user_name = await aget_user_name(user)
|
||||||
|
|
||||||
|
if ConversationCommand.Reminder in conversation_commands:
|
||||||
|
crontime, inferred_query = await schedule_query(q, location, meta_log)
|
||||||
|
trigger = CronTrigger.from_crontab(crontime)
|
||||||
|
state.scheduler.add_job(
|
||||||
|
async_to_sync(chat),
|
||||||
|
trigger=trigger,
|
||||||
|
args=(request, common, inferred_query, n, d, False, title, conversation_id, city, region, country),
|
||||||
|
id=f"job_{user.uuid}_{inferred_query}",
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_response = f'🕒 Scheduled running Query: "{inferred_query}" on Schedule: `{crontime}` (in server timezone).'
|
||||||
|
await sync_to_async(save_to_conversation_log)(
|
||||||
|
q,
|
||||||
|
llm_response,
|
||||||
|
user,
|
||||||
|
meta_log,
|
||||||
|
intent_type="reminder",
|
||||||
|
client_application=request.user.client_app,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return StreamingResponse(llm_response, media_type="text/event-stream", status_code=200)
|
||||||
|
else:
|
||||||
|
return Response(content=llm_response, media_type="text/plain", status_code=200)
|
||||||
|
|
||||||
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
compiled_references, inferred_queries, defiltered_query = await extract_references_and_questions(
|
||||||
request, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
|
request, meta_log, q, (n or 5), (d or math.inf), conversation_commands, location
|
||||||
)
|
)
|
||||||
|
|
|
@ -134,7 +134,7 @@ def update_telemetry_state(
|
||||||
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
def construct_chat_history(conversation_history: dict, n: int = 4, agent_name="AI") -> str:
|
||||||
chat_history = ""
|
chat_history = ""
|
||||||
for chat in conversation_history.get("chat", [])[-n:]:
|
for chat in conversation_history.get("chat", [])[-n:]:
|
||||||
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
|
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
||||||
chat_history += f"User: {chat['intent']['query']}\n"
|
chat_history += f"User: {chat['intent']['query']}\n"
|
||||||
chat_history += f"{agent_name}: {chat['message']}\n"
|
chat_history += f"{agent_name}: {chat['message']}\n"
|
||||||
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
elif chat["by"] == "khoj" and ("text-to-image" in chat["intent"].get("type")):
|
||||||
|
@ -312,6 +312,34 @@ async def generate_online_subqueries(q: str, conversation_history: dict, locatio
|
||||||
return [q]
|
return [q]
|
||||||
|
|
||||||
|
|
||||||
|
async def schedule_query(q: str, location_data: LocationData, conversation_history: dict) -> Tuple[str, ...]:
|
||||||
|
"""
|
||||||
|
Schedule the date, time to run the query. Assume the server timezone is UTC.
|
||||||
|
"""
|
||||||
|
user_location = (
|
||||||
|
f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Greenwich"
|
||||||
|
)
|
||||||
|
chat_history = construct_chat_history(conversation_history)
|
||||||
|
|
||||||
|
crontime_prompt = prompts.crontime_prompt.format(
|
||||||
|
query=q,
|
||||||
|
user_location=user_location,
|
||||||
|
chat_history=chat_history,
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_response = await send_message_to_model_wrapper(crontime_prompt)
|
||||||
|
|
||||||
|
# Validate that the response is a non-empty, JSON-serializable list
|
||||||
|
try:
|
||||||
|
raw_response = raw_response.strip()
|
||||||
|
response: List[str] = json.loads(raw_response)
|
||||||
|
if not isinstance(response, list) or not response or len(response) != 2:
|
||||||
|
raise AssertionError(f"Invalid response for scheduling query : {response}")
|
||||||
|
return tuple(response)
|
||||||
|
except Exception:
|
||||||
|
raise AssertionError(f"Invalid response for scheduling query: {raw_response}")
|
||||||
|
|
||||||
|
|
||||||
async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
|
async def extract_relevant_info(q: str, corpus: str) -> Union[str, None]:
|
||||||
"""
|
"""
|
||||||
Extract relevant information for a given query from the target corpus
|
Extract relevant information for a given query from the target corpus
|
||||||
|
@ -547,7 +575,7 @@ async def text_to_image(
|
||||||
text2image_model = text_to_image_config.model_name
|
text2image_model = text_to_image_config.model_name
|
||||||
chat_history = ""
|
chat_history = ""
|
||||||
for chat in conversation_log.get("chat", [])[-4:]:
|
for chat in conversation_log.get("chat", [])[-4:]:
|
||||||
if chat["by"] == "khoj" and chat["intent"].get("type") == "remember":
|
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
|
||||||
chat_history += f"Q: {chat['intent']['query']}\n"
|
chat_history += f"Q: {chat['intent']['query']}\n"
|
||||||
chat_history += f"A: {chat['message']}\n"
|
chat_history += f"A: {chat['message']}\n"
|
||||||
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
|
||||||
|
|
|
@ -304,6 +304,7 @@ class ConversationCommand(str, Enum):
|
||||||
Online = "online"
|
Online = "online"
|
||||||
Webpage = "webpage"
|
Webpage = "webpage"
|
||||||
Image = "image"
|
Image = "image"
|
||||||
|
Reminder = "reminder"
|
||||||
|
|
||||||
|
|
||||||
command_descriptions = {
|
command_descriptions = {
|
||||||
|
@ -313,6 +314,7 @@ command_descriptions = {
|
||||||
ConversationCommand.Online: "Search for information on the internet.",
|
ConversationCommand.Online: "Search for information on the internet.",
|
||||||
ConversationCommand.Webpage: "Get information from webpage links provided by you.",
|
ConversationCommand.Webpage: "Get information from webpage links provided by you.",
|
||||||
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
ConversationCommand.Image: "Generate images by describing your imagination in words.",
|
||||||
|
ConversationCommand.Reminder: "Schedule your query to run at a specified time or interval.",
|
||||||
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
|
ConversationCommand.Help: "Display a help message with all available commands and other metadata.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -325,7 +327,8 @@ tool_descriptions_for_llm = {
|
||||||
}
|
}
|
||||||
|
|
||||||
mode_descriptions_for_llm = {
|
mode_descriptions_for_llm = {
|
||||||
ConversationCommand.Image: "Use this if you think the user is requesting an image or visual response to their query.",
|
ConversationCommand.Image: "Use this if the user is requesting an image or visual response to their query.",
|
||||||
|
ConversationCommand.Reminder: "Use this if the user is requesting a response at a scheduled date or time.",
|
||||||
ConversationCommand.Default: "Use this if the other response modes don't seem to fit the query.",
|
ConversationCommand.Default: "Use this if the other response modes don't seem to fit the query.",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from whisper import Whisper
|
from whisper import Whisper
|
||||||
|
|
||||||
|
@ -29,6 +30,7 @@ cli_args: List[str] = None
|
||||||
query_cache: Dict[str, LRU] = defaultdict(LRU)
|
query_cache: Dict[str, LRU] = defaultdict(LRU)
|
||||||
chat_lock = threading.Lock()
|
chat_lock = threading.Lock()
|
||||||
SearchType = utils_config.SearchType
|
SearchType = utils_config.SearchType
|
||||||
|
scheduler: BackgroundScheduler = None
|
||||||
telemetry: List[Dict[str, str]] = []
|
telemetry: List[Dict[str, str]] = []
|
||||||
khoj_version: str = None
|
khoj_version: str = None
|
||||||
device = get_device()
|
device = get_device()
|
||||||
|
|
Loading…
Reference in a new issue