mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Add default settings to let new users be subscribed on trial
- Add the default user to a subscription trial - Update associated unit tests
This commit is contained in:
parent
0a950d9382
commit
e2e96f9aa4
9 changed files with 86 additions and 11 deletions
|
@ -101,6 +101,8 @@ async def create_google_user(token: dict) -> KhojUser:
|
|||
user=user,
|
||||
)
|
||||
|
||||
await Subscription.objects.acreate(user=user, type="trial")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
|
|
|
@ -8,6 +8,7 @@ from database.models import (
|
|||
ChatModelOptions,
|
||||
OpenAIProcessorConversationConfig,
|
||||
OfflineChatProcessorConversationConfig,
|
||||
Subscription,
|
||||
)
|
||||
|
||||
admin.site.register(KhojUser, UserAdmin)
|
||||
|
@ -15,3 +16,4 @@ admin.site.register(KhojUser, UserAdmin)
|
|||
admin.site.register(ChatModelOptions)
|
||||
admin.site.register(OpenAIProcessorConversationConfig)
|
||||
admin.site.register(OfflineChatProcessorConversationConfig)
|
||||
admin.site.register(Subscription)
|
||||
|
|
21
src/database/migrations/0015_alter_subscription_user.py
Normal file
21
src/database/migrations/0015_alter_subscription_user.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-11 05:39
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0014_alter_googleuser_picture"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="subscription",
|
||||
name="user",
|
||||
field=models.OneToOneField(
|
||||
on_delete=django.db.models.deletion.CASCADE, related_name="subscription", to=settings.AUTH_USER_MODEL
|
||||
),
|
||||
),
|
||||
]
|
|
@ -0,0 +1,17 @@
|
|||
# Generated by Django 4.2.5 on 2023-11-11 06:15
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0015_alter_subscription_user"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="subscription",
|
||||
name="renewal_date",
|
||||
field=models.DateTimeField(blank=True, default=None, null=True),
|
||||
),
|
||||
]
|
|
@ -51,10 +51,10 @@ class Subscription(BaseModel):
|
|||
TRIAL = "trial"
|
||||
STANDARD = "standard"
|
||||
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE)
|
||||
user = models.OneToOneField(KhojUser, on_delete=models.CASCADE, related_name="subscription")
|
||||
type = models.CharField(max_length=20, choices=Type.choices, default=Type.TRIAL)
|
||||
is_recurring = models.BooleanField(default=False)
|
||||
renewal_date = models.DateTimeField(null=True, default=None)
|
||||
renewal_date = models.DateTimeField(null=True, default=None, blank=True)
|
||||
|
||||
|
||||
class NotionConfig(BaseModel):
|
||||
|
|
|
@ -28,7 +28,7 @@ from khoj.utils.config import (
|
|||
from khoj.utils.fs_syncer import collect_files
|
||||
from khoj.utils.rawconfig import FullConfig
|
||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||
from database.models import KhojUser
|
||||
from database.models import KhojUser, Subscription
|
||||
from database.adapters import get_all_users
|
||||
|
||||
|
||||
|
@ -54,27 +54,40 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
|
||||
def _initialize_default_user(self):
|
||||
if not self.khojuser_manager.filter(username="default").exists():
|
||||
self.khojuser_manager.create_user(
|
||||
default_user = self.khojuser_manager.create_user(
|
||||
username="default",
|
||||
email="default@example.com",
|
||||
password="default",
|
||||
)
|
||||
Subscription.objects.create(
|
||||
user=default_user,
|
||||
type="trial",
|
||||
)
|
||||
|
||||
async def authenticate(self, request: HTTPConnection):
|
||||
current_user = request.session.get("user")
|
||||
if current_user and current_user.get("email"):
|
||||
user = await self.khojuser_manager.filter(email=current_user.get("email")).afirst()
|
||||
user = (
|
||||
await self.khojuser_manager.filter(email=current_user.get("email"))
|
||||
.prefetch_related("subscription")
|
||||
.afirst()
|
||||
)
|
||||
if user:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
||||
# Get bearer token from header
|
||||
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
||||
# Get user owning token
|
||||
user_with_token = await self.khojapiuser_manager.filter(token=bearer_token).select_related("user").afirst()
|
||||
user_with_token = (
|
||||
await self.khojapiuser_manager.filter(token=bearer_token)
|
||||
.select_related("user")
|
||||
.prefetch_related("user__subscription")
|
||||
.afirst()
|
||||
)
|
||||
if user_with_token:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
if state.anonymous_mode:
|
||||
user = await self.khojuser_manager.filter(username="default").afirst()
|
||||
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
||||
if user:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ from khoj.utils.helpers import ConversationCommand, log_telemetry
|
|||
from khoj.processor.conversation.openai.gpt import converse
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline
|
||||
from khoj.processor.conversation.utils import message_to_log, ThreadedGenerator
|
||||
from database.models import KhojUser
|
||||
from database.models import KhojUser, Subscription
|
||||
from database.adapters import ConversationAdapters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -61,12 +61,15 @@ def update_telemetry_state(
|
|||
metadata: Optional[dict] = None,
|
||||
):
|
||||
user: KhojUser = request.user.object if request.user.is_authenticated else None
|
||||
subscription: Subscription = user.subscription if user and user.subscription else None
|
||||
user_state = {
|
||||
"client_host": request.client.host if request.client else None,
|
||||
"user_agent": user_agent or "unknown",
|
||||
"referer": referer or "unknown",
|
||||
"host": host or "unknown",
|
||||
"server_id": str(user.uuid) if user else None,
|
||||
"subscription_type": subscription.type if subscription else None,
|
||||
"is_recurring": subscription.is_recurring if subscription else None,
|
||||
}
|
||||
|
||||
if metadata:
|
||||
|
|
|
@ -43,6 +43,7 @@ from tests.helpers import (
|
|||
OpenAIProcessorConversationConfigFactory,
|
||||
OfflineChatProcessorConversationConfigFactory,
|
||||
UserConversationProcessorConfigFactory,
|
||||
SubscriptionFactory,
|
||||
)
|
||||
|
||||
|
||||
|
@ -69,7 +70,9 @@ def search_config() -> SearchConfig:
|
|||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_user():
|
||||
return UserFactory()
|
||||
user = UserFactory()
|
||||
SubscriptionFactory(user=user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
|
@ -78,11 +81,13 @@ def default_user2():
|
|||
if KhojUser.objects.filter(username="default").exists():
|
||||
return KhojUser.objects.get(username="default")
|
||||
|
||||
return KhojUser.objects.create(
|
||||
user = KhojUser.objects.create(
|
||||
username="default",
|
||||
email="default@example.com",
|
||||
password="default",
|
||||
)
|
||||
SubscriptionFactory(user=user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
|
@ -94,11 +99,13 @@ def default_user3():
|
|||
if KhojUser.objects.filter(username="default3").exists():
|
||||
return KhojUser.objects.get(username="default3")
|
||||
|
||||
return KhojUser.objects.create(
|
||||
user = KhojUser.objects.create(
|
||||
username="default3",
|
||||
email="default3@example.com",
|
||||
password="default3",
|
||||
)
|
||||
SubscriptionFactory(user=user)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
|
|
|
@ -9,6 +9,7 @@ from database.models import (
|
|||
OpenAIProcessorConversationConfig,
|
||||
UserConversationConfig,
|
||||
Conversation,
|
||||
Subscription,
|
||||
)
|
||||
|
||||
|
||||
|
@ -68,3 +69,12 @@ class ConversationFactory(factory.django.DjangoModelFactory):
|
|||
model = Conversation
|
||||
|
||||
user = factory.SubFactory(UserFactory)
|
||||
|
||||
|
||||
class SubscriptionFactory(factory.django.DjangoModelFactory):
|
||||
class Meta:
|
||||
model = Subscription
|
||||
|
||||
user = factory.SubFactory(UserFactory)
|
||||
type = "trial"
|
||||
is_recurring = False
|
||||
|
|
Loading…
Reference in a new issue