mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 19:03:01 +01:00
Use scopes to represent whether the use has a valid subscription in the middleware
This commit is contained in:
parent
c13953311a
commit
69c8f45830
5 changed files with 59 additions and 32 deletions
|
@ -21,7 +21,12 @@ from starlette.authentication import (
|
|||
|
||||
# Internal Packages
|
||||
from khoj.database.models import KhojUser, Subscription
|
||||
from khoj.database.adapters import get_all_users, get_or_create_search_model
|
||||
from khoj.database.adapters import (
|
||||
get_all_users,
|
||||
get_or_create_search_model,
|
||||
aget_user_subscription_state,
|
||||
SubscriptionState,
|
||||
)
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||
from khoj.utils import constants, state
|
||||
|
@ -70,7 +75,11 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
.afirst()
|
||||
)
|
||||
if user:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
subscription_state = await aget_user_subscription_state(user)
|
||||
subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
|
||||
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
|
||||
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
||||
# Get bearer token from header
|
||||
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
||||
|
@ -82,11 +91,15 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
.afirst()
|
||||
)
|
||||
if user_with_token:
|
||||
subscription_state = await aget_user_subscription_state(user_with_token.user)
|
||||
subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
if state.anonymous_mode:
|
||||
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
||||
if user:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
|
||||
|
||||
return AuthCredentials(), UnauthenticatedUser()
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import random
|
|||
import secrets
|
||||
from datetime import date, datetime, timezone
|
||||
from typing import List, Optional, Type
|
||||
from enum import Enum
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
|
@ -40,6 +41,14 @@ from khoj.utils.config import GPT4AllProcessorModel
|
|||
from khoj.utils.helpers import generate_random_name
|
||||
|
||||
|
||||
class SubscriptionState(Enum):
|
||||
TRIAL = "trial"
|
||||
SUBSCRIBED = "subscribed"
|
||||
UNSUBSCRIBED = "unsubscribed"
|
||||
EXPIRED = "expired"
|
||||
INVALID = "invalid"
|
||||
|
||||
|
||||
async def set_notion_config(token: str, user: KhojUser):
|
||||
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
||||
if not notion_config:
|
||||
|
@ -127,22 +136,34 @@ async def set_user_subscription(
|
|||
return None
|
||||
|
||||
|
||||
def subscription_to_state(subscription: Subscription) -> str:
|
||||
if not subscription:
|
||||
return SubscriptionState.INVALID.value
|
||||
elif subscription.type == Subscription.Type.TRIAL:
|
||||
return SubscriptionState.TRIAL.value
|
||||
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
return SubscriptionState.SUBSCRIBED.value
|
||||
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
return SubscriptionState.UNSUBSCRIBED.value
|
||||
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
|
||||
return SubscriptionState.EXPIRED.value
|
||||
return SubscriptionState.INVALID.value
|
||||
|
||||
|
||||
def get_user_subscription_state(email: str) -> str:
|
||||
"""Get subscription state of user
|
||||
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
||||
"""
|
||||
user_subscription = Subscription.objects.filter(user__email=email).first()
|
||||
if not user_subscription:
|
||||
return "trial"
|
||||
elif user_subscription.type == Subscription.Type.TRIAL:
|
||||
return "trial"
|
||||
elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
return "subscribed"
|
||||
elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
return "unsubscribed"
|
||||
elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc):
|
||||
return "expired"
|
||||
return "invalid"
|
||||
return subscription_to_state(user_subscription)
|
||||
|
||||
|
||||
async def aget_user_subscription_state(email: str) -> str:
|
||||
"""Get subscription state of user
|
||||
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
||||
"""
|
||||
user_subscription = await Subscription.objects.filter(user__email=email).afirst()
|
||||
return subscription_to_state(user_subscription)
|
||||
|
||||
|
||||
async def get_user_by_email(email: str) -> KhojUser:
|
||||
|
|
|
@ -171,7 +171,7 @@
|
|||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{% if billing_enabled %}
|
||||
{% if not billing_enabled %}
|
||||
<div id="billing" class="section">
|
||||
<h2 class="section-title">Billing</h2>
|
||||
<div class="section-cards">
|
||||
|
|
|
@ -12,7 +12,7 @@ from asgiref.sync import sync_to_async
|
|||
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from starlette.authentication import requires
|
||||
from starlette.authentication import requires, has_required_scope
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_server
|
||||
|
|
|
@ -7,7 +7,7 @@ from fastapi import APIRouter
|
|||
from fastapi import Request
|
||||
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from starlette.authentication import requires
|
||||
from starlette.authentication import requires, has_required_scope
|
||||
from khoj.database import adapters
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.utils.rawconfig import (
|
||||
|
@ -37,7 +37,6 @@ templates = Jinja2Templates(directory=constants.web_directory)
|
|||
def index(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
|
@ -46,7 +45,7 @@ def index(request: Request):
|
|||
"request": request,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["subscribed"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -57,7 +56,6 @@ def index(request: Request):
|
|||
def index_post(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
|
@ -66,7 +64,7 @@ def index_post(request: Request):
|
|||
"request": request,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["subscribed"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -77,7 +75,6 @@ def index_post(request: Request):
|
|||
def search_page(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
|
@ -86,7 +83,7 @@ def search_page(request: Request):
|
|||
"request": request,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["subscribed"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -97,7 +94,6 @@ def search_page(request: Request):
|
|||
def chat_page(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
|
@ -106,7 +102,7 @@ def chat_page(request: Request):
|
|||
"request": request,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["subscribed"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -171,7 +167,7 @@ def config_page(request: Request):
|
|||
"subscription_state": user_subscription_state,
|
||||
"subscription_renewal_date": subscription_renewal_date,
|
||||
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["subscribed"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -182,7 +178,6 @@ def config_page(request: Request):
|
|||
def github_config_page(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
current_github_config = get_user_github_config(user)
|
||||
|
||||
|
@ -212,7 +207,7 @@ def github_config_page(request: Request):
|
|||
"current_config": current_config,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["subscribed"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -223,7 +218,6 @@ def github_config_page(request: Request):
|
|||
def notion_config_page(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = adapters.get_user_subscription(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
current_notion_config = get_user_notion_config(user)
|
||||
|
||||
|
@ -240,7 +234,7 @@ def notion_config_page(request: Request):
|
|||
"current_config": current_config,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["subscribed"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -251,7 +245,6 @@ def notion_config_page(request: Request):
|
|||
def computer_config_page(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
|
@ -260,7 +253,7 @@ def computer_config_page(request: Request):
|
|||
"request": request,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["subscribed"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue