mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-17 08:04:21 +00:00
Move Subscription data into separate table in DB. Merge migrations
This commit is contained in:
parent
3bb10128ef
commit
8178004e6d
7 changed files with 95 additions and 75 deletions
|
@ -1,4 +1,4 @@
|
|||
from typing import Type, TypeVar, List
|
||||
from typing import Optional, Type, TypeVar, List
|
||||
from datetime import date, datetime, timedelta
|
||||
import secrets
|
||||
from typing import Type, TypeVar, List
|
||||
|
@ -30,6 +30,7 @@ from database.models import (
|
|||
GithubRepoConfig,
|
||||
Conversation,
|
||||
ChatModelOptions,
|
||||
Subscription,
|
||||
UserConversationConfig,
|
||||
OpenAIProcessorConversationConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
|
@ -103,35 +104,51 @@ async def create_google_user(token: dict) -> KhojUser:
|
|||
return user
|
||||
|
||||
|
||||
async def set_user_subscription(email: str, is_subscribed=None, renewal_date=None, type="standard") -> KhojUser:
|
||||
user = await KhojUser.objects.filter(email=email).afirst()
|
||||
if user:
|
||||
user.subscription_type = type
|
||||
if is_subscribed is not None:
|
||||
user.is_subscribed = is_subscribed
|
||||
def get_user_subscription(email: str) -> Optional[Subscription]:
|
||||
return Subscription.objects.filter(user__email=email).first()
|
||||
|
||||
|
||||
async def set_user_subscription(
|
||||
email: str, is_recurring=None, renewal_date=None, type="standard"
|
||||
) -> Optional[Subscription]:
|
||||
user_subscription = await Subscription.objects.filter(user__email=email).afirst()
|
||||
if not user_subscription:
|
||||
user = await get_user_by_email(email)
|
||||
if not user:
|
||||
return None
|
||||
user_subscription = await Subscription.objects.acreate(
|
||||
user=user, type=type, is_recurring=is_recurring, renewal_date=renewal_date
|
||||
)
|
||||
return user_subscription
|
||||
elif user_subscription:
|
||||
user_subscription.type = type
|
||||
if is_recurring is not None:
|
||||
user_subscription.is_recurring = is_recurring
|
||||
if renewal_date is False:
|
||||
user.subscription_renewal_date = None
|
||||
user_subscription.renewal_date = None
|
||||
elif renewal_date is not None:
|
||||
user.subscription_renewal_date = renewal_date
|
||||
await user.asave()
|
||||
return user
|
||||
user_subscription.renewal_date = renewal_date
|
||||
await user_subscription.asave()
|
||||
return user_subscription
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_user_subscription_state(email: str) -> str:
|
||||
def get_user_subscription_state(user_subscription: Subscription) -> str:
|
||||
"""Get subscription state of user
|
||||
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
||||
"""
|
||||
user = KhojUser.objects.filter(email=email).first()
|
||||
if user.subscription_type == KhojUser.SubscriptionType.TRIAL:
|
||||
if not user_subscription:
|
||||
return "trial"
|
||||
elif user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc):
|
||||
elif user_subscription.type == Subscription.Type.TRIAL:
|
||||
return "trial"
|
||||
elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
return "subscribed"
|
||||
elif not user.is_subscribed and user.subscription_renewal_date >= datetime.now(tz=timezone.utc):
|
||||
elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
return "unsubscribed"
|
||||
elif not user.is_subscribed and user.subscription_renewal_date < datetime.now(tz=timezone.utc):
|
||||
elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc):
|
||||
return "expired"
|
||||
return "invalid"
|
||||
|
||||
|
||||
async def get_user_by_email(email: str) -> KhojUser:
|
||||
|
|
|
@ -1,24 +0,0 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-07 18:19
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0012_entry_file_source"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="khojuser",
|
||||
name="subscription_renewal_date",
|
||||
field=models.DateTimeField(default=None, null=True),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="khojuser",
|
||||
name="subscription_type",
|
||||
field=models.CharField(
|
||||
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
|
||||
),
|
||||
),
|
||||
]
|
37
src/database/migrations/0013_subscription.py
Normal file
37
src/database/migrations/0013_subscription.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-09 01:27
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0012_entry_file_source"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="Subscription",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
(
|
||||
"type",
|
||||
models.CharField(
|
||||
choices=[("trial", "Trial"), ("standard", "Standard")], default="trial", max_length=20
|
||||
),
|
||||
),
|
||||
("is_recurring", models.BooleanField(default=False)),
|
||||
("renewal_date", models.DateTimeField(default=None, null=True)),
|
||||
(
|
||||
"user",
|
||||
models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -1,17 +0,0 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-08 19:40
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0013_khojuser_subscription_renewal_date_and_more"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="khojuser",
|
||||
name="is_subscribed",
|
||||
field=models.BooleanField(default=False),
|
||||
),
|
||||
]
|
|
@ -14,16 +14,7 @@ class BaseModel(models.Model):
|
|||
|
||||
|
||||
class KhojUser(AbstractUser):
|
||||
class SubscriptionType(models.TextChoices):
|
||||
TRIAL = "trial"
|
||||
STANDARD = "standard"
|
||||
|
||||
uuid = models.UUIDField(models.UUIDField(default=uuid.uuid4, editable=False))
|
||||
subscription_type = models.CharField(
|
||||
max_length=20, choices=SubscriptionType.choices, default=SubscriptionType.TRIAL
|
||||
)
|
||||
is_subscribed = models.BooleanField(default=False)
|
||||
subscription_renewal_date = models.DateTimeField(null=True, default=None)
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
if not self.uuid:
|
||||
|
@ -55,6 +46,17 @@ class KhojApiUser(models.Model):
|
|||
accessed_at = models.DateTimeField(null=True, default=None)
|
||||
|
||||
|
||||
class Subscription(BaseModel):
|
||||
class Type(models.TextChoices):
|
||||
TRIAL = "trial"
|
||||
STANDARD = "standard"
|
||||
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
|
||||
is_recurring = models.BooleanField(default=False)
|
||||
renewal_date = models.DateTimeField(null=True, default=None)
|
||||
|
||||
|
||||
class NotionConfig(BaseModel):
|
||||
token = models.CharField(max_length=200)
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
|
|
|
@ -4,6 +4,7 @@ import logging
|
|||
import os
|
||||
|
||||
# External Packages
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Request
|
||||
from starlette.authentication import requires
|
||||
import stripe
|
||||
|
@ -58,20 +59,20 @@ async def subscribe(request: Request):
|
|||
# Mark the user as subscribed and update the next renewal date on payment
|
||||
subscription = stripe.Subscription.list(customer=customer_id).data[0]
|
||||
renewal_date = datetime.fromtimestamp(subscription["current_period_end"], tz=timezone.utc)
|
||||
user = await adapters.set_user_subscription(customer_email, is_subscribed=True, renewal_date=renewal_date)
|
||||
user = await adapters.set_user_subscription(customer_email, is_recurring=True, renewal_date=renewal_date)
|
||||
success = user is not None
|
||||
elif event_type in {"customer.subscription.updated"}:
|
||||
user = await adapters.get_user_by_email(customer_email)
|
||||
user_subscription = await sync_to_async(adapters.get_user_subscription)(customer_email)
|
||||
# Allow updating subscription status if paid user
|
||||
if user.subscription_renewal_date:
|
||||
if user_subscription.renewal_date:
|
||||
# Mark user as unsubscribed or resubscribed
|
||||
is_subscribed = not subscription["cancel_at_period_end"]
|
||||
updated_user = await adapters.set_user_subscription(customer_email, is_subscribed=is_subscribed)
|
||||
is_recurring = not subscription["cancel_at_period_end"]
|
||||
updated_user = await adapters.set_user_subscription(customer_email, is_recurring=is_recurring)
|
||||
success = updated_user is not None
|
||||
elif event_type in {"customer.subscription.deleted"}:
|
||||
# Reset the user to trial state
|
||||
user = await adapters.set_user_subscription(
|
||||
customer_email, is_subscribed=False, renewal_date=False, type="trial"
|
||||
customer_email, is_recurring=False, renewal_date=False, type="trial"
|
||||
)
|
||||
success = user is not None
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from fastapi import Request
|
|||
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from starlette.authentication import requires
|
||||
from database import adapters
|
||||
from database.models import KhojUser
|
||||
from khoj.utils.rawconfig import (
|
||||
GithubContentConfig,
|
||||
|
@ -117,9 +118,12 @@ def login_page(request: Request):
|
|||
def config_page(request: Request):
|
||||
user: KhojUser = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
user_subscription = adapters.get_user_subscription(user.email)
|
||||
user_subscription_state = get_user_subscription_state(user_subscription)
|
||||
subscription_renewal_date = (
|
||||
user.subscription_renewal_date.strftime("%d %b %Y") if user.subscription_renewal_date else None
|
||||
user_subscription.renewal_date.strftime("%d %b %Y")
|
||||
if user_subscription and user_subscription.renewal_date
|
||||
else None
|
||||
)
|
||||
enabled_content_source = set(EntryAdapters.get_unique_file_source(user).all())
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue