Move Subscription data into separate table in DB. Merge migrations

This commit is contained in:
Debanjum Singh Solanky 2023-11-08 17:45:25 -08:00
parent 3bb10128ef
commit 8178004e6d
7 changed files with 95 additions and 75 deletions

View file

@ -1,4 +1,4 @@
from typing import Type, TypeVar, List
from typing import Optional, Type, TypeVar, List
from datetime import date, datetime, timedelta
import secrets
from typing import Type, TypeVar, List
@ -30,6 +30,7 @@ from database.models import (
GithubRepoConfig,
Conversation,
ChatModelOptions,
Subscription,
UserConversationConfig,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
@ -103,35 +104,51 @@ async def create_google_user(token: dict) -> KhojUser:
return user
async def set_user_subscription(email: str, is_subscribed=None, renewal_date=None, type="standard") -> KhojUser:
user = await KhojUser.objects.filter(email=email).afirst()
if user:
user.subscription_type = type
if is_subscribed is not None:
user.is_subscribed = is_subscribed
def get_user_subscription(email: str) -> Optional[Subscription]:
return Subscription.objects.filter(user__email=email).first()
async def set_user_subscription(
email: str, is_recurring=None, renewal_date=None, type="standard"
) -> Optional[Subscription]:
user_subscription = await Subscription.objects.filter(user__email=email).afirst()
if not user_subscription:
user = await get_user_by_email(email)
if not user:
return None
user_subscription = await Subscription.objects.acreate(
user=user, type=type, is_recurring=is_recurring, renewal_date=renewal_date
)
return user_subscription
elif user_subscription:
user_subscription.type = type
if is_recurring is not None:
user_subscription.is_recurring = is_recurring
if renewal_date is False:
user.subscription_renewal_date = None
user_subscription.renewal_date = None
elif renewal_date is not None:
user.subscription_renewal_date = renewal_date
await user.asave()
return user
user_subscription.renewal_date = renewal_date
await user_subscription.asave()
return user_subscription
else:
return None
def get_user_subscription_state(email: str) -> str:
def get_user_subscription_state(user_subscription: Subscription) -> str:
"""Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
"""
user = KhojUser.objects.filter(email=email).first()
if user.subscription_type == KhojUser.SubscriptionType.TRIAL:
if not user_subscription:
return "trial"
elif user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc):
elif user_subscription.type == Subscription.Type.TRIAL:
return "trial"
elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
return "subscribed"
elif not user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc):
elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
return "unsubscribed"
elif not user.is_subscribed and user.subscription_renewal_date < datetime.now(tz=timezone.utc):
elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc):
return "expired"
return "invalid"
async def get_user_by_email(email: str) -> KhojUser:

View file

@ -1,24 +0,0 @@
# Generated by Django 4.2.5 on 2023-11-07 18:19
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0012_entry_file_source"),
]
operations = [
migrations.AddField(
model_name="khojuser",
name="subscription_renewal_date",
field=models.DateTimeField(default=None, null=True),
),
migrations.AddField(
model_name="khojuser",
name="subscription_type",
field=models.CharField(
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
),
),
]

View file

@ -0,0 +1,37 @@
# Generated by Django 4.2.5 on 2023-11-09 01:27
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0012_entry_file_source"),
]
operations = [
migrations.CreateModel(
name="Subscription",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
(
"type",
models.CharField(
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
),
),
("is_recurring", models.BooleanField(default=False)),
("renewal_date", models.DateTimeField(default=None, null=True)),
(
"user",
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
),
],
options={
"abstract": False,
},
),
]

View file

@ -1,17 +0,0 @@
# Generated by Django 4.2.5 on 2023-11-08 19:40
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0013_khojuser_subscription_renewal_date_and_more"),
]
operations = [
migrations.AddField(
model_name="khojuser",
name="is_subscribed",
field=models.BooleanField(default=False),
),
]

View file

@ -14,16 +14,7 @@ class BaseModel(models.Model):
class KhojUser(AbstractUser):
class SubscriptionType(models.TextChoices):
TRIAL = "trial"
STANDARD = "standard"
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
subscription_type = models.CharField(
max_length=20, choices=SubscriptionType.choices, default=SubscriptionType.TRIAL
)
is_subscribed = models.BooleanField(default=False)
subscription_renewal_date = models.DateTimeField(null=True, default=None)
def save(self, *args, **kwargs):
if not self.uuid:
@ -55,6 +46,17 @@ class KhojApiUser(models.Model):
accessed_at = models.DateTimeField(null=True, default=None)
class Subscription(BaseModel):
class Type(models.TextChoices):
TRIAL = "trial"
STANDARD = "standard"
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
is_recurring = models.BooleanField(default=False)
renewal_date = models.DateTimeField(null=True, default=None)
class NotionConfig(BaseModel):
token = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)

View file

@ -4,6 +4,7 @@ import logging
import os
# External Packages
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request
from starlette.authentication import requires
import stripe
@ -58,20 +59,20 @@ async def subscribe(request: Request):
# Mark the user as subscribed and update the next renewal date on payment
subscription = stripe.Subscription.list(customer=customer_id).data[0]
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date)
user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date)
success = user is not None
elif event_type in {"customer.subscription.updated"}:
user = await adapters.get_user_by_email(customer_email)
user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email)
# Allow updating subscription status if paid user
if user.subscription_renewal_date:
if user_subscription.renewal_date:
# Mark user as unsubscribed or resubscribed
is_subscribed = not subscription["cancel_at_period_end"]
updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed)
is_recurring = not subscription["cancel_at_period_end"]
updated_user = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring)
success = updated_user is not None
elif event_type in {"customer.subscription.deleted"}:
# Reset the user to trial state
user = await adapters.set_user_subscription(
customer_email, is_subscribed=False, renewal_date=False, type="trial"
customer_email, is_recurring=False, renewal_date=False, type="trial"
)
success = user is not None

View file

@ -8,6 +8,7 @@ from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from starlette.authentication import requires
from database import adapters
from database.models import KhojUser
from khoj.utils.rawconfig import (
GithubContentConfig,
@ -117,9 +118,12 @@ def login_page(request: Request):
def config_page(request: Request):
user: KhojUser = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
user_subscription = adapters.get_user_subscription(user.email)
user_subscription_state = get_user_subscription_state(user_subscription)
subscription_renewal_date = (
user.subscription_renewal_date.strftime("%d %b %Y") if user.subscription_renewal_date else None
user_subscription.renewal_date.strftime("%d %b %Y")
if user_subscription and user_subscription.renewal_date
else None
)
enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all())