mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Create API webhook, endpoints for subscription payments using Stripe
- Add fields to mark users as subscribed to a specific plan and subscription renewal date in DB - Add ability to unsubscribe a user using their email address - Expose webhook for stripe to callback confirming payment
This commit is contained in:
parent
156421d30a
commit
9aaf475c8a
5 changed files with 116 additions and 3 deletions
|
@ -73,7 +73,8 @@ dependencies = [
|
|||
"gunicorn == 21.2.0",
|
||||
"lxml == 4.9.3",
|
||||
"tzdata == 2023.3",
|
||||
"rapidocr-onnxruntime == 1.3.8"
|
||||
"rapidocr-onnxruntime == 1.3.8",
|
||||
"stripe == 7.3.0",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Type, TypeVar, List
|
||||
from datetime import date
|
||||
from datetime import date, datetime, timedelta
|
||||
import secrets
|
||||
from typing import Type, TypeVar, List
|
||||
from datetime import date
|
||||
|
@ -103,6 +103,27 @@ async def create_google_user(token: dict) -> KhojUser:
|
|||
return user
|
||||
|
||||
|
||||
async def set_user_subscribed(email: str, type="standard") -> KhojUser:
|
||||
user = await KhojUser.objects.filter(email=email).afirst()
|
||||
if user:
|
||||
user.subscription_type = type
|
||||
start_date = user.subscription_renewal_date or datetime.now()
|
||||
user.subscription_renewal_date = start_date + timedelta(days=30)
|
||||
await user.asave()
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def is_user_subscribed(email: str, type="standard") -> bool:
|
||||
user = KhojUser.objects.filter(email=email, subscription_type=type).first()
|
||||
if user and user.subscription_renewal_date:
|
||||
is_subscribed = user.subscription_renewal_date > date.today()
|
||||
return is_subscribed
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
async def get_user_by_token(token: dict) -> KhojUser:
|
||||
google_user = await GoogleUser.objects.filter(sub=token.get("sub")).select_related("user").afirst()
|
||||
if not google_user:
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-07 18:19
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0012_entry_file_source"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="khojuser",
|
||||
name="subscription_renewal_date",
|
||||
field=models.DateTimeField(default=None, null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="khojuser",
|
||||
name="subscription_type",
|
||||
field=models.CharField(
|
||||
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
|
||||
),
|
||||
),
|
||||
]
|
|
@ -14,7 +14,15 @@ class BaseModel(models.Model):
|
|||
|
||||
|
||||
class KhojUser(AbstractUser):
|
||||
class SubscriptionType(models.TextChoices):
|
||||
TRIAL = "trial"
|
||||
STANDARD = "standard"
|
||||
|
||||
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
|
||||
subscription_type = models.CharField(
|
||||
max_length=20, choices=SubscriptionType.choices, default=SubscriptionType.TRIAL
|
||||
)
|
||||
subscription_renewal_date = models.DateTimeField(null=True, default=None)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if not self.uuid:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Standard Packages
|
||||
import concurrent.futures
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
|
@ -10,6 +11,7 @@ from typing import List, Optional, Union, Any
|
|||
from fastapi import APIRouter, HTTPException, Header, Request
|
||||
from starlette.authentication import requires
|
||||
from asgiref.sync import sync_to_async
|
||||
import stripe
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_server
|
||||
|
@ -23,7 +25,6 @@ from khoj.utils.rawconfig import (
|
|||
FullConfig,
|
||||
SearchConfig,
|
||||
SearchResponse,
|
||||
TextContentConfig,
|
||||
GithubContentConfig,
|
||||
NotionContentConfig,
|
||||
)
|
||||
|
@ -723,3 +724,61 @@ async def extract_references_and_questions(
|
|||
compiled_references = [item.additional["compiled"] for item in result_list]
|
||||
|
||||
return compiled_references, inferred_queries, defiltered_query
|
||||
|
||||
|
||||
# Stripe integration for Khoj Cloud Subscription
|
||||
stripe.api_key = os.getenv("STRIPE_API_KEY")
|
||||
endpoint_secret = os.getenv("STRIPE_SIGINING_SECRET")
|
||||
|
||||
|
||||
@api.post("/subscription")
|
||||
async def subscribe(request: Request):
|
||||
"""Webhook for Stripe to send subscription events to Khoj Cloud"""
|
||||
event = None
|
||||
try:
|
||||
payload = await request.body()
|
||||
sig_header = request.headers["stripe-signature"]
|
||||
event = stripe.Webhook.construct_event(payload, sig_header, endpoint_secret)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise e
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise e
|
||||
|
||||
# Handle the event
|
||||
success = True
|
||||
if (
|
||||
event["type"] == "payment_intent.succeeded"
|
||||
or event["type"] == "invoice.payment_succeeded"
|
||||
or event["type"] == "customer.subscription.created"
|
||||
):
|
||||
# Retrieve the customer's details
|
||||
customer_id = event["data"]["object"]["customer"]
|
||||
customer = stripe.Customer.retrieve(customer_id)
|
||||
customer_email = customer["email"]
|
||||
# Mark the customer as subscribed
|
||||
user = await adapters.set_user_subscribed(customer_email)
|
||||
if not user:
|
||||
success = False
|
||||
elif event["type"] == "customer.subscription.updated" or event["type"] == "customer.subscription.deleted":
|
||||
# Retrieve the customer's details
|
||||
customer_id = event["data"]["object"]["customer"]
|
||||
customer = stripe.Customer.retrieve(customer_id)
|
||||
|
||||
logger.info(f'Stripe subscription {event["type"]} for {customer["email"]}')
|
||||
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@api.delete("/subscription")
|
||||
@requires(["authenticated"])
|
||||
async def unsubscribe(request: Request, user_email: str):
|
||||
customer = stripe.Customer.list(email=user_email).data
|
||||
if not is_none_or_empty(customer):
|
||||
stripe.Subscription.modify(customer[0].id, cancel_at_period_end=True)
|
||||
success = True
|
||||
else:
|
||||
success = False
|
||||
|
||||
return {"success": success}
|
||||
|
|
Loading…
Add table
Reference in a new issue