Create API webhook, endpoints for subscription payments using Stripe

- Add fields to mark users as subscribed to a specific plan and
  subscription renewal date in DB
- Add ability to unsubscribe a user using their email address
- Expose webhook for stripe to callback confirming payment
This commit is contained in:
Debanjum Singh Solanky 2023-11-07 10:15:21 -08:00
parent 156421d30a
commit 9aaf475c8a
5 changed files with 116 additions and 3 deletions

View file

@ -73,7 +73,8 @@ dependencies = [
"gunicorn == 21.2.0",
"lxml == 4.9.3",
"tzdata == 2023.3",
"rapidocr-onnxruntime == 1.3.8"
"rapidocr-onnxruntime == 1.3.8",
"stripe == 7.3.0",
]
dynamic = ["version"]

View file

@ -1,5 +1,5 @@
from typing import Type, TypeVar, List
from datetime import date
from datetime import date, datetime, timedelta
import secrets
from typing import Type, TypeVar, List
from datetime import date
@ -103,6 +103,27 @@ async def create_google_user(token: dict) -> KhojUser:
return user
async def set_user_subscribed(email: str, type="standard") -> KhojUser:
user = await KhojUser.objects.filter(email=email).afirst()
if user:
user.subscription_type = type
start_date = user.subscription_renewal_date or datetime.now()
user.subscription_renewal_date = start_date + timedelta(days=30)
await user.asave()
return user
else:
return None
def is_user_subscribed(email: str, type="standard") -> bool:
user = KhojUser.objects.filter(email=email, subscription_type=type).first()
if user and user.subscription_renewal_date:
is_subscribed = user.subscription_renewal_date > date.today()
return is_subscribed
else:
return False
async def get_user_by_token(token: dict) -> KhojUser:
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
if not google_user:

View file

@ -0,0 +1,24 @@
# Generated by Django 4.2.5 on 2023-11-07 18:19
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0012_entry_file_source"),
]
operations = [
migrations.AddField(
model_name="khojuser",
name="subscription_renewal_date",
field=models.DateTimeField(default=None, null=True),
),
migrations.AddField(
model_name="khojuser",
name="subscription_type",
field=models.CharField(
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
),
),
]

View file

@ -14,7 +14,15 @@ class BaseModel(models.Model):
class KhojUser(AbstractUser):
class SubscriptionType(models.TextChoices):
TRIAL = "trial"
STANDARD = "standard"
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
subscription_type = models.CharField(
max_length=20, choices=SubscriptionType.choices, default=SubscriptionType.TRIAL
)
subscription_renewal_date = models.DateTimeField(null=True, default=None)
def save(self, *args, **kwargs):
if not self.uuid:

View file

@ -1,6 +1,7 @@
# Standard Packages
import concurrent.futures
import math
import os
import time
import logging
import json
@ -10,6 +11,7 @@ from typing import List, Optional, Union, Any
from fastapi import APIRouter, HTTPException, Header, Request
from starlette.authentication import requires
from asgiref.sync import sync_to_async
import stripe
# Internal Packages
from khoj.configure import configure_server
@ -23,7 +25,6 @@ from khoj.utils.rawconfig import (
FullConfig,
SearchConfig,
SearchResponse,
TextContentConfig,
GithubContentConfig,
NotionContentConfig,
)
@ -723,3 +724,61 @@ async def extract_references_and_questions(
compiled_references = [item.additional["compiled"] for item in result_list]
return compiled_references, inferred_queries, defiltered_query
# Stripe integration for Khoj Cloud Subscription
stripe.api_key = os.getenv("STRIPE_API_KEY")
endpoint_secret = os.getenv("STRIPE_SIGINING_SECRET")
@api.post("/subscription")
async def subscribe(request: Request):
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
event = None
try:
payload = await request.body()
sig_header = request.headers["stripe-signature"]
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
except ValueError as e:
# Invalid payload
raise e
except stripe.error.SignatureVerificationError as e:
# Invalid signature
raise e
# Handle the event
success = True
if (
event["type"] == "payment_intent.succeeded"
or event["type"] == "invoice.payment_succeeded"
or event["type"] == "customer.subscription.created"
):
# Retrieve the customer's details
customer_id = event["data"]["object"]["customer"]
customer = stripe.Customer.retrieve(customer_id)
customer_email = customer["email"]
# Mark the customer as subscribed
user = await adapters.set_user_subscribed(customer_email)
if not user:
success = False
elif event["type"] == "customer.subscription.updated" or event["type"] == "customer.subscription.deleted":
# Retrieve the customer's details
customer_id = event["data"]["object"]["customer"]
customer = stripe.Customer.retrieve(customer_id)
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
return {"success": success}
@api.delete("/subscription")
@requires(["authenticated"])
async def unsubscribe(request: Request, user_email: str):
customer = stripe.Customer.list(email=user_email).data
if not is_none_or_empty(customer):
stripe.Subscription.modify(customer[0].id, cancel_at_period_end=True)
success = True
else:
success = False
return {"success": success}