Add default settings to let new users be subscribed on trial

- Add the default user to a subscription trial
- Update associated unit tests
This commit is contained in:
sabaimran 2023-11-10 22:38:28 -08:00
parent 0a950d9382
commit e2e96f9aa4
9 changed files with 86 additions and 11 deletions

View file

@ -101,6 +101,8 @@ async def create_google_user(token: dict) -> KhojUser:
user=user,
)
await Subscription.objects.acreate(user=user, type="trial")
return user

View file

@ -8,6 +8,7 @@ from database.models import (
ChatModelOptions,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
Subscription,
)
admin.site.register(KhojUser, UserAdmin)
@ -15,3 +16,4 @@ admin.site.register(KhojUser, UserAdmin)
admin.site.register(ChatModelOptions)
admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig)
admin.site.register(Subscription)

View file

@ -0,0 +1,21 @@
# Generated by Django 4.2.5 on 2023-11-11 05:39
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0014_alter_googleuser_picture"),
]
operations = [
migrations.AlterField(
model_name="subscription",
name="user",
field=models.OneToOneField(
on_delete=django.db.models.deletion.CASCADE, related_name="subscription", to=settings.AUTH_USER_MODEL
),
),
]

View file

@ -0,0 +1,17 @@
# Generated by Django 4.2.5 on 2023-11-11 06:15
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0015_alter_subscription_user"),
]
operations = [
migrations.AlterField(
model_name="subscription",
name="renewal_date",
field=models.DateTimeField(blank=True, default=None, null=True),
),
]

View file

@ -51,10 +51,10 @@ class Subscription(BaseModel):
TRIAL = "trial"
STANDARD = "standard"
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
is_recurring = models.BooleanField(default=False)
renewal_date = models.DateTimeField(null=True, default=None)
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
class NotionConfig(BaseModel):

View file

@ -28,7 +28,7 @@ from khoj.utils.config import (
from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig
from khoj.routers.indexer import configure_content, load_content, configure_search
from database.models import KhojUser
from database.models import KhojUser, Subscription
from database.adapters import get_all_users
@ -54,27 +54,40 @@ class UserAuthenticationBackend(AuthenticationBackend):
def _initialize_default_user(self):
if not self.khojuser_manager.filter(username="default").exists():
self.khojuser_manager.create_user(
default_user = self.khojuser_manager.create_user(
username="default",
email="default@example.com",
password="default",
)
Subscription.objects.create(
user=default_user,
type="trial",
)
async def authenticate(self, request: HTTPConnection):
current_user = request.session.get("user")
if current_user and current_user.get("email"):
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
user = (
await self.khojuser_manager.filter(email=current_user.get("email"))
.prefetch_related("subscription")
.afirst()
)
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
# Get user owning token
user_with_token = await self.khojapiuser_manager.filter(token=bearer_token).select_related("user").afirst()
user_with_token = (
await self.khojapiuser_manager.filter(token=bearer_token)
.select_related("user")
.prefetch_related("user__subscription")
.afirst()
)
if user_with_token:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").afirst()
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)

View file

@ -13,7 +13,7 @@ from khoj.utils.helpers import ConversationCommand, log_telemetry
from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
from database.models import KhojUser
from database.models import KhojUser, Subscription
from database.adapters import ConversationAdapters
logger = logging.getLogger(__name__)
@ -61,12 +61,15 @@ def update_telemetry_state(
metadata: Optional[dict] = None,
):
user: KhojUser = request.user.object if request.user.is_authenticated else None
subscription: Subscription = user.subscription if user and user.subscription else None
user_state = {
"client_host": request.client.host if request.client else None,
"user_agent": user_agent or "unknown",
"referer": referer or "unknown",
"host": host or "unknown",
"server_id": str(user.uuid) if user else None,
"subscription_type": subscription.type if subscription else None,
"is_recurring": subscription.is_recurring if subscription else None,
}
if metadata:

View file

@ -43,6 +43,7 @@ from tests.helpers import (
OpenAIProcessorConversationConfigFactory,
OfflineChatProcessorConversationConfigFactory,
UserConversationProcessorConfigFactory,
SubscriptionFactory,
)
@ -69,7 +70,9 @@ def search_config() -> SearchConfig:
@pytest.mark.django_db
@pytest.fixture
def default_user():
return UserFactory()
user = UserFactory()
SubscriptionFactory(user=user)
return user
@pytest.mark.django_db
@ -78,11 +81,13 @@ def default_user2():
if KhojUser.objects.filter(username="default").exists():
return KhojUser.objects.get(username="default")
return KhojUser.objects.create(
user = KhojUser.objects.create(
username="default",
email="default@example.com",
password="default",
)
SubscriptionFactory(user=user)
return user
@pytest.mark.django_db
@ -94,11 +99,13 @@ def default_user3():
if KhojUser.objects.filter(username="default3").exists():
return KhojUser.objects.get(username="default3")
return KhojUser.objects.create(
user = KhojUser.objects.create(
username="default3",
email="default3@example.com",
password="default3",
)
SubscriptionFactory(user=user)
return user
@pytest.mark.django_db

View file

@ -9,6 +9,7 @@ from database.models import (
OpenAIProcessorConversationConfig,
UserConversationConfig,
Conversation,
Subscription,
)
@ -68,3 +69,12 @@ class ConversationFactory(factory.django.DjangoModelFactory):
model = Conversation
user = factory.SubFactory(UserFactory)
class SubscriptionFactory(factory.django.DjangoModelFactory):
class Meta:
model = Subscription
user = factory.SubFactory(UserFactory)
type = "trial"
is_recurring = False