From e2e96f9aa45117c6fec1456c9454ded345ed297c Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 10 Nov 2023 22:38:28 -0800 Subject: [PATCH] Add default settings to let new users be subscribed on trial - Add the default user to a subscription trial - Update associated unit tests --- src/database/adapters/__init__.py | 2 ++ src/database/admin.py | 2 ++ .../0015_alter_subscription_user.py | 21 +++++++++++++++++ .../0016_alter_subscription_renewal_date.py | 17 ++++++++++++++ src/database/models/__init__.py | 4 ++-- src/khoj/configure.py | 23 +++++++++++++++---- src/khoj/routers/helpers.py | 5 +++- tests/conftest.py | 13 ++++++++--- tests/helpers.py | 10 ++++++++ 9 files changed, 86 insertions(+), 11 deletions(-) create mode 100644 src/database/migrations/0015_alter_subscription_user.py create mode 100644 src/database/migrations/0016_alter_subscription_renewal_date.py diff --git a/src/database/adapters/__init__.py b/src/database/adapters/__init__.py index 70d94df3..28999369 100644 --- a/src/database/adapters/__init__.py +++ b/src/database/adapters/__init__.py @@ -101,6 +101,8 @@ async def create_google_user(token: dict) -> KhojUser: user=user, ) + await Subscription.objects.acreate(user=user, type="trial") + return user diff --git a/src/database/admin.py b/src/database/admin.py index 5f41f54a..03c2ca42 100644 --- a/src/database/admin.py +++ b/src/database/admin.py @@ -8,6 +8,7 @@ from database.models import ( ChatModelOptions, OpenAIProcessorConversationConfig, OfflineChatProcessorConversationConfig, + Subscription, ) admin.site.register(KhojUser, UserAdmin) @@ -15,3 +16,4 @@ admin.site.register(KhojUser, UserAdmin) admin.site.register(ChatModelOptions) admin.site.register(OpenAIProcessorConversationConfig) admin.site.register(OfflineChatProcessorConversationConfig) +admin.site.register(Subscription) diff --git a/src/database/migrations/0015_alter_subscription_user.py b/src/database/migrations/0015_alter_subscription_user.py new file mode 100644 index 00000000..e4ba6ab0 --- /dev/null +++ b/src/database/migrations/0015_alter_subscription_user.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.5 on 2023-11-11 05:39 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0014_alter_googleuser_picture"), + ] + + operations = [ + migrations.AlterField( + model_name="subscription", + name="user", + field=models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, related_name="subscription", to=settings.AUTH_USER_MODEL + ), + ), + ] diff --git a/src/database/migrations/0016_alter_subscription_renewal_date.py b/src/database/migrations/0016_alter_subscription_renewal_date.py new file mode 100644 index 00000000..bc7c5ada --- /dev/null +++ b/src/database/migrations/0016_alter_subscription_renewal_date.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.5 on 2023-11-11 06:15 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0015_alter_subscription_user"), + ] + + operations = [ + migrations.AlterField( + model_name="subscription", + name="renewal_date", + field=models.DateTimeField(blank=True, default=None, null=True), + ), + ] diff --git a/src/database/models/__init__.py b/src/database/models/__init__.py index 73f19c36..437d86ed 100644 --- a/src/database/models/__init__.py +++ b/src/database/models/__init__.py @@ -51,10 +51,10 @@ class Subscription(BaseModel): TRIAL = "trial" STANDARD = "standard" - user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) + user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription") type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL) is_recurring = models.BooleanField(default=False) - renewal_date = models.DateTimeField(null=True, default=None) + renewal_date = models.DateTimeField(null=True, default=None, blank=True) class NotionConfig(BaseModel): diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 6f0589a8..cfb9fac4 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -28,7 +28,7 @@ from khoj.utils.config import ( from khoj.utils.fs_syncer import collect_files from khoj.utils.rawconfig import FullConfig from khoj.routers.indexer import configure_content, load_content, configure_search -from database.models import KhojUser +from database.models import KhojUser, Subscription from database.adapters import get_all_users @@ -54,27 +54,40 @@ class UserAuthenticationBackend(AuthenticationBackend): def _initialize_default_user(self): if not self.khojuser_manager.filter(username="default").exists(): - self.khojuser_manager.create_user( + default_user = self.khojuser_manager.create_user( username="default", email="default@example.com", password="default", ) + Subscription.objects.create( + user=default_user, + type="trial", + ) async def authenticate(self, request: HTTPConnection): current_user = request.session.get("user") if current_user and current_user.get("email"): - user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst() + user = ( + await self.khojuser_manager.filter(email=current_user.get("email")) + .prefetch_related("subscription") + .afirst() + ) if user: return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: # Get bearer token from header bearer_token = request.headers["Authorization"].split("Bearer ")[1] # Get user owning token - user_with_token = await self.khojapiuser_manager.filter(token=bearer_token).select_related("user").afirst() + user_with_token = ( + await self.khojapiuser_manager.filter(token=bearer_token) + .select_related("user") + .prefetch_related("user__subscription") + .afirst() + ) if user_with_token: return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) if state.anonymous_mode: - user = await self.khojuser_manager.filter(username="default").afirst() + user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() if user: return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index a5263d56..46ef0641 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -13,7 +13,7 @@ from khoj.utils.helpers import ConversationCommand, log_telemetry from khoj.processor.conversation.openai.gpt import converse from khoj.processor.conversation.gpt4all.chat_model import converse_offline from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator -from database.models import KhojUser +from database.models import KhojUser, Subscription from database.adapters import ConversationAdapters logger = logging.getLogger(__name__) @@ -61,12 +61,15 @@ def update_telemetry_state( metadata: Optional[dict] = None, ): user: KhojUser = request.user.object if request.user.is_authenticated else None + subscription: Subscription = user.subscription if user and user.subscription else None user_state = { "client_host": request.client.host if request.client else None, "user_agent": user_agent or "unknown", "referer": referer or "unknown", "host": host or "unknown", "server_id": str(user.uuid) if user else None, + "subscription_type": subscription.type if subscription else None, + "is_recurring": subscription.is_recurring if subscription else None, } if metadata: diff --git a/tests/conftest.py b/tests/conftest.py index 8cf0a391..95fa9a99 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,6 +43,7 @@ from tests.helpers import ( OpenAIProcessorConversationConfigFactory, OfflineChatProcessorConversationConfigFactory, UserConversationProcessorConfigFactory, + SubscriptionFactory, ) @@ -69,7 +70,9 @@ def search_config() -> SearchConfig: @pytest.mark.django_db @pytest.fixture def default_user(): - return UserFactory() + user = UserFactory() + SubscriptionFactory(user=user) + return user @pytest.mark.django_db @@ -78,11 +81,13 @@ def default_user2(): if KhojUser.objects.filter(username="default").exists(): return KhojUser.objects.get(username="default") - return KhojUser.objects.create( + user = KhojUser.objects.create( username="default", email="default@example.com", password="default", ) + SubscriptionFactory(user=user) + return user @pytest.mark.django_db @@ -94,11 +99,13 @@ def default_user3(): if KhojUser.objects.filter(username="default3").exists(): return KhojUser.objects.get(username="default3") - return KhojUser.objects.create( + user = KhojUser.objects.create( username="default3", email="default3@example.com", password="default3", ) + SubscriptionFactory(user=user) + return user @pytest.mark.django_db diff --git a/tests/helpers.py b/tests/helpers.py index 3aa7c435..0f0f9cf2 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -9,6 +9,7 @@ from database.models import ( OpenAIProcessorConversationConfig, UserConversationConfig, Conversation, + Subscription, ) @@ -68,3 +69,12 @@ class ConversationFactory(factory.django.DjangoModelFactory): model = Conversation user = factory.SubFactory(UserFactory) + + +class SubscriptionFactory(factory.django.DjangoModelFactory): + class Meta: + model = Subscription + + user = factory.SubFactory(UserFactory) + type = "trial" + is_recurring = False