Create explicit flow to enable the free trial (#944)

* Create explicit flow to enable the free trial

The current design is confusing. It obfuscates the fact that the user is on a free trial. This design will make the opt-in explicit and more intuitive.

* Use the Subscription Type enum instead of hardcoded strings everywhere

* Use length of free trial in the frontend code as well
This commit is contained in:
sabaimran 2024-10-23 15:29:23 -07:00 committed by GitHub
parent c5e91c346a
commit f3ce47b445
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 167 additions and 24 deletions

View file

@ -12,7 +12,7 @@ Without any desktop clients, you can start chatting with Khoj on WhatsApp. Bear
In order to use Khoj on WhatsApp with your own data, you need to setup a Khoj Cloud account and connect your WhatsApp account to it. This is a one time setup and you can do it from the [Khoj Cloud config page](https://app.khoj.dev/settings). In order to use Khoj on WhatsApp with your own data, you need to setup a Khoj Cloud account and connect your WhatsApp account to it. This is a one time setup and you can do it from the [Khoj Cloud config page](https://app.khoj.dev/settings).
If you hit usage limits for the WhatsApp bot, upgrade to [a paid plan](https://khoj.dev/pricing) on Khoj Cloud. If you hit usage limits for the WhatsApp bot, upgrade to [a paid plan](https://khoj.dev/#pricing) on Khoj Cloud.
<img src="https://khoj-web-bucket.s3.amazonaws.com/khojwhatsapp.png" alt="WhatsApp QR Code" width="300" height="300" /> <img src="https://khoj-web-bucket.s3.amazonaws.com/khojwhatsapp.png" alt="WhatsApp QR Code" width="300" height="300" />

View file

@ -68,7 +68,8 @@ export interface UserConfig {
selected_voice_model_config: number; selected_voice_model_config: number;
// user billing info // user billing info
subscription_state: SubscriptionStates; subscription_state: SubscriptionStates;
subscription_renewal_date: string; subscription_renewal_date: string | undefined;
subscription_enabled_trial_at: string | undefined;
// server settings // server settings
khoj_cloud_subscription_url: string | undefined; khoj_cloud_subscription_url: string | undefined;
billing_enabled: boolean; billing_enabled: boolean;
@ -78,6 +79,7 @@ export interface UserConfig {
anonymous_mode: boolean; anonymous_mode: boolean;
notion_oauth_url: string; notion_oauth_url: string;
detail: string; detail: string;
length_of_free_trial: number;
} }
export function useUserConfig(detailed: boolean = false) { export function useUserConfig(detailed: boolean = false) {

View file

@ -513,7 +513,7 @@ export default function SettingsView() {
const isMobileWidth = useIsMobileWidth(); const isMobileWidth = useIsMobileWidth();
const cardClassName = const cardClassName =
"w-full lg:w-1/3 grid grid-flow-column border border-gray-300 shadow-md rounded-lg bg-gradient-to-b from-background to-gray-50 dark:to-gray-950"; "w-full lg:w-1/3 grid grid-flow-column border border-gray-300 shadow-md rounded-lg bg-gradient-to-b from-background to-gray-50 dark:to-gray-950 border border-opacity-50";
useEffect(() => { useEffect(() => {
setUserConfig(initialUserConfig); setUserConfig(initialUserConfig);
@ -640,6 +640,51 @@ export default function SettingsView() {
} }
}; };
const enableFreeTrial = async () => {
const formatDate = (dateString: Date) => {
const date = new Date(dateString);
return new Intl.DateTimeFormat("en-US", {
day: "2-digit",
month: "short",
year: "numeric",
}).format(date);
};
try {
const response = await fetch(`/api/subscription/trial`, {
method: "POST",
});
if (!response.ok) throw new Error("Failed to enable free trial");
const responseBody = await response.json();
// Set updated user settings
if (responseBody.trial_enabled && userConfig) {
let newUserConfig = userConfig;
newUserConfig.subscription_state = SubscriptionStates.TRIAL;
const renewalDate = new Date(
Date.now() + userConfig.length_of_free_trial * 24 * 60 * 60 * 1000,
);
newUserConfig.subscription_renewal_date = formatDate(renewalDate);
newUserConfig.subscription_enabled_trial_at = new Date().toISOString();
setUserConfig(newUserConfig);
// Notify user of free trial
toast({
title: "🎉 Trial Enabled",
description: `Your free trial will end on ${newUserConfig.subscription_renewal_date}`,
});
}
} catch (error) {
console.error("Error enabling free trial:", error);
toast({
title: "⚠️ Failed to Enable Free Trial",
description:
"Failed to enable free trial. Try again or contact us at team@khoj.dev",
});
}
};
const saveName = async () => { const saveName = async () => {
if (!name) return; if (!name) return;
try { try {
@ -866,10 +911,13 @@ export default function SettingsView() {
Futurist (Trial) Futurist (Trial)
</p> </p>
<p className="text-gray-400"> <p className="text-gray-400">
You are on a 14 day trial of the Khoj You are on a{" "}
Futurist plan. Check{" "} {userConfig.length_of_free_trial} day trial
of the Khoj Futurist plan. Your trial ends
on {userConfig.subscription_renewal_date}.
Check{" "}
<a <a
href="https://khoj.dev/pricing" href="https://khoj.dev/#pricing"
target="_blank" target="_blank"
> >
pricing page pricing page
@ -909,7 +957,7 @@ export default function SettingsView() {
)) || )) ||
(userConfig.subscription_state === "expired" && ( (userConfig.subscription_state === "expired" && (
<> <>
<p className="text-xl">Free Plan</p> <p className="text-xl">Humanist</p>
{(userConfig.subscription_renewal_date && ( {(userConfig.subscription_renewal_date && (
<p className="text-gray-400"> <p className="text-gray-400">
Subscription <b>expired</b> on{" "} Subscription <b>expired</b> on{" "}
@ -923,7 +971,7 @@ export default function SettingsView() {
<p className="text-gray-400"> <p className="text-gray-400">
Check{" "} Check{" "}
<a <a
href="https://khoj.dev/pricing" href="https://khoj.dev/#pricing"
target="_blank" target="_blank"
> >
pricing page pricing page
@ -960,7 +1008,8 @@ export default function SettingsView() {
/> />
Resubscribe Resubscribe
</Button> </Button>
)) || ( )) ||
(userConfig.subscription_enabled_trial_at && (
<Button <Button
variant="outline" variant="outline"
className="text-primary/80 hover:text-primary" className="text-primary/80 hover:text-primary"
@ -978,6 +1027,18 @@ export default function SettingsView() {
/> />
Subscribe Subscribe
</Button> </Button>
)) || (
<Button
variant="outline"
className="text-primary/80 hover:text-primary"
onClick={enableFreeTrial}
>
<ArrowCircleUp
weight="bold"
className="h-5 w-5 mr-2"
/>
Enable Trial
</Button>
)} )}
</CardFooter> </CardFooter>
</Card> </Card>

View file

@ -108,7 +108,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
password="default", password="default",
) )
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d")) renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
Subscription.objects.create(user=default_user, type="standard", renewal_date=renewal_date) Subscription.objects.create(user=default_user, type=Subscription.Type.STANDARD, renewal_date=renewal_date)
async def authenticate(self, request: HTTPConnection): async def authenticate(self, request: HTTPConnection):
current_user = request.session.get("user") current_user = request.session.get("user")
@ -312,7 +312,7 @@ def configure_routes(app):
logger.info("🔑 Enabled Authentication") logger.info("🔑 Enabled Authentication")
if state.billing_enabled: if state.billing_enabled:
from khoj.routers.subscription import subscription_router from khoj.routers.api_subscription import subscription_router
app.include_router(subscription_router, prefix="/api/subscription") app.include_router(subscription_router, prefix="/api/subscription")
logger.info("💳 Enabled Billing") logger.info("💳 Enabled Billing")

View file

@ -70,6 +70,9 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
LENGTH_OF_FREE_TRIAL = 7 #
class SubscriptionState(Enum): class SubscriptionState(Enum):
TRIAL = "trial" TRIAL = "trial"
SUBSCRIBED = "subscribed" SUBSCRIBED = "subscribed"
@ -168,7 +171,7 @@ async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
) )
await user.asave() await user.asave()
await Subscription.objects.acreate(user=user, type="trial") await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
return user return user
@ -185,11 +188,29 @@ async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]:
user_subscription = await Subscription.objects.filter(user=user).afirst() user_subscription = await Subscription.objects.filter(user=user).afirst()
if not user_subscription: if not user_subscription:
await Subscription.objects.acreate(user=user, type="trial") await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
return user, is_new return user, is_new
async def astart_trial_subscription(user: KhojUser) -> Subscription:
subscription = await Subscription.objects.filter(user=user).afirst()
if not subscription:
raise HTTPException(status_code=400, detail="User does not have a subscription")
if subscription.type == Subscription.Type.TRIAL:
raise HTTPException(status_code=400, detail="User already has a trial subscription")
if subscription.enabled_trial_at:
raise HTTPException(status_code=400, detail="User already has a trial subscription")
subscription.type = Subscription.Type.TRIAL
subscription.enabled_trial_at = datetime.now(tz=timezone.utc)
subscription.renewal_date = datetime.now(tz=timezone.utc) + timedelta(days=LENGTH_OF_FREE_TRIAL)
await subscription.asave()
return subscription
async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser: async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser:
user = await KhojUser.objects.filter(email_verification_code=code).afirst() user = await KhojUser.objects.filter(email_verification_code=code).afirst()
if not user: if not user:
@ -221,7 +242,7 @@ async def create_user_by_google_token(token: dict) -> KhojUser:
user=user, user=user,
) )
await Subscription.objects.acreate(user=user, type="trial") await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
return user return user
@ -279,16 +300,15 @@ def subscription_to_state(subscription: Subscription) -> str:
if not subscription: if not subscription:
return SubscriptionState.INVALID.value return SubscriptionState.INVALID.value
elif subscription.type == Subscription.Type.TRIAL: elif subscription.type == Subscription.Type.TRIAL:
# Trial subscription is valid for 7 days # Check if the trial has expired
if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=14): if datetime.now(tz=timezone.utc) > subscription.renewal_date:
return SubscriptionState.EXPIRED.value return SubscriptionState.EXPIRED.value
return SubscriptionState.TRIAL.value return SubscriptionState.TRIAL.value
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): elif subscription.is_recurring and subscription.renewal_date > datetime.now(tz=timezone.utc):
return SubscriptionState.SUBSCRIBED.value return SubscriptionState.SUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date is None: elif not subscription.is_recurring and subscription.renewal_date is None:
return SubscriptionState.EXPIRED.value return SubscriptionState.EXPIRED.value
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): elif not subscription.is_recurring and subscription.renewal_date > datetime.now(tz=timezone.utc):
return SubscriptionState.UNSUBSCRIBED.value return SubscriptionState.UNSUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc): elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
return SubscriptionState.EXPIRED.value return SubscriptionState.EXPIRED.value

View file

@ -0,0 +1,32 @@
# Generated by Django 5.0.8 on 2024-10-20 19:24
from django.db import migrations, models
def set_enabled_trial_at(apps, schema_editor):
Subscription = apps.get_model("database", "Subscription")
for subscription in Subscription.objects.all():
subscription.enabled_trial_at = subscription.created_at
subscription.save()
class Migration(migrations.Migration):
dependencies = [
("database", "0070_alter_agent_input_tools_alter_agent_output_modes"),
]
operations = [
migrations.AddField(
model_name="subscription",
name="enabled_trial_at",
field=models.DateTimeField(blank=True, default=None, null=True),
),
migrations.AlterField(
model_name="subscription",
name="type",
field=models.CharField(
choices=[("trial", "Trial"), ("standard", "Standard")], default="standard", max_length=20
),
),
migrations.RunPython(set_enabled_trial_at),
]

View file

@ -73,9 +73,10 @@ class Subscription(BaseModel):
STANDARD = "standard" STANDARD = "standard"
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription") user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL) type = models.CharField(max_length=20, choices=Type.choices, default=Type.STANDARD)
is_recurring = models.BooleanField(default=False) is_recurring = models.BooleanField(default=False)
renewal_date = models.DateTimeField(null=True, default=None, blank=True) renewal_date = models.DateTimeField(null=True, default=None, blank=True)
enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True)
class OpenAIProcessorConversationConfig(BaseModel): class OpenAIProcessorConversationConfig(BaseModel):

View file

@ -1,12 +1,14 @@
import json
import logging import logging
import os import os
from datetime import datetime, timezone from datetime import datetime, timezone
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Request from fastapi import APIRouter, Request, Response
from starlette.authentication import requires from starlette.authentication import requires
from khoj.database import adapters from khoj.database import adapters
from khoj.database.models import KhojUser, Subscription
from khoj.routers.helpers import update_telemetry_state from khoj.routers.helpers import update_telemetry_state
from khoj.utils import state from khoj.utils import state
@ -73,7 +75,7 @@ async def subscribe(request: Request):
elif event_type in {"customer.subscription.deleted"}: elif event_type in {"customer.subscription.deleted"}:
# Reset the user to trial state # Reset the user to trial state
user, is_new = await adapters.set_user_subscription( user, is_new = await adapters.set_user_subscription(
customer_email, is_recurring=False, renewal_date=False, type="trial" customer_email, is_recurring=False, renewal_date=False, type=Subscription.Type.TRIAL
) )
success = user is not None success = user is not None
@ -116,3 +118,19 @@ async def update_subscription(request: Request, email: str, operation: str):
return {"success": False, "message": "No subscription found that is set to cancel"} return {"success": False, "message": "No subscription found that is set to cancel"}
return {"success": False, "message": "Invalid operation"} return {"success": False, "message": "Invalid operation"}
@subscription_router.post("/trial", response_class=Response)
@requires(["authenticated"])
async def start_trial(request: Request) -> Response:
user: KhojUser = request.user.object
# Start a trial for the user
updated_subscription = await adapters.astart_trial_subscription(user)
# Return trial status as a JSON response
return Response(
content=json.dumps({"trial_enabled": updated_subscription is not None}),
media_type="application/json",
status_code=200,
)

View file

@ -38,6 +38,7 @@ from starlette.requests import URL
from khoj.database import adapters from khoj.database import adapters
from khoj.database.adapters import ( from khoj.database.adapters import (
LENGTH_OF_FREE_TRIAL,
AgentAdapters, AgentAdapters,
AutomationAdapters, AutomationAdapters,
ConversationAdapters, ConversationAdapters,
@ -1673,10 +1674,16 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
user_subscription_state = get_user_subscription_state(user.email) user_subscription_state = get_user_subscription_state(user.email)
user_subscription = adapters.get_user_subscription(user.email) user_subscription = adapters.get_user_subscription(user.email)
subscription_renewal_date = ( subscription_renewal_date = (
user_subscription.renewal_date.strftime("%d %b %Y") user_subscription.renewal_date.strftime("%d %b %Y")
if user_subscription and user_subscription.renewal_date if user_subscription and user_subscription.renewal_date
else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y") else None
)
subscription_enabled_trial_at = (
user_subscription.enabled_trial_at.strftime("%d %b %Y")
if user_subscription and user_subscription.enabled_trial_at
else None
) )
given_name = get_user_name(user) given_name = get_user_name(user)
@ -1749,6 +1756,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
# user billing info # user billing info
"subscription_state": user_subscription_state, "subscription_state": user_subscription_state,
"subscription_renewal_date": subscription_renewal_date, "subscription_renewal_date": subscription_renewal_date,
"subscription_enabled_trial_at": subscription_enabled_trial_at,
# server settings # server settings
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
"billing_enabled": state.billing_enabled, "billing_enabled": state.billing_enabled,
@ -1757,6 +1765,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
"khoj_version": state.khoj_version, "khoj_version": state.khoj_version,
"anonymous_mode": state.anonymous_mode, "anonymous_mode": state.anonymous_mode,
"notion_oauth_url": notion_oauth_url, "notion_oauth_url": notion_oauth_url,
"length_of_free_trial": LENGTH_OF_FREE_TRIAL,
} }

View file

@ -86,7 +86,7 @@ class SubscriptionFactory(factory.django.DjangoModelFactory):
model = Subscription model = Subscription
user = factory.SubFactory(UserFactory) user = factory.SubFactory(UserFactory)
type = "standard" type = Subscription.Type.STANDARD
is_recurring = False is_recurring = False
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d")) renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))