pricing page
@@ -909,7 +957,7 @@ export default function SettingsView() {
)) ||
(userConfig.subscription_state === "expired" && (
<>
- Free Plan
+ Humanist
{(userConfig.subscription_renewal_date && (
Subscription expired on{" "}
@@ -923,7 +971,7 @@ export default function SettingsView() {
Check{" "}
pricing page
@@ -960,7 +1008,8 @@ export default function SettingsView() {
/>
Resubscribe
- )) || (
+ )) ||
+ (userConfig.subscription_enabled_trial_at && (
Subscribe
+ )) || (
+
)}
diff --git a/src/khoj/configure.py b/src/khoj/configure.py
index 3fb540bd..1454b164 100644
--- a/src/khoj/configure.py
+++ b/src/khoj/configure.py
@@ -108,7 +108,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
password="default",
)
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))
- Subscription.objects.create(user=default_user, type="standard", renewal_date=renewal_date)
+ Subscription.objects.create(user=default_user, type=Subscription.Type.STANDARD, renewal_date=renewal_date)
async def authenticate(self, request: HTTPConnection):
current_user = request.session.get("user")
@@ -312,7 +312,7 @@ def configure_routes(app):
logger.info("🔑 Enabled Authentication")
if state.billing_enabled:
- from khoj.routers.subscription import subscription_router
+ from khoj.routers.api_subscription import subscription_router
app.include_router(subscription_router, prefix="/api/subscription")
logger.info("💳 Enabled Billing")
diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py
index eb914ed8..d5312b29 100644
--- a/src/khoj/database/adapters/__init__.py
+++ b/src/khoj/database/adapters/__init__.py
@@ -70,6 +70,9 @@ from khoj.utils.helpers import (
logger = logging.getLogger(__name__)
+LENGTH_OF_FREE_TRIAL = 7 #
+
+
class SubscriptionState(Enum):
TRIAL = "trial"
SUBSCRIBED = "subscribed"
@@ -168,7 +171,7 @@ async def acreate_user_by_phone_number(phone_number: str) -> KhojUser:
)
await user.asave()
- await Subscription.objects.acreate(user=user, type="trial")
+ await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
return user
@@ -185,11 +188,29 @@ async def aget_or_create_user_by_email(email: str) -> tuple[KhojUser, bool]:
user_subscription = await Subscription.objects.filter(user=user).afirst()
if not user_subscription:
- await Subscription.objects.acreate(user=user, type="trial")
+ await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
return user, is_new
+async def astart_trial_subscription(user: KhojUser) -> Subscription:
+ subscription = await Subscription.objects.filter(user=user).afirst()
+ if not subscription:
+ raise HTTPException(status_code=400, detail="User does not have a subscription")
+
+ if subscription.type == Subscription.Type.TRIAL:
+ raise HTTPException(status_code=400, detail="User already has a trial subscription")
+
+ if subscription.enabled_trial_at:
+ raise HTTPException(status_code=400, detail="User already has a trial subscription")
+
+ subscription.type = Subscription.Type.TRIAL
+ subscription.enabled_trial_at = datetime.now(tz=timezone.utc)
+ subscription.renewal_date = datetime.now(tz=timezone.utc) + timedelta(days=LENGTH_OF_FREE_TRIAL)
+ await subscription.asave()
+ return subscription
+
+
async def aget_user_validated_by_email_verification_code(code: str) -> KhojUser:
user = await KhojUser.objects.filter(email_verification_code=code).afirst()
if not user:
@@ -221,7 +242,7 @@ async def create_user_by_google_token(token: dict) -> KhojUser:
user=user,
)
- await Subscription.objects.acreate(user=user, type="trial")
+ await Subscription.objects.acreate(user=user, type=Subscription.Type.STANDARD)
return user
@@ -279,16 +300,15 @@ def subscription_to_state(subscription: Subscription) -> str:
if not subscription:
return SubscriptionState.INVALID.value
elif subscription.type == Subscription.Type.TRIAL:
- # Trial subscription is valid for 7 days
- if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=14):
+ # Check if the trial has expired
+ if datetime.now(tz=timezone.utc) > subscription.renewal_date:
return SubscriptionState.EXPIRED.value
-
return SubscriptionState.TRIAL.value
- elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
+ elif subscription.is_recurring and subscription.renewal_date > datetime.now(tz=timezone.utc):
return SubscriptionState.SUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date is None:
return SubscriptionState.EXPIRED.value
- elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
+ elif not subscription.is_recurring and subscription.renewal_date > datetime.now(tz=timezone.utc):
return SubscriptionState.UNSUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
return SubscriptionState.EXPIRED.value
diff --git a/src/khoj/database/migrations/0071_subscription_enabled_trial_at_and_more.py b/src/khoj/database/migrations/0071_subscription_enabled_trial_at_and_more.py
new file mode 100644
index 00000000..68b97399
--- /dev/null
+++ b/src/khoj/database/migrations/0071_subscription_enabled_trial_at_and_more.py
@@ -0,0 +1,32 @@
+# Generated by Django 5.0.8 on 2024-10-20 19:24
+
+from django.db import migrations, models
+
+
+def set_enabled_trial_at(apps, schema_editor):
+ Subscription = apps.get_model("database", "Subscription")
+ for subscription in Subscription.objects.all():
+ subscription.enabled_trial_at = subscription.created_at
+ subscription.save()
+
+
+class Migration(migrations.Migration):
+ dependencies = [
+ ("database", "0070_alter_agent_input_tools_alter_agent_output_modes"),
+ ]
+
+ operations = [
+ migrations.AddField(
+ model_name="subscription",
+ name="enabled_trial_at",
+ field=models.DateTimeField(blank=True, default=None, null=True),
+ ),
+ migrations.AlterField(
+ model_name="subscription",
+ name="type",
+ field=models.CharField(
+ choices=[("trial", "Trial"), ("standard", "Standard")], default="standard", max_length=20
+ ),
+ ),
+ migrations.RunPython(set_enabled_trial_at),
+ ]
diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py
index 6b122dac..c89c409a 100644
--- a/src/khoj/database/models/__init__.py
+++ b/src/khoj/database/models/__init__.py
@@ -73,9 +73,10 @@ class Subscription(BaseModel):
STANDARD = "standard"
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
- type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
+ type = models.CharField(max_length=20, choices=Type.choices, default=Type.STANDARD)
is_recurring = models.BooleanField(default=False)
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
+ enabled_trial_at = models.DateTimeField(null=True, default=None, blank=True)
class OpenAIProcessorConversationConfig(BaseModel):
diff --git a/src/khoj/routers/subscription.py b/src/khoj/routers/api_subscription.py
similarity index 87%
rename from src/khoj/routers/subscription.py
rename to src/khoj/routers/api_subscription.py
index 14a75da8..f47775aa 100644
--- a/src/khoj/routers/subscription.py
+++ b/src/khoj/routers/api_subscription.py
@@ -1,12 +1,14 @@
+import json
import logging
import os
from datetime import datetime, timezone
from asgiref.sync import sync_to_async
-from fastapi import APIRouter, Request
+from fastapi import APIRouter, Request, Response
from starlette.authentication import requires
from khoj.database import adapters
+from khoj.database.models import KhojUser, Subscription
from khoj.routers.helpers import update_telemetry_state
from khoj.utils import state
@@ -73,7 +75,7 @@ async def subscribe(request: Request):
elif event_type in {"customer.subscription.deleted"}:
# Reset the user to trial state
user, is_new = await adapters.set_user_subscription(
- customer_email, is_recurring=False, renewal_date=False, type="trial"
+ customer_email, is_recurring=False, renewal_date=False, type=Subscription.Type.TRIAL
)
success = user is not None
@@ -116,3 +118,19 @@ async def update_subscription(request: Request, email: str, operation: str):
return {"success": False, "message": "No subscription found that is set to cancel"}
return {"success": False, "message": "Invalid operation"}
+
+
+@subscription_router.post("/trial", response_class=Response)
+@requires(["authenticated"])
+async def start_trial(request: Request) -> Response:
+ user: KhojUser = request.user.object
+
+ # Start a trial for the user
+ updated_subscription = await adapters.astart_trial_subscription(user)
+
+ # Return trial status as a JSON response
+ return Response(
+ content=json.dumps({"trial_enabled": updated_subscription is not None}),
+ media_type="application/json",
+ status_code=200,
+ )
diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py
index c587c4bd..e9c752fb 100644
--- a/src/khoj/routers/helpers.py
+++ b/src/khoj/routers/helpers.py
@@ -38,6 +38,7 @@ from starlette.requests import URL
from khoj.database import adapters
from khoj.database.adapters import (
+ LENGTH_OF_FREE_TRIAL,
AgentAdapters,
AutomationAdapters,
ConversationAdapters,
@@ -1673,10 +1674,16 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
user_subscription_state = get_user_subscription_state(user.email)
user_subscription = adapters.get_user_subscription(user.email)
+
subscription_renewal_date = (
user_subscription.renewal_date.strftime("%d %b %Y")
if user_subscription and user_subscription.renewal_date
- else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y")
+ else None
+ )
+ subscription_enabled_trial_at = (
+ user_subscription.enabled_trial_at.strftime("%d %b %Y")
+ if user_subscription and user_subscription.enabled_trial_at
+ else None
)
given_name = get_user_name(user)
@@ -1749,6 +1756,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
# user billing info
"subscription_state": user_subscription_state,
"subscription_renewal_date": subscription_renewal_date,
+ "subscription_enabled_trial_at": subscription_enabled_trial_at,
# server settings
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
"billing_enabled": state.billing_enabled,
@@ -1757,6 +1765,7 @@ def get_user_config(user: KhojUser, request: Request, is_detailed: bool = False)
"khoj_version": state.khoj_version,
"anonymous_mode": state.anonymous_mode,
"notion_oauth_url": notion_oauth_url,
+ "length_of_free_trial": LENGTH_OF_FREE_TRIAL,
}
diff --git a/tests/helpers.py b/tests/helpers.py
index 2e8e5671..ae5c7779 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -86,7 +86,7 @@ class SubscriptionFactory(factory.django.DjangoModelFactory):
model = Subscription
user = factory.SubFactory(UserFactory)
- type = "standard"
+ type = Subscription.Type.STANDARD
is_recurring = False
renewal_date = make_aware(datetime.strptime("2100-04-01", "%Y-%m-%d"))