From 69c8f45830f02c4d7b19865a57779e17b85b8f74 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 24 Nov 2023 20:29:36 -0800 Subject: [PATCH] Use scopes to represent whether the use has a valid subscription in the middleware --- src/khoj/configure.py | 19 ++++++++++-- src/khoj/database/adapters/__init__.py | 43 +++++++++++++++++++------- src/khoj/interface/web/config.html | 2 +- src/khoj/routers/api.py | 2 +- src/khoj/routers/web_client.py | 25 ++++++--------- 5 files changed, 59 insertions(+), 32 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index d34205d9..4e4e1008 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -21,7 +21,12 @@ from starlette.authentication import ( # Internal Packages from khoj.database.models import KhojUser, Subscription -from khoj.database.adapters import get_all_users, get_or_create_search_model +from khoj.database.adapters import ( + get_all_users, + get_or_create_search_model, + aget_user_subscription_state, + SubscriptionState, +) from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.routers.indexer import configure_content, load_content, configure_search from khoj.utils import constants, state @@ -70,7 +75,11 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user: - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + subscription_state = await aget_user_subscription_state(user) + subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value + if subscribed: + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: # Get bearer token from header bearer_token = request.headers["Authorization"].split("Bearer ")[1] @@ -82,11 +91,15 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user_with_token: + subscription_state = await aget_user_subscription_state(user_with_token.user) + subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value + if subscribed: + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) if state.anonymous_mode: user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() if user: - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) return AuthCredentials(), UnauthenticatedUser() diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 7fd04006..146de11c 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -3,6 +3,7 @@ import random import secrets from datetime import date, datetime, timezone from typing import List, Optional, Type +from enum import Enum from asgiref.sync import sync_to_async from django.contrib.sessions.backends.db import SessionStore @@ -40,6 +41,14 @@ from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.helpers import generate_random_name +class SubscriptionState(Enum): + TRIAL = "trial" + SUBSCRIBED = "subscribed" + UNSUBSCRIBED = "unsubscribed" + EXPIRED = "expired" + INVALID = "invalid" + + async def set_notion_config(token: str, user: KhojUser): notion_config = await NotionConfig.objects.filter(user=user).afirst() if not notion_config: @@ -127,22 +136,34 @@ async def set_user_subscription( return None +def subscription_to_state(subscription: Subscription) -> str: + if not subscription: + return SubscriptionState.INVALID.value + elif subscription.type == Subscription.Type.TRIAL: + return SubscriptionState.TRIAL.value + elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): + return SubscriptionState.SUBSCRIBED.value + elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): + return SubscriptionState.UNSUBSCRIBED.value + elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc): + return SubscriptionState.EXPIRED.value + return SubscriptionState.INVALID.value + + def get_user_subscription_state(email: str) -> str: """Get subscription state of user Valid state transitions: trial -> subscribed <-> unsubscribed OR expired """ user_subscription = Subscription.objects.filter(user__email=email).first() - if not user_subscription: - return "trial" - elif user_subscription.type == Subscription.Type.TRIAL: - return "trial" - elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): - return "subscribed" - elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): - return "unsubscribed" - elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc): - return "expired" - return "invalid" + return subscription_to_state(user_subscription) + + +async def aget_user_subscription_state(email: str) -> str: + """Get subscription state of user + Valid state transitions: trial -> subscribed <-> unsubscribed OR expired + """ + user_subscription = await Subscription.objects.filter(user__email=email).afirst() + return subscription_to_state(user_subscription) async def get_user_by_email(email: str) -> KhojUser: diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 01a3786f..318759df 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -171,7 +171,7 @@ - {% if billing_enabled %} + {% if not billing_enabled %}

Billing

diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 83955088..d5b6ce0e 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -12,7 +12,7 @@ from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, Header, HTTPException, Request from fastapi.requests import Request from fastapi.responses import Response, StreamingResponse -from starlette.authentication import requires +from starlette.authentication import requires, has_required_scope # Internal Packages from khoj.configure import configure_server diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index dab16fa8..bf3ff957 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -7,7 +7,7 @@ from fastapi import APIRouter from fastapi import Request from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse from fastapi.templating import Jinja2Templates -from starlette.authentication import requires +from starlette.authentication import requires, has_required_scope from khoj.database import adapters from khoj.database.models import KhojUser from khoj.utils.rawconfig import ( @@ -37,7 +37,6 @@ templates = Jinja2Templates(directory=constants.web_directory) def index(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -46,7 +45,7 @@ def index(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -57,7 +56,6 @@ def index(request: Request): def index_post(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -66,7 +64,7 @@ def index_post(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -77,7 +75,6 @@ def index_post(request: Request): def search_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -86,7 +83,7 @@ def search_page(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -97,7 +94,6 @@ def search_page(request: Request): def chat_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -106,7 +102,7 @@ def chat_page(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -171,7 +167,7 @@ def config_page(request: Request): "subscription_state": user_subscription_state, "subscription_renewal_date": subscription_renewal_date, "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -182,7 +178,6 @@ def config_page(request: Request): def github_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) current_github_config = get_user_github_config(user) @@ -212,7 +207,7 @@ def github_config_page(request: Request): "current_config": current_config, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -223,7 +218,6 @@ def github_config_page(request: Request): def notion_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = adapters.get_user_subscription(user.email) has_documents = EntryAdapters.user_has_entries(user=user) current_notion_config = get_user_notion_config(user) @@ -240,7 +234,7 @@ def notion_config_page(request: Request): "current_config": current_config, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -251,7 +245,6 @@ def notion_config_page(request: Request): def computer_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -260,7 +253,7 @@ def computer_config_page(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, )