Use scopes to represent whether the use has a valid subscription in the middleware

This commit is contained in:
sabaimran 2023-11-24 20:29:36 -08:00
parent c13953311a
commit 69c8f45830
5 changed files with 59 additions and 32 deletions

View file

@ -21,7 +21,12 @@ from starlette.authentication import (
# Internal Packages
from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import get_all_users, get_or_create_search_model
from khoj.database.adapters import (
get_all_users,
get_or_create_search_model,
aget_user_subscription_state,
SubscriptionState,
)
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, load_content, configure_search
from khoj.utils import constants, state
@ -70,7 +75,11 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
subscription_state = await aget_user_subscription_state(user)
subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value
if subscribed:
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
@ -82,11 +91,15 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user_with_token:
subscription_state = await aget_user_subscription_state(user_with_token.user)
subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value
if subscribed:
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
return AuthCredentials(), UnauthenticatedUser()

View file

@ -3,6 +3,7 @@ import random
import secrets
from datetime import date, datetime, timezone
from typing import List, Optional, Type
from enum import Enum
from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore
@ -40,6 +41,14 @@ from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import generate_random_name
class SubscriptionState(Enum):
TRIAL = "trial"
SUBSCRIBED = "subscribed"
UNSUBSCRIBED = "unsubscribed"
EXPIRED = "expired"
INVALID = "invalid"
async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst()
if not notion_config:
@ -127,22 +136,34 @@ async def set_user_subscription(
return None
def subscription_to_state(subscription: Subscription) -> str:
if not subscription:
return SubscriptionState.INVALID.value
elif subscription.type == Subscription.Type.TRIAL:
return SubscriptionState.TRIAL.value
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
return SubscriptionState.SUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
return SubscriptionState.UNSUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
return SubscriptionState.EXPIRED.value
return SubscriptionState.INVALID.value
def get_user_subscription_state(email: str) -> str:
"""Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
"""
user_subscription = Subscription.objects.filter(user__email=email).first()
if not user_subscription:
return "trial"
elif user_subscription.type == Subscription.Type.TRIAL:
return "trial"
elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
return "subscribed"
elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
return "unsubscribed"
elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc):
return "expired"
return "invalid"
return subscription_to_state(user_subscription)
async def aget_user_subscription_state(email: str) -> str:
"""Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
"""
user_subscription = await Subscription.objects.filter(user__email=email).afirst()
return subscription_to_state(user_subscription)
async def get_user_by_email(email: str) -> KhojUser:

View file

@ -171,7 +171,7 @@
</div>
</div>
</div>
{% if billing_enabled %}
{% if not billing_enabled %}
<div id="billing" class="section">
<h2 class="section-title">Billing</h2>
<div class="section-cards">

View file

@ -12,7 +12,7 @@ from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires
from starlette.authentication import requires, has_required_scope
# Internal Packages
from khoj.configure import configure_server

View file

@ -7,7 +7,7 @@ from fastapi import APIRouter
from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from starlette.authentication import requires
from starlette.authentication import requires, has_required_scope
from khoj.database import adapters
from khoj.database.models import KhojUser
from khoj.utils.rawconfig import (
@ -37,7 +37,6 @@ templates = Jinja2Templates(directory=constants.web_directory)
def index(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@ -46,7 +45,7 @@ def index(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@ -57,7 +56,6 @@ def index(request: Request):
def index_post(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@ -66,7 +64,7 @@ def index_post(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@ -77,7 +75,6 @@ def index_post(request: Request):
def search_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@ -86,7 +83,7 @@ def search_page(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@ -97,7 +94,6 @@ def search_page(request: Request):
def chat_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@ -106,7 +102,7 @@ def chat_page(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@ -171,7 +167,7 @@ def config_page(request: Request):
"subscription_state": user_subscription_state,
"subscription_renewal_date": subscription_renewal_date,
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@ -182,7 +178,6 @@ def config_page(request: Request):
def github_config_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
current_github_config = get_user_github_config(user)
@ -212,7 +207,7 @@ def github_config_page(request: Request):
"current_config": current_config,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@ -223,7 +218,6 @@ def github_config_page(request: Request):
def notion_config_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = adapters.get_user_subscription(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
current_notion_config = get_user_notion_config(user)
@ -240,7 +234,7 @@ def notion_config_page(request: Request):
"current_config": current_config,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)
@ -251,7 +245,6 @@ def notion_config_page(request: Request):
def computer_config_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@ -260,7 +253,7 @@ def computer_config_page(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
},
)