Use DB adapter to unify logic to get, delete automation by auth user

To use place with logic to get, view, delete (and edit soon) automations
by (authenticated) user, instead of scattered across code
This commit is contained in:
Debanjum Singh Solanky 2024-04-30 04:00:48 +05:30
parent 1238cadd31
commit 6936875a82
2 changed files with 73 additions and 46 deletions

View file

@ -1,12 +1,16 @@
import json
import logging import logging
import math import math
import random import random
import re
import secrets import secrets
import sys import sys
from datetime import date, datetime, timedelta, timezone from datetime import date, datetime, timedelta, timezone
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Type from typing import Callable, Iterable, List, Optional, Type
import cron_descriptor
from apscheduler.job import Job
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore from django.contrib.sessions.backends.db import SessionStore
from django.db import models from django.db import models
@ -908,3 +912,57 @@ class EntryAdapters:
@staticmethod @staticmethod
def get_unique_file_sources(user: KhojUser): def get_unique_file_sources(user: KhojUser):
return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all() return Entry.objects.filter(user=user).values_list("file_source", flat=True).distinct().all()
class AutomationAdapters:
@staticmethod
def get_automations(user: KhojUser) -> Iterable[Job]:
all_automations: Iterable[Job] = state.scheduler.get_jobs()
for automation in all_automations:
if automation.id.startswith(f"automation_{user.uuid}_"):
yield automation
@staticmethod
def get_automations_metadata(user: KhojUser):
for automation in AutomationAdapters.get_automations(user):
automation_metadata = json.loads(automation.name)
crontime = automation_metadata["crontime"]
timezone = automation.next_run_time.strftime("%Z")
schedule = f"{cron_descriptor.get_description(crontime)} {timezone}"
yield {
"id": automation.id,
"subject": automation_metadata["subject"],
"query_to_run": re.sub(r"^/automated_task\s*", "", automation_metadata["query_to_run"]),
"scheduling_request": automation_metadata["scheduling_request"],
"schedule": schedule,
"next": automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z"),
}
@staticmethod
def get_automation(user: KhojUser, automation_id: str) -> Job:
# Perform validation checks
# Check if user is allowed to delete this automation id
if not automation_id.startswith(f"automation_{user.uuid}_"):
raise ValueError("Invalid automation id")
# Check if automation with this id exist
automation: Job = state.scheduler.get_job(job_id=automation_id)
if not automation:
raise ValueError("Invalid automation id")
return automation
@staticmethod
def delete_automation(user: KhojUser, automation_id: str):
# Get valid, user-owned automation
automation: Job = AutomationAdapters.get_automation(user, automation_id)
# Collate info about user automation to be deleted
automation_metadata = json.loads(automation.name)
automation_info = {
"id": automation.id,
"name": automation_metadata["query_to_run"],
"next": automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z"),
}
automation.remove()
return automation_info

View file

