Add default settings to let new users be subscribed on trial

- Add the default user to a subscription trial
- Update associated unit tests
This commit is contained in:
sabaimran 2023-11-10 22:38:28 -08:00
parent 0a950d9382
commit e2e96f9aa4
9 changed files with 86 additions and 11 deletions

View file

@ -101,6 +101,8 @@ async def create_google_user(token: dict) -> KhojUser:
user=user, user=user,
) )
await Subscription.objects.acreate(user=user, type="trial")
return user return user

View file

@ -8,6 +8,7 @@ from database.models import (
ChatModelOptions, ChatModelOptions,
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig, OfflineChatProcessorConversationConfig,
Subscription,
) )
admin.site.register(KhojUser, UserAdmin) admin.site.register(KhojUser, UserAdmin)
@ -15,3 +16,4 @@ admin.site.register(KhojUser, UserAdmin)
admin.site.register(ChatModelOptions) admin.site.register(ChatModelOptions)
admin.site.register(OpenAIProcessorConversationConfig) admin.site.register(OpenAIProcessorConversationConfig)
admin.site.register(OfflineChatProcessorConversationConfig) admin.site.register(OfflineChatProcessorConversationConfig)
admin.site.register(Subscription)

View file

@ -0,0 +1,21 @@
# Generated by Django 4.2.5 on 2023-11-11 05:39
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
("database", "0014_alter_googleuser_picture"),
]
operations = [
migrations.AlterField(
model_name="subscription",
name="user",
field=models.OneToOneField(
on_delete=django.db.models.deletion.CASCADE, related_name="subscription", to=settings.AUTH_USER_MODEL
),
),
]

View file

@ -0,0 +1,17 @@
# Generated by Django 4.2.5 on 2023-11-11 06:15
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0015_alter_subscription_user"),
]
operations = [
migrations.AlterField(
model_name="subscription",
name="renewal_date",
field=models.DateTimeField(blank=True, default=None, null=True),
),
]

View file

@ -51,10 +51,10 @@ class Subscription(BaseModel):
TRIAL = "trial" TRIAL = "trial"
STANDARD = "standard" STANDARD = "standard"
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE) user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL) type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
is_recurring = models.BooleanField(default=False) is_recurring = models.BooleanField(default=False)
renewal_date = models.DateTimeField(null=True, default=None) renewal_date = models.DateTimeField(null=True, default=None, blank=True)
class NotionConfig(BaseModel): class NotionConfig(BaseModel):

View file

