mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Move subscription API to separate, independent router
This commit is contained in:
parent
ec1395d072
commit
3bb10128ef
4 changed files with 111 additions and 96 deletions
|
@ -330,7 +330,7 @@ class EntryAdapters:
|
|||
return deleted_count
|
||||
|
||||
@staticmethod
|
||||
def delete_all_entries_by_source(user: KhojUser, file_source: str = None):
|
||||
def delete_all_entries(user: KhojUser, file_source: str = None):
|
||||
if file_source is None:
|
||||
deleted_count, _ = Entry.objects.filter(user=user).delete()
|
||||
else:
|
||||
|
|
|
@ -145,10 +145,12 @@ def configure_routes(app):
|
|||
from khoj.routers.web_client import web_client
|
||||
from khoj.routers.indexer import indexer
|
||||
from khoj.routers.auth import auth_router
|
||||
from khoj.routers.subscription import subscription_router
|
||||
|
||||
app.include_router(api, prefix="/api")
|
||||
app.include_router(api_beta, prefix="/api/beta")
|
||||
app.include_router(indexer, prefix="/api/v1/index")
|
||||
app.include_router(subscription_router, prefix="/api/subscription")
|
||||
app.include_router(web_client)
|
||||
app.include_router(auth_router, prefix="/auth")
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
# Standard Packages
|
||||
import concurrent.futures
|
||||
from datetime import datetime, timezone
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
|
@ -12,7 +10,6 @@ from typing import List, Optional, Union, Any
|
|||
from fastapi import APIRouter, HTTPException, Header, Request
|
||||
from starlette.authentication import requires
|
||||
from asgiref.sync import sync_to_async
|
||||
import stripe
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_server
|
||||
|
@ -245,7 +242,7 @@ async def remove_content_source_data(
|
|||
raise ValueError(f"Invalid content source: {content_source}")
|
||||
elif content_object != "Computer":
|
||||
await content_object.objects.filter(user=user).adelete()
|
||||
await sync_to_async(EntryAdapters.delete_all_entries_by_source)(user, content_source)
|
||||
await sync_to_async(EntryAdapters.delete_all_entries)(user, content_source)
|
||||
|
||||
enabled_content = await sync_to_async(EntryAdapters.get_unique_file_types)(user)
|
||||
return {"status": "ok"}
|
||||
|
@ -725,94 +722,3 @@ async def extract_references_and_questions(
|
|||
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||
|
||||
return compiled_references, inferred_queries, defiltered_query
|
||||
|
||||
|
||||
# Stripe integration for Khoj Cloud Subscription
|
||||
stripe.api_key = os.getenv("STRIPE_API_KEY")
|
||||
endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET")
|
||||
|
||||
|
||||
@api.post("/subscription")
|
||||
async def subscribe(request: Request):
|
||||
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
|
||||
event = None
|
||||
try:
|
||||
payload = await request.body()
|
||||
sig_header = request.headers["stripe-signature"]
|
||||
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise e
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise e
|
||||
|
||||
event_type = event["type"]
|
||||
if event_type not in {
|
||||
"invoice.paid",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
"subscription_schedule.canceled",
|
||||
}:
|
||||
logger.warn(f"Unhandled Stripe event type: {event['type']}")
|
||||
return {"success": False}
|
||||
|
||||
# Retrieve the customer's details
|
||||
subscription = event["data"]["object"]
|
||||
customer_id = subscription["customer"]
|
||||
customer = stripe.Customer.retrieve(customer_id)
|
||||
customer_email = customer["email"]
|
||||
|
||||
# Handle valid stripe webhook events
|
||||
success = True
|
||||
if event_type in {"invoice.paid"}:
|
||||
# Mark the user as subscribed and update the next renewal date on payment
|
||||
subscription = stripe.Subscription.list(customer=customer_id).data[0]
|
||||
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
|
||||
user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date)
|
||||
success = user is not None
|
||||
elif event_type in {"customer.subscription.updated"}:
|
||||
user = await adapters.get_user_by_email(customer_email)
|
||||
# Allow updating subscription status if paid user
|
||||
if user.subscription_renewal_date:
|
||||
# Mark user as unsubscribed or resubscribed
|
||||
is_subscribed = not subscription["cancel_at_period_end"]
|
||||
updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed)
|
||||
success = updated_user is not None
|
||||
elif event_type in {"customer.subscription.deleted"}:
|
||||
# Reset the user to trial state
|
||||
user = await adapters.set_user_subscription(
|
||||
customer_email, is_subscribed=False, renewal_date=False, type="trial"
|
||||
)
|
||||
success = user is not None
|
||||
|
||||
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@api.patch("/subscription")
|
||||
@requires(["authenticated"])
|
||||
async def unsubscribe(request: Request, email: str, operation: str):
|
||||
# Retrieve the customer's details
|
||||
customers = stripe.Customer.list(email=email).auto_paging_iter()
|
||||
customer = next(customers, None)
|
||||
if customer is None:
|
||||
return {"success": False, "message": "Customer not found"}
|
||||
|
||||
if operation == "cancel":
|
||||
customer_id = customer.id
|
||||
for subscription in stripe.Subscription.list(customer=customer_id):
|
||||
stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
|
||||
return {"success": True}
|
||||
|
||||
elif operation == "resubscribe":
|
||||
subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
|
||||
# Find the subscription that is set to cancel at the end of the period
|
||||
for subscription in subscriptions:
|
||||
if subscription.cancel_at_period_end:
|
||||
# Update the subscription to not cancel at the end of the period
|
||||
stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
|
||||
return {"success": True}
|
||||
return {"success": False, "message": "No subscription found that is set to cancel"}
|
||||
|
||||
return {"success": False, "message": "Invalid operation"}
|
||||
|
|
107
src/khoj/routers/subscription.py
Normal file
107
src/khoj/routers/subscription.py
Normal file
|
@ -0,0 +1,107 @@
|
|||
# Standard Packages
|
||||
from datetime import datetime, timezone
|
||||
import logging
|
||||
import os
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, Request
|
||||
from starlette.authentication import requires
|
||||
import stripe
|
||||
|
||||
# Internal Packages
|
||||
from database import adapters
|
||||
|
||||
# Stripe integration for Khoj Cloud Subscription
|
||||
stripe.api_key = os.getenv("STRIPE_API_KEY")
|
||||
endpoint_secret = os.getenv("STRIPE_SIGNING_SECRET")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
subscription_router = APIRouter()
|
||||
|
||||
|
||||
@subscription_router.post("")
|
||||
async def subscribe(request: Request):
|
||||
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
|
||||
event = None
|
||||
try:
|
||||
payload = await request.body()
|
||||
sig_header = request.headers["stripe-signature"]
|
||||
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise e
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise e
|
||||
|
||||
event_type = event["type"]
|
||||
if event_type not in {
|
||||
"invoice.paid",
|
||||
"customer.subscription.updated",
|
||||
"customer.subscription.deleted",
|
||||
"subscription_schedule.canceled",
|
||||
}:
|
||||
logger.warn(f"Unhandled Stripe event type: {event['type']}")
|
||||
return {"success": False}
|
||||
|
||||
# Retrieve the customer's details
|
||||
subscription = event["data"]["object"]
|
||||
customer_id = subscription["customer"]
|
||||
customer = stripe.Customer.retrieve(customer_id)
|
||||
customer_email = customer["email"]
|
||||
|
||||
# Handle valid stripe webhook events
|
||||
success = True
|
||||
if event_type in {"invoice.paid"}:
|
||||
# Mark the user as subscribed and update the next renewal date on payment
|
||||
subscription = stripe.Subscription.list(customer=customer_id).data[0]
|
||||
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
|
||||
user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date)
|
||||
success = user is not None
|
||||
elif event_type in {"customer.subscription.updated"}:
|
||||
user = await adapters.get_user_by_email(customer_email)
|
||||
# Allow updating subscription status if paid user
|
||||
if user.subscription_renewal_date:
|
||||
# Mark user as unsubscribed or resubscribed
|
||||
is_subscribed = not subscription["cancel_at_period_end"]
|
||||
updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed)
|
||||
success = updated_user is not None
|
||||
elif event_type in {"customer.subscription.deleted"}:
|
||||
# Reset the user to trial state
|
||||
user = await adapters.set_user_subscription(
|
||||
customer_email, is_subscribed=False, renewal_date=False, type="trial"
|
||||
)
|
||||
success = user is not None
|
||||
|
||||
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@subscription_router.patch("")
|
||||
@requires(["authenticated"])
|
||||
async def update_subscription(request: Request, email: str, operation: str):
|
||||
# Retrieve the customer's details
|
||||
customers = stripe.Customer.list(email=email).auto_paging_iter()
|
||||
customer = next(customers, None)
|
||||
if customer is None:
|
||||
return {"success": False, "message": "Customer not found"}
|
||||
|
||||
if operation == "cancel":
|
||||
customer_id = customer.id
|
||||
for subscription in stripe.Subscription.list(customer=customer_id):
|
||||
stripe.Subscription.modify(subscription.id, cancel_at_period_end=True)
|
||||
return {"success": True}
|
||||
|
||||
elif operation == "resubscribe":
|
||||
subscriptions = stripe.Subscription.list(customer=customer.id).auto_paging_iter()
|
||||
# Find the subscription that is set to cancel at the end of the period
|
||||
for subscription in subscriptions:
|
||||
if subscription.cancel_at_period_end:
|
||||
# Update the subscription to not cancel at the end of the period
|
||||
stripe.Subscription.modify(subscription.id, cancel_at_period_end=False)
|
||||
return {"success": True}
|
||||
return {"success": False, "message": "No subscription found that is set to cancel"}
|
||||
|
||||
return {"success": False, "message": "Invalid operation"}
|
Loading…
Add table
Reference in a new issue