diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index 34f78951..120f6647 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -361,12 +361,25 @@ if (newResponseText.getElementsByClassName("spinner").length > 0) { newResponseText.removeChild(loadingSpinner); } + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. + if (chunk.startsWith("{") && chunk.endsWith("}")) { + try { + const responseAsJson = JSON.parse(chunk); + if (responseAsJson.detail) { + newResponseText.innerHTML += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + newResponseText.innerHTML += chunk; + } + } else { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); - rawResponse += chunk; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - - readStream(); + readStream(); + } } // Scroll to bottom of chat window as chat response is streamed diff --git a/src/interface/desktop/config.html b/src/interface/desktop/config.html index fb39fbb8..f8ecb06f 100644 --- a/src/interface/desktop/config.html +++ b/src/interface/desktop/config.html @@ -101,6 +101,9 @@
+ diff --git a/src/interface/desktop/main.js b/src/interface/desktop/main.js index eb355a5f..95927be1 100644 --- a/src/interface/desktop/main.js +++ b/src/interface/desktop/main.js @@ -198,6 +198,11 @@ function pushDataToKhoj (regenerate = false) { }) .catch(error => { console.error(error); + if (error.response.status == 429) { + const win = BrowserWindow.getAllWindows()[0]; + if (win) win.webContents.send('needsSubscription', true); + if (win) win.webContents.send('update-state', state); + } state['completed'] = false }) .finally(() => { @@ -396,6 +401,11 @@ app.whenReady().then(() => { event.reply('update-state', arg); }); + ipcMain.on('needsSubscription', (event, arg) => { + console.log(arg); + event.reply('needsSubscription', arg); + }); + ipcMain.on('navigate', (event, page) => { win.loadFile(page); }); diff --git a/src/interface/desktop/preload.js b/src/interface/desktop/preload.js index 1d4c6ec0..8d5152b7 100644 --- a/src/interface/desktop/preload.js +++ b/src/interface/desktop/preload.js @@ -31,6 +31,10 @@ contextBridge.exposeInMainWorld('updateStateAPI', { onUpdateState: (callback) => ipcRenderer.on('update-state', callback) }) +contextBridge.exposeInMainWorld('needsSubscriptionAPI', { + onNeedsSubscription: (callback) => ipcRenderer.on('needsSubscription', callback) +}) + contextBridge.exposeInMainWorld('removeFileAPI', { removeFile: (filePath) => ipcRenderer.invoke('removeFile', filePath) }) diff --git a/src/interface/desktop/renderer.js b/src/interface/desktop/renderer.js index 7d0d906e..16df6d2f 100644 --- a/src/interface/desktop/renderer.js +++ b/src/interface/desktop/renderer.js @@ -1,7 +1,7 @@ const setFolderButton = document.getElementById('update-folder'); const setFileButton = document.getElementById('update-file'); -const showKey = document.getElementById('show-key'); const loadingBar = document.getElementById('loading-bar'); +const needsSubscriptionElement = document.getElementById('needs-subscription'); async function removeFile(filePath) { const updatedFiles = await window.removeFileAPI.removeFile(filePath); @@ -165,6 +165,15 @@ window.updateStateAPI.onUpdateState((event, state) => { syncStatusElement.innerHTML = `⏱️ Synced at ${currentTime.toLocaleTimeString(undefined, options)}. Next sync at ${nextSyncTime.toLocaleTimeString(undefined, options)}.`; }); +window.needsSubscriptionAPI.onNeedsSubscription((event, needsSubscription) => { + console.log("needs subscription", needsSubscription); + if (needsSubscription) { + needsSubscriptionElement.style.display = 'block'; + } else { + needsSubscriptionElement.style.display = 'none'; + } +}); + const urlInput = document.getElementById('khoj-host-url'); (async function() { const url = await window.hostURLAPI.getURL(); diff --git a/src/khoj/configure.py b/src/khoj/configure.py index d34205d9..19e7d403 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -21,7 +21,12 @@ from starlette.authentication import ( # Internal Packages from khoj.database.models import KhojUser, Subscription -from khoj.database.adapters import get_all_users, get_or_create_search_model +from khoj.database.adapters import ( + get_all_users, + get_or_create_search_model, + aget_user_subscription_state, + SubscriptionState, +) from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.routers.indexer import configure_content, load_content, configure_search from khoj.utils import constants, state @@ -70,7 +75,17 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user: - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + if state.billing_enabled: + subscription_state = await aget_user_subscription_state(user) + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value + ) + if subscribed: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: # Get bearer token from header bearer_token = request.headers["Authorization"].split("Bearer ")[1] @@ -82,11 +97,23 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user_with_token: - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) + if state.billing_enabled: + subscription_state = await aget_user_subscription_state(user_with_token.user) + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value + ) + if subscribed: + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser( + user_with_token.user + ) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) if state.anonymous_mode: user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() if user: - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) return AuthCredentials(), UnauthenticatedUser() diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index eb143ab6..12a127e9 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1,8 +1,10 @@ import math import random import secrets -from datetime import date, datetime, timezone +import sys +from datetime import date, datetime, timezone, timedelta from typing import List, Optional, Type +from enum import Enum from asgiref.sync import sync_to_async from django.contrib.sessions.backends.db import SessionStore @@ -41,6 +43,14 @@ from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.helpers import generate_random_name +class SubscriptionState(Enum): + TRIAL = "trial" + SUBSCRIBED = "subscribed" + UNSUBSCRIBED = "unsubscribed" + EXPIRED = "expired" + INVALID = "invalid" + + async def set_notion_config(token: str, user: KhojUser): notion_config = await NotionConfig.objects.filter(user=user).afirst() if not notion_config: @@ -128,22 +138,38 @@ async def set_user_subscription( return None +def subscription_to_state(subscription: Subscription) -> str: + if not subscription: + return SubscriptionState.INVALID.value + elif subscription.type == Subscription.Type.TRIAL: + # Trial subscription is valid for 7 days + if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=7): + return SubscriptionState.EXPIRED.value + + return SubscriptionState.TRIAL.value + elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): + return SubscriptionState.SUBSCRIBED.value + elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): + return SubscriptionState.UNSUBSCRIBED.value + elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc): + return SubscriptionState.EXPIRED.value + return SubscriptionState.INVALID.value + + def get_user_subscription_state(email: str) -> str: """Get subscription state of user Valid state transitions: trial -> subscribed <-> unsubscribed OR expired """ user_subscription = Subscription.objects.filter(user__email=email).first() - if not user_subscription: - return "trial" - elif user_subscription.type == Subscription.Type.TRIAL: - return "trial" - elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): - return "subscribed" - elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): - return "unsubscribed" - elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc): - return "expired" - return "invalid" + return subscription_to_state(user_subscription) + + +async def aget_user_subscription_state(email: str) -> str: + """Get subscription state of user + Valid state transitions: trial -> subscribed <-> unsubscribed OR expired + """ + user_subscription = await Subscription.objects.filter(user__email=email).afirst() + return subscription_to_state(user_subscription) async def get_user_by_email(email: str) -> KhojUser: @@ -458,6 +484,12 @@ class EntryAdapters: async def adelete_all_entries(user: KhojUser): return await Entry.objects.filter(user=user).adelete() + @staticmethod + def get_size_of_indexed_data_in_mb(user: KhojUser): + entries = Entry.objects.filter(user=user).iterator() + total_size = sum(sys.getsizeof(entry.compiled) for entry in entries) + return total_size / 1024 / 1024 + @staticmethod def apply_filters(user: KhojUser, query: str, file_type_filter: str = None): q_filter_terms = Q() diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index f62af230..de6b899c 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -402,10 +402,24 @@ To get started, just start typing below. You can also type / to see a list of co newResponseText.removeChild(loadingSpinner); } - rawResponse += chunk; - newResponseText.innerHTML = ""; - newResponseText.appendChild(formatHTMLMessage(rawResponse)); - readStream(); + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. + if (chunk.startsWith("{") && chunk.endsWith("}")) { + try { + const responseAsJson = JSON.parse(chunk); + if (responseAsJson.detail) { + newResponseText.innerHTML += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + newResponseText.innerHTML += chunk; + } + } else { + // If the chunk is not a JSON object, just display it as is + rawResponse += chunk; + newResponseText.innerHTML = ""; + newResponseText.appendChild(formatHTMLMessage(rawResponse)); + readStream(); + } } // Scroll to bottom of chat window as chat response is streamed diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 01a3786f..96d82131 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -4,6 +4,10 @@

Content

+ +

@@ -191,7 +195,7 @@

- Subscribe to Khoj Cloud + Subscribe to Khoj Cloud. See pricing for details.

response.json()) + .then(data => { + document.getElementById("indexed-data-size").innerHTML = data.indexed_data_size_in_mb + " MB used"; + }); + } + // List user's API keys on page load listApiKeys(); diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 3fd2285d..ae125980 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional, Union import uuid # External Packages -from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile +from asgiref.sync import sync_to_async from fastapi.requests import Request from fastapi.responses import Response, StreamingResponse from starlette.authentication import requires @@ -334,6 +334,18 @@ def get_default_config_data(): return constants.empty_config +@api.get("/config/index/size", response_model=Dict[str, int]) +@requires(["authenticated"]) +async def get_indexed_data_size(request: Request, common: CommonQueryParams): + user = request.user.object + indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user) + return Response( + content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}), + media_type="application/json", + status_code=200, + ) + + @api.get("/config/types", response_model=List[str]) @requires(["authenticated"]) def get_config_types( @@ -650,8 +662,8 @@ async def chat( n: Optional[int] = 5, d: Optional[float] = 0.18, stream: Optional[bool] = False, - rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)), - rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)), + rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)), + rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), ) -> Response: user = request.user.object diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index e1ab05b5..1dd3f4c7 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -9,10 +9,12 @@ from functools import partial from time import time from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union -# External Packages -from fastapi import Depends, Header, HTTPException, Request +from fastapi import Depends, Header, HTTPException, Request, UploadFile +from starlette.authentication import has_required_scope +from asgiref.sync import sync_to_async -from khoj.database.adapters import ConversationAdapters + +from khoj.database.adapters import ConversationAdapters, EntryAdapters from khoj.database.models import KhojUser, Subscription from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline @@ -270,13 +272,15 @@ def generate_chat_response( class ApiUserRateLimiter: - def __init__(self, requests: int, window: int): + def __init__(self, requests: int, subscribed_requests: int, window: int): self.requests = requests + self.subscribed_requests = subscribed_requests self.window = window self.cache: dict[str, list[float]] = defaultdict(list) def __call__(self, request: Request): user: KhojUser = request.user.object + subscribed = has_required_scope(request, ["premium"]) user_requests = self.cache[user.uuid] # Remove requests outside of the time window @@ -285,13 +289,69 @@ class ApiUserRateLimiter: user_requests.pop(0) # Check if the user has exceeded the rate limit - if len(user_requests) >= self.requests: + if subscribed and len(user_requests) >= self.subscribed_requests: raise HTTPException(status_code=429, detail="Too Many Requests") + if not subscribed and len(user_requests) >= self.requests: + raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.") # Add the current request to the cache user_requests.append(time()) +class ApiIndexedDataLimiter: + def __init__( + self, + incoming_entries_size_limit: float, + subscribed_incoming_entries_size_limit: float, + total_entries_size_limit: float, + subscribed_total_entries_size_limit: float, + ): + self.num_entries_size = incoming_entries_size_limit + self.subscribed_num_entries_size = subscribed_incoming_entries_size_limit + self.total_entries_size_limit = total_entries_size_limit + self.subscribed_total_entries_size = subscribed_total_entries_size_limit + + def __call__(self, request: Request, files: List[UploadFile]): + if state.billing_enabled is False: + return + subscribed = has_required_scope(request, ["premium"]) + incoming_data_size_mb = 0 + deletion_file_names = set() + + if not request.user.is_authenticated: + return + + user: KhojUser = request.user.object + + for file in files: + if file.size == 0: + deletion_file_names.add(file.filename) + + incoming_data_size_mb += file.size / 1024 / 1024 + + num_deleted_entries = 0 + for file_path in deletion_file_names: + deleted_count = EntryAdapters.delete_entry_by_file(user, file_path) + num_deleted_entries += deleted_count + + logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.") + + if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size: + raise HTTPException(status_code=429, detail="Too much data indexed.") + if not subscribed and incoming_data_size_mb >= self.num_entries_size: + raise HTTPException( + status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit." + ) + + user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user) + if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size: + raise HTTPException(status_code=429, detail="Too much data indexed.") + if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit: + raise HTTPException( + status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit." + ) + + class CommonQueryParamsClass: def __init__( self, diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index 0432eed0..0c906707 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -2,7 +2,7 @@ import asyncio import logging from typing import Dict, Optional, Union -from fastapi import APIRouter, Header, Request, Response, UploadFile +from fastapi import APIRouter, Header, Request, Response, UploadFile, Depends from pydantic import BaseModel from starlette.authentication import requires @@ -18,6 +18,7 @@ from khoj.search_type import image_search, text_search from khoj.utils import constants, state from khoj.utils.config import ContentIndex, SearchModels from khoj.utils.helpers import LRU, get_file_type +from khoj.routers.helpers import ApiIndexedDataLimiter from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig from khoj.utils.yaml import save_config_to_file_updated_state @@ -53,6 +54,14 @@ async def update( user_agent: Optional[str] = Header(None), referer: Optional[str] = Header(None), host: Optional[str] = Header(None), + indexed_data_limiter: ApiIndexedDataLimiter = Depends( + ApiIndexedDataLimiter( + incoming_entries_size_limit=10, + subscribed_incoming_entries_size_limit=25, + total_entries_size_limit=10, + subscribed_total_entries_size_limit=100, + ) + ), ): user = request.user.object try: @@ -92,7 +101,7 @@ async def update( logger.info("📬 Initializing content index on first run.") default_full_config = FullConfig( content_type=None, - search_type=SearchConfig.parse_obj(constants.default_config["search-type"]), + search_type=SearchConfig.model_validate(constants.default_config["search-type"]), processor=None, ) state.config = default_full_config @@ -116,7 +125,7 @@ async def update( configure_content, state.content_index, state.config.content_type, - indexer_input.dict(), + indexer_input.model_dump(), state.search_models, force, t, diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index dab16fa8..7907f99e 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -1,13 +1,15 @@ # System Packages import json import os +import math +from datetime import timedelta # External Packages from fastapi import APIRouter from fastapi import Request from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse from fastapi.templating import Jinja2Templates -from starlette.authentication import requires +from starlette.authentication import requires, has_required_scope from khoj.database import adapters from khoj.database.models import KhojUser from khoj.utils.rawconfig import ( @@ -37,7 +39,6 @@ templates = Jinja2Templates(directory=constants.web_directory) def index(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -46,7 +47,7 @@ def index(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -57,7 +58,6 @@ def index(request: Request): def index_post(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -66,7 +66,7 @@ def index_post(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -77,7 +77,6 @@ def index_post(request: Request): def search_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -86,7 +85,7 @@ def search_page(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -97,7 +96,6 @@ def search_page(request: Request): def chat_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -106,7 +104,7 @@ def chat_page(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -141,7 +139,7 @@ def config_page(request: Request): subscription_renewal_date = ( user_subscription.renewal_date.strftime("%d %b %Y") if user_subscription and user_subscription.renewal_date - else None + else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y") ) enabled_content_source = set(EntryAdapters.get_unique_file_sources(user)) @@ -171,7 +169,7 @@ def config_page(request: Request): "subscription_state": user_subscription_state, "subscription_renewal_date": subscription_renewal_date, "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -182,7 +180,6 @@ def config_page(request: Request): def github_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) current_github_config = get_user_github_config(user) @@ -212,7 +209,7 @@ def github_config_page(request: Request): "current_config": current_config, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -223,7 +220,6 @@ def github_config_page(request: Request): def notion_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = adapters.get_user_subscription(user.email) has_documents = EntryAdapters.user_has_entries(user=user) current_notion_config = get_user_notion_config(user) @@ -240,7 +236,7 @@ def notion_config_page(request: Request): "current_config": current_config, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -251,7 +247,6 @@ def notion_config_page(request: Request): def computer_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -260,7 +255,7 @@ def computer_config_page(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) diff --git a/tests/conftest.py b/tests/conftest.py index 54c664d5..9a500609 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -102,6 +102,24 @@ def default_user3(): return user +@pytest.mark.django_db +@pytest.fixture +def default_user4(): + """ + This user should not have a valid subscription + """ + if KhojUser.objects.filter(username="default4").exists(): + return KhojUser.objects.get(username="default4") + + user = KhojUser.objects.create( + username="default4", + email="default4@example.com", + password="default4", + ) + SubscriptionFactory(user=user, renewal_date=None) + return user + + @pytest.mark.django_db @pytest.fixture def api_user(default_user): @@ -141,6 +159,19 @@ def api_user3(default_user3): ) +@pytest.mark.django_db +@pytest.fixture +def api_user4(default_user4): + if KhojApiUser.objects.filter(user=default_user4).exists(): + return KhojApiUser.objects.get(user=default_user4) + + return KhojApiUser.objects.create( + user=default_user4, + name="api-key", + token="kk-diff-secret-4", + ) + + @pytest.fixture(scope="session") def search_models(search_config: SearchConfig): search_models = SearchModels() diff --git a/tests/test_client.py b/tests/test_client.py index 19aba03b..f23a350e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -125,6 +125,67 @@ def test_regenerate_with_invalid_content_type(client): assert response.status_code == 422 +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) +def test_index_update_big_files(client): + # Arrange + state.billing_enabled = True + files = get_big_size_sample_files_data() + headers = {"Authorization": "Bearer kk-secret"} + + # Act + response = client.post("/api/v1/index/update", files=files, headers=headers) + + # Assert + assert response.status_code == 429 + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) +def test_index_update_medium_file_unsubscribed(client, api_user4: KhojApiUser): + # Arrange + api_token = api_user4.token + state.billing_enabled = True + files = get_medium_size_sample_files_data() + headers = {"Authorization": f"Bearer {api_token}"} + + # Act + response = client.post("/api/v1/index/update", files=files, headers=headers) + + # Assert + assert response.status_code == 429 + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) +def test_index_update_normal_file_unsubscribed(client, api_user4: KhojApiUser): + # Arrange + api_token = api_user4.token + state.billing_enabled = True + files = get_sample_files_data() + headers = {"Authorization": f"Bearer {api_token}"} + + # Act + response = client.post("/api/v1/index/update", files=files, headers=headers) + + # Assert + assert response.status_code == 200 + + +@pytest.mark.django_db(transaction=True) +def test_index_update_big_files_no_billing(client): + # Arrange + state.billing_enabled = False + files = get_big_size_sample_files_data() + headers = {"Authorization": "Bearer kk-secret"} + + # Act + response = client.post("/api/v1/index/update", files=files, headers=headers) + + # Assert + assert response.status_code == 200 + + # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) def test_index_update(client): @@ -421,3 +482,23 @@ def get_sample_files_data(): ), ("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")), ] + + +def get_big_size_sample_files_data(): + big_text = "a" * (25 * 1024 * 1024) # a string of approximately 25 MB + return [ + ( + "files", + ("path/to/filename.org", big_text, "text/org"), + ) + ] + + +def get_medium_size_sample_files_data(): + big_text = "a" * (10 * 1024 * 1024) # a string of approximately 10 MB + return [ + ( + "files", + ("path/to/filename.org", big_text, "text/org"), + ) + ]