@ -3,7 +3,6 @@ import json
import logging import logging
import math import math
import os import os
import re
import time import time
import uuid import uuid
from typing import Any, Callable, List, Optional, Union from typing import Any, Callable, List, Optional, Union
@ -18,6 +17,7 @@ from starlette.authentication import has_required_scope, requires
from khoj.configure import initialize_content from khoj.configure import initialize_content
from khoj.database.adapters import ( from khoj.database.adapters import (
AutomationAdapters,
ConversationAdapters, ConversationAdapters,
EntryAdapters, EntryAdapters,
get_user_photo, get_user_photo,
@ -39,7 +39,7 @@ from khoj.search_filter.date_filter import DateFilter
from khoj.search_filter.file_filter import FileFilter from khoj.search_filter.file_filter import FileFilter
from khoj.search_filter.word_filter import WordFilter from khoj.search_filter.word_filter import WordFilter
from khoj.search_type import text_search from khoj.search_type import text_search
from khoj.utils import constants, state from khoj.utils import state
from khoj.utils.config import OfflineChatProcessorModel from khoj.utils.config import OfflineChatProcessorModel
from khoj.utils.helpers import ConversationCommand, timer from khoj.utils.helpers import ConversationCommand, timer
from khoj.utils.rawconfig import LocationData, SearchResponse from khoj.utils.rawconfig import LocationData, SearchResponse
@ -396,26 +396,9 @@ def user_info(request: Request) -> Response:
@requires(["authenticated"]) @requires(["authenticated"])
def get_automations(request: Request) -> Response: def get_automations(request: Request) -> Response:
user: KhojUser = request.user.object user: KhojUser = request.user.object
automations: list[Job] = state.scheduler.get_jobs()
# Collate all automations created by user that are still active # Collate all automations created by user that are still active
automations_info = [] automations_info = [automation_info for automation_info in AutomationAdapters.get_automations_metadata(user)]
for automation in automations:
if automation.id.startswith(f"automation_{user.uuid}_"):
automation_metadata = json.loads(automation.name)
crontime = automation_metadata["crontime"]
timezone = automation.next_run_time.strftime("%Z")
schedule = f"{cron_descriptor.get_description(crontime)} {timezone}"
automations_info.append(
{
"id": automation.id,
"subject": automation_metadata["subject"],
"query_to_run": re.sub(r"^/automated_task\s*", "", automation_metadata["query_to_run"]),
"scheduling_request": automation_metadata["scheduling_request"],
"schedule": schedule,
"next": automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z"),
}
)
# Return tasks information as a JSON response # Return tasks information as a JSON response
return Response(content=json.dumps(automations_info), media_type="application/json", status_code=200) return Response(content=json.dumps(automations_info), media_type="application/json", status_code=200)
@ -426,25 +409,10 @@ def get_automations(request: Request) -> Response:
def delete_automation(request: Request, automation_id: str) -> Response: def delete_automation(request: Request, automation_id: str) -> Response:
user: KhojUser = request.user.object user: KhojUser = request.user.object
# Perform validation checks try:
# Check if user is allowed to delete this automation id automation_info = AutomationAdapters.delete_automation(user, automation_id)
if not automation_id.startswith(f"automation_{user.uuid}_"): except ValueError as e:
return Response(content="Unauthorized job deletion request", status_code=403) return Response(content="Could not find automation", status_code=403)
# Check if automation with this id exist
automation: Job = state.scheduler.get_job(job_id=automation_id)
if not automation:
return Response(content="Invalid job", status_code=403)
# Collate info about user task to be deleted
automation_metadata = json.loads(automation.name)
automation_info = {
"id": automation.id,
"name": automation_metadata["query_to_run"],
"next": automation.next_run_time.strftime("%Y-%m-%d %I:%M %p %Z"),
}
# Delete job
automation.remove()
# Return deleted automation information as a JSON response # Return deleted automation information as a JSON response
return Response(content=json.dumps(automation_info), media_type="application/json", status_code=200) return Response(content=json.dumps(automation_info), media_type="application/json", status_code=200)
@ -500,13 +468,14 @@ def edit_job(
# Check at least one of query or crontime is provided # Check at least one of query or crontime is provided
if not query_to_run and not crontime: if not query_to_run and not crontime:
return Response(content="A query or crontime is required", status_code=400) return Response(content="A query or crontime is required", status_code=400)
# Check if user is allowed to edit this automation id
if not automation_id.startswith(f"automation_{user.uuid}_"): # Check, get automation to edit
return Response(content="Unauthorized automation deletion request", status_code=403) try:
# Check if automation with this id exist automation: Job = AutomationAdapters.get_automation(user, automation_id)
automation: Job = state.scheduler.get_job(job_id=automation_id) except ValueError as e:
if not automation:
return Response(content="Invalid automation", status_code=403) return Response(content="Invalid automation", status_code=403)
# Add /automated_task prefix to query if not present
if not query_to_run.startswith("/automated_task"): if not query_to_run.startswith("/automated_task"):
query_to_run = f"/automated_task {query_to_run}" query_to_run = f"/automated_task {query_to_run}"