@ -28,7 +28,7 @@ from khoj.utils.config import (
from khoj.utils.fs_syncer import collect_files from khoj.utils.fs_syncer import collect_files
from khoj.utils.rawconfig import FullConfig from khoj.utils.rawconfig import FullConfig
from khoj.routers.indexer import configure_content, load_content, configure_search from khoj.routers.indexer import configure_content, load_content, configure_search
from database.models import KhojUser from database.models import KhojUser, Subscription
from database.adapters import get_all_users from database.adapters import get_all_users
@ -54,27 +54,40 @@ class UserAuthenticationBackend(AuthenticationBackend):
def _initialize_default_user(self): def _initialize_default_user(self):
if not self.khojuser_manager.filter(username="default").exists(): if not self.khojuser_manager.filter(username="default").exists():
self.khojuser_manager.create_user( default_user = self.khojuser_manager.create_user(
username="default", username="default",
email="default@example.com", email="default@example.com",
password="default", password="default",
) )
Subscription.objects.create(
user=default_user,
type="trial",
)
async def authenticate(self, request: HTTPConnection): async def authenticate(self, request: HTTPConnection):
current_user = request.session.get("user") current_user = request.session.get("user")
if current_user and current_user.get("email"): if current_user and current_user.get("email"):
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst() user = (
await self.khojuser_manager.filter(email=current_user.get("email"))
.prefetch_related("subscription")
.afirst()
)
if user: if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header # Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1] bearer_token = request.headers["Authorization"].split("Bearer ")[1]
# Get user owning token # Get user owning token
user_with_token = await self.khojapiuser_manager.filter(token=bearer_token).select_related("user").afirst() user_with_token = (
await self.khojapiuser_manager.filter(token=bearer_token)
.select_related("user")
.prefetch_related("user__subscription")
.afirst()
)
if user_with_token: if user_with_token:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode: if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").afirst() user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user: if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)

View file

@ -13,7 +13,7 @@ from khoj.utils.helpers import ConversationCommand, log_telemetry
from khoj.processor.conversation.openai.gpt import converse from khoj.processor.conversation.openai.gpt import converse
from khoj.processor.conversation.gpt4all.chat_model import converse_offline from khoj.processor.conversation.gpt4all.chat_model import converse_offline
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
from database.models import KhojUser from database.models import KhojUser, Subscription
from database.adapters import ConversationAdapters from database.adapters import ConversationAdapters
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -61,12 +61,15 @@ def update_telemetry_state(
metadata: Optional[dict] = None, metadata: Optional[dict] = None,
): ):
user: KhojUser = request.user.object if request.user.is_authenticated else None user: KhojUser = request.user.object if request.user.is_authenticated else None
subscription: Subscription = user.subscription if user and user.subscription else None
user_state = { user_state = {
"client_host": request.client.host if request.client else None, "client_host": request.client.host if request.client else None,
"user_agent": user_agent or "unknown", "user_agent": user_agent or "unknown",
"referer": referer or "unknown", "referer": referer or "unknown",
"host": host or "unknown", "host": host or "unknown",
"server_id": str(user.uuid) if user else None, "server_id": str(user.uuid) if user else None,
"subscription_type": subscription.type if subscription else None,
"is_recurring": subscription.is_recurring if subscription else None,
} }
if metadata: if metadata:

View file

@ -43,6 +43,7 @@ from tests.helpers import (
OpenAIProcessorConversationConfigFactory, OpenAIProcessorConversationConfigFactory,
OfflineChatProcessorConversationConfigFactory, OfflineChatProcessorConversationConfigFactory,
UserConversationProcessorConfigFactory, UserConversationProcessorConfigFactory,
SubscriptionFactory,
) )
@ -69,7 +70,9 @@ def search_config() -> SearchConfig:
@pytest.mark.django_db @pytest.mark.django_db
@pytest.fixture @pytest.fixture
def default_user(): def default_user():
return UserFactory() user = UserFactory()
SubscriptionFactory(user=user)
return user
@pytest.mark.django_db @pytest.mark.django_db
@ -78,11 +81,13 @@ def default_user2():
if KhojUser.objects.filter(username="default").exists(): if KhojUser.objects.filter(username="default").exists():
return KhojUser.objects.get(username="default") return KhojUser.objects.get(username="default")
return KhojUser.objects.create( user = KhojUser.objects.create(
username="default", username="default",
email="default@example.com", email="default@example.com",
password="default", password="default",
) )
SubscriptionFactory(user=user)
return user
@pytest.mark.django_db @pytest.mark.django_db
@ -94,11 +99,13 @@ def default_user3():
if KhojUser.objects.filter(username="default3").exists(): if KhojUser.objects.filter(username="default3").exists():
return KhojUser.objects.get(username="default3") return KhojUser.objects.get(username="default3")
return KhojUser.objects.create( user = KhojUser.objects.create(
username="default3", username="default3",
email="default3@example.com", email="default3@example.com",
password="default3", password="default3",
) )
SubscriptionFactory(user=user)
return user
@pytest.mark.django_db @pytest.mark.django_db

View file

@ -9,6 +9,7 @@ from database.models import (
OpenAIProcessorConversationConfig, OpenAIProcessorConversationConfig,
UserConversationConfig, UserConversationConfig,
Conversation, Conversation,
Subscription,
) )
@ -68,3 +69,12 @@ class ConversationFactory(factory.django.DjangoModelFactory):
model = Conversation model = Conversation
user = factory.SubFactory(UserFactory) user = factory.SubFactory(UserFactory)
class SubscriptionFactory(factory.django.DjangoModelFactory):
class Meta:
model = Subscription
user = factory.SubFactory(UserFactory)
type = "trial"
is_recurring = False