From 69c8f45830f02c4d7b19865a57779e17b85b8f74 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 24 Nov 2023 20:29:36 -0800 Subject: [PATCH 01/12] Use scopes to represent whether the use has a valid subscription in the middleware --- src/khoj/configure.py | 19 ++++++++++-- src/khoj/database/adapters/__init__.py | 43 +++++++++++++++++++------- src/khoj/interface/web/config.html | 2 +- src/khoj/routers/api.py | 2 +- src/khoj/routers/web_client.py | 25 ++++++--------- 5 files changed, 59 insertions(+), 32 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index d34205d9..4e4e1008 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -21,7 +21,12 @@ from starlette.authentication import ( # Internal Packages from khoj.database.models import KhojUser, Subscription -from khoj.database.adapters import get_all_users, get_or_create_search_model +from khoj.database.adapters import ( + get_all_users, + get_or_create_search_model, + aget_user_subscription_state, + SubscriptionState, +) from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel from khoj.routers.indexer import configure_content, load_content, configure_search from khoj.utils import constants, state @@ -70,7 +75,11 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user: - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + subscription_state = await aget_user_subscription_state(user) + subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value + if subscribed: + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: # Get bearer token from header bearer_token = request.headers["Authorization"].split("Bearer ")[1] @@ -82,11 +91,15 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user_with_token: + subscription_state = await aget_user_subscription_state(user_with_token.user) + subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value + if subscribed: + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user_with_token.user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) if state.anonymous_mode: user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() if user: - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) return AuthCredentials(), UnauthenticatedUser() diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 7fd04006..146de11c 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -3,6 +3,7 @@ import random import secrets from datetime import date, datetime, timezone from typing import List, Optional, Type +from enum import Enum from asgiref.sync import sync_to_async from django.contrib.sessions.backends.db import SessionStore @@ -40,6 +41,14 @@ from khoj.utils.config import GPT4AllProcessorModel from khoj.utils.helpers import generate_random_name +class SubscriptionState(Enum): + TRIAL = "trial" + SUBSCRIBED = "subscribed" + UNSUBSCRIBED = "unsubscribed" + EXPIRED = "expired" + INVALID = "invalid" + + async def set_notion_config(token: str, user: KhojUser): notion_config = await NotionConfig.objects.filter(user=user).afirst() if not notion_config: @@ -127,22 +136,34 @@ async def set_user_subscription( return None +def subscription_to_state(subscription: Subscription) -> str: + if not subscription: + return SubscriptionState.INVALID.value + elif subscription.type == Subscription.Type.TRIAL: + return SubscriptionState.TRIAL.value + elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): + return SubscriptionState.SUBSCRIBED.value + elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): + return SubscriptionState.UNSUBSCRIBED.value + elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc): + return SubscriptionState.EXPIRED.value + return SubscriptionState.INVALID.value + + def get_user_subscription_state(email: str) -> str: """Get subscription state of user Valid state transitions: trial -> subscribed <-> unsubscribed OR expired """ user_subscription = Subscription.objects.filter(user__email=email).first() - if not user_subscription: - return "trial" - elif user_subscription.type == Subscription.Type.TRIAL: - return "trial" - elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): - return "subscribed" - elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc): - return "unsubscribed" - elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc): - return "expired" - return "invalid" + return subscription_to_state(user_subscription) + + +async def aget_user_subscription_state(email: str) -> str: + """Get subscription state of user + Valid state transitions: trial -> subscribed <-> unsubscribed OR expired + """ + user_subscription = await Subscription.objects.filter(user__email=email).afirst() + return subscription_to_state(user_subscription) async def get_user_by_email(email: str) -> KhojUser: diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 01a3786f..318759df 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -171,7 +171,7 @@ - {% if billing_enabled %} + {% if not billing_enabled %}

Billing

diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 83955088..d5b6ce0e 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -12,7 +12,7 @@ from asgiref.sync import sync_to_async from fastapi import APIRouter, Depends, Header, HTTPException, Request from fastapi.requests import Request from fastapi.responses import Response, StreamingResponse -from starlette.authentication import requires +from starlette.authentication import requires, has_required_scope # Internal Packages from khoj.configure import configure_server diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index dab16fa8..bf3ff957 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -7,7 +7,7 @@ from fastapi import APIRouter from fastapi import Request from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse from fastapi.templating import Jinja2Templates -from starlette.authentication import requires +from starlette.authentication import requires, has_required_scope from khoj.database import adapters from khoj.database.models import KhojUser from khoj.utils.rawconfig import ( @@ -37,7 +37,6 @@ templates = Jinja2Templates(directory=constants.web_directory) def index(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -46,7 +45,7 @@ def index(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -57,7 +56,6 @@ def index(request: Request): def index_post(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -66,7 +64,7 @@ def index_post(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -77,7 +75,6 @@ def index_post(request: Request): def search_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -86,7 +83,7 @@ def search_page(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -97,7 +94,6 @@ def search_page(request: Request): def chat_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -106,7 +102,7 @@ def chat_page(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -171,7 +167,7 @@ def config_page(request: Request): "subscription_state": user_subscription_state, "subscription_renewal_date": subscription_renewal_date, "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -182,7 +178,6 @@ def config_page(request: Request): def github_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) current_github_config = get_user_github_config(user) @@ -212,7 +207,7 @@ def github_config_page(request: Request): "current_config": current_config, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -223,7 +218,6 @@ def github_config_page(request: Request): def notion_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = adapters.get_user_subscription(user.email) has_documents = EntryAdapters.user_has_entries(user=user) current_notion_config = get_user_notion_config(user) @@ -240,7 +234,7 @@ def notion_config_page(request: Request): "current_config": current_config, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) @@ -251,7 +245,6 @@ def notion_config_page(request: Request): def computer_config_page(request: Request): user = request.user.object user_picture = request.session.get("user", {}).get("picture") - user_subscription_state = get_user_subscription_state(user.email) has_documents = EntryAdapters.user_has_entries(user=user) return templates.TemplateResponse( @@ -260,7 +253,7 @@ def computer_config_page(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed", + "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, }, ) From 9c868ee10ba26a90e7e2c30df06a257e32b4e29d Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 24 Nov 2023 20:41:19 -0800 Subject: [PATCH 02/12] Use the state.billing_enabled field to determine whether to use the subscribed scope --- src/khoj/configure.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 4e4e1008..2d21bd01 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -75,11 +75,14 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user: - subscription_state = await aget_user_subscription_state(user) - subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value - if subscribed: - return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) - return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) + if state.billing_enabled: + subscription_state = await aget_user_subscription_state(user) + subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value + if subscribed: + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( + user_with_token.user + ) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: # Get bearer token from header bearer_token = request.headers["Authorization"].split("Bearer ")[1] @@ -91,10 +94,13 @@ class UserAuthenticationBackend(AuthenticationBackend): .afirst() ) if user_with_token: - subscription_state = await aget_user_subscription_state(user_with_token.user) - subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value - if subscribed: - return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user_with_token.user) + if state.billing_enabled: + subscription_state = await aget_user_subscription_state(user_with_token.user) + subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value + if subscribed: + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( + user_with_token.user + ) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) if state.anonymous_mode: user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() From e5b1350523102f89efe1ef579546234e24950f36 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 24 Nov 2023 21:55:16 -0800 Subject: [PATCH 03/12] Enforce API use limits depending on whether the server has billing enabled and whether the given user is subscribed --- src/interface/desktop/chat.html | 19 ++++++++++++++++--- src/khoj/configure.py | 6 ++++-- src/khoj/interface/web/chat.html | 18 ++++++++++++++++-- src/khoj/routers/api.py | 4 ++-- src/khoj/routers/helpers.py | 9 +++++++-- 5 files changed, 45 insertions(+), 11 deletions(-) diff --git a/src/interface/desktop/chat.html b/src/interface/desktop/chat.html index ecd8ebf9..ac8d9589 100644 --- a/src/interface/desktop/chat.html +++ b/src/interface/desktop/chat.html @@ -361,9 +361,22 @@ if (newResponseText.getElementsByClassName("spinner").length > 0) { newResponseText.removeChild(loadingSpinner); } - - newResponseText.innerHTML += chunk; - readStream(); + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. + if (chunk.startsWith("{") && chunk.endsWith("}")) { + try { + const responseAsJson = JSON.parse(chunk); + if (responseAsJson.detail) { + newResponseText.innerHTML += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + newResponseText.innerHTML += chunk; + } + } else { + // If the chunk is not a JSON object, just display it as is + newResponseText.innerHTML += chunk; + readStream(); + } } // Scroll to bottom of chat window as chat response is streamed diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 2d21bd01..19c1fd81 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -82,7 +82,8 @@ class UserAuthenticationBackend(AuthenticationBackend): return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( user_with_token.user ) - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: # Get bearer token from header bearer_token = request.headers["Authorization"].split("Bearer ")[1] @@ -101,7 +102,8 @@ class UserAuthenticationBackend(AuthenticationBackend): return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( user_with_token.user ) - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) if state.anonymous_mode: user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() if user: diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index abab83ab..f67ab857 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -403,8 +403,22 @@ To get started, just start typing below. You can also type / to see a list of co newResponseText.removeChild(loadingSpinner); } - newResponseText.innerHTML += chunk; - readStream(); + // Try to parse the chunk as a JSON object. It will be a JSON object if there is an error. + if (chunk.startsWith("{") && chunk.endsWith("}")) { + try { + const responseAsJson = JSON.parse(chunk); + if (responseAsJson.detail) { + newResponseText.innerHTML += responseAsJson.detail; + } + } catch (error) { + // If the chunk is not a JSON object, just display it as is + newResponseText.innerHTML += chunk; + } + } else { + // If the chunk is not a JSON object, just display it as is + newResponseText.innerHTML += chunk; + readStream(); + } } // Scroll to bottom of chat window as chat response is streamed diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index d5b6ce0e..f1967e67 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -573,8 +573,8 @@ async def chat( n: Optional[int] = 5, d: Optional[float] = 0.18, stream: Optional[bool] = False, - rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)), - rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)), + rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)), + rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), ) -> Response: user = request.user.object diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index c6fcb436..dbcef1bc 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -11,6 +11,7 @@ from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union # External Packages from fastapi import Depends, Header, HTTPException, Request +from starlette.authentication import has_required_scope from khoj.database.adapters import ConversationAdapters from khoj.database.models import KhojUser, Subscription @@ -270,13 +271,15 @@ def generate_chat_response( class ApiUserRateLimiter: - def __init__(self, requests: int, window: int): + def __init__(self, requests: int, subscribed_requests: int, window: int): self.requests = requests + self.subscribed_requests = subscribed_requests self.window = window self.cache: dict[str, list[float]] = defaultdict(list) def __call__(self, request: Request): user: KhojUser = request.user.object + subscribed = has_required_scope(request, ["subscribed"]) user_requests = self.cache[user.uuid] # Remove requests outside of the time window @@ -285,8 +288,10 @@ class ApiUserRateLimiter: user_requests.pop(0) # Check if the user has exceeded the rate limit - if len(user_requests) >= self.requests: + if subscribed and len(user_requests) >= self.subscribed_requests: raise HTTPException(status_code=429, detail="Too Many Requests") + if not subscribed and len(user_requests) >= self.requests: + raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.") # Add the current request to the cache user_requests.append(time()) From 771f9bcfa1844827bb23b99be9f9bf6bf3e98d35 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 24 Nov 2023 22:08:32 -0800 Subject: [PATCH 04/12] If the user subscription was created over 7 days ago, then their trial is expired --- src/khoj/configure.py | 10 ++++++++-- src/khoj/database/adapters/__init__.py | 6 +++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 19c1fd81..fc64af0f 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -77,7 +77,10 @@ class UserAuthenticationBackend(AuthenticationBackend): if user: if state.billing_enabled: subscription_state = await aget_user_subscription_state(user) - subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + ) if subscribed: return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( user_with_token.user @@ -97,7 +100,10 @@ class UserAuthenticationBackend(AuthenticationBackend): if user_with_token: if state.billing_enabled: subscription_state = await aget_user_subscription_state(user_with_token.user) - subscribed = subscription_state == SubscriptionState.SUBSCRIBED.value + subscribed = ( + subscription_state == SubscriptionState.SUBSCRIBED.value + or subscription_state == SubscriptionState.TRIAL.value + ) if subscribed: return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( user_with_token.user diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 146de11c..7f76b796 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1,7 +1,7 @@ import math import random import secrets -from datetime import date, datetime, timezone +from datetime import date, datetime, timezone, timedelta from typing import List, Optional, Type from enum import Enum @@ -140,6 +140,10 @@ def subscription_to_state(subscription: Subscription) -> str: if not subscription: return SubscriptionState.INVALID.value elif subscription.type == Subscription.Type.TRIAL: + # Trial subscription is valid for 7 days + if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=7): + return SubscriptionState.EXPIRED.value + return SubscriptionState.TRIAL.value elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc): return SubscriptionState.SUBSCRIBED.value From 48b91161953ac62d102c6c3fa46caf3d7fd93419 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 24 Nov 2023 22:18:00 -0800 Subject: [PATCH 05/12] Fix to use user rather than user_with_token in authenticated credentials --- src/khoj/configure.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index fc64af0f..022f1395 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -82,9 +82,7 @@ class UserAuthenticationBackend(AuthenticationBackend): or subscription_state == SubscriptionState.TRIAL.value ) if subscribed: - return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( - user_with_token.user - ) + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: From dd1badae81f336cd79d7faea1056ea2505d87b99 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Fri, 24 Nov 2023 22:18:45 -0800 Subject: [PATCH 06/12] Use userwithtoken.user when authenticating with an API key --- src/khoj/configure.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 022f1395..6aa747e8 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -106,8 +106,8 @@ class UserAuthenticationBackend(AuthenticationBackend): return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( user_with_token.user ) - return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) - return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) + return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user_with_token.user) if state.anonymous_mode: user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() if user: From b2afbaa31537a344ff96919b9ff71640343f8e34 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 25 Nov 2023 20:28:04 -0800 Subject: [PATCH 07/12] Add support for rate limiting the amount of data indexed - Add a dependency on the indexer API endpoint that rounds up the amount of data indexed and uses that to determine whether the next set of data should be processed - Delete any files that are being removed for adminstering the calculation - Show current amount of data indexed in the config page --- src/khoj/configure.py | 2 + src/khoj/database/adapters/__init__.py | 7 +++ src/khoj/interface/web/config.html | 5 ++- src/khoj/routers/api.py | 4 +- src/khoj/routers/helpers.py | 61 ++++++++++++++++++++++++-- src/khoj/routers/indexer.py | 15 +++++-- src/khoj/routers/web_client.py | 6 ++- tests/test_client.py | 38 ++++++++++++++++ 8 files changed, 127 insertions(+), 11 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index 6aa747e8..bff7e3ca 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -80,6 +80,7 @@ class UserAuthenticationBackend(AuthenticationBackend): subscribed = ( subscription_state == SubscriptionState.SUBSCRIBED.value or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value ) if subscribed: return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) @@ -101,6 +102,7 @@ class UserAuthenticationBackend(AuthenticationBackend): subscribed = ( subscription_state == SubscriptionState.SUBSCRIBED.value or subscription_state == SubscriptionState.TRIAL.value + or subscription_state == SubscriptionState.UNSUBSCRIBED.value ) if subscribed: return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index 7f76b796..bcf4856c 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -1,6 +1,7 @@ import math import random import secrets +import sys from datetime import date, datetime, timezone, timedelta from typing import List, Optional, Type from enum import Enum @@ -474,6 +475,12 @@ class EntryAdapters: async def adelete_all_entries(user: KhojUser): return await Entry.objects.filter(user=user).adelete() + @staticmethod + def get_size_of_indexed_data_in_mb(user: KhojUser): + entries = Entry.objects.filter(user=user).iterator() + total_size = sum(sys.getsizeof(entry.compiled) for entry in entries) + return total_size / 1024 / 1024 + @staticmethod def apply_filters(user: KhojUser, query: str, file_type_filter: str = None): q_filter_terms = Q() diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 318759df..88fbc70d 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -4,6 +4,7 @@

Content

+

{{indexed_data_size_in_mb}} MB used

@@ -171,7 +172,7 @@
- {% if not billing_enabled %} + {% if billing_enabled %}

Billing

@@ -191,7 +192,7 @@

- Subscribe to Khoj Cloud + Subscribe to Khoj Cloud. See pricing for details.

= self.subscribed_num_entries_size: + raise HTTPException(status_code=429, detail="Too much data indexed.") + if not subscribed and incoming_data_size_mb >= self.num_entries_size: + raise HTTPException( + status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit." + ) + + user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user) + if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size: + raise HTTPException(status_code=429, detail="Too much data indexed.") + if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit: + raise HTTPException( + status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit." + ) + + class CommonQueryParamsClass: def __init__( self, diff --git a/src/khoj/routers/indexer.py b/src/khoj/routers/indexer.py index 0432eed0..0c906707 100644 --- a/src/khoj/routers/indexer.py +++ b/src/khoj/routers/indexer.py @@ -2,7 +2,7 @@ import asyncio import logging from typing import Dict, Optional, Union -from fastapi import APIRouter, Header, Request, Response, UploadFile +from fastapi import APIRouter, Header, Request, Response, UploadFile, Depends from pydantic import BaseModel from starlette.authentication import requires @@ -18,6 +18,7 @@ from khoj.search_type import image_search, text_search from khoj.utils import constants, state from khoj.utils.config import ContentIndex, SearchModels from khoj.utils.helpers import LRU, get_file_type +from khoj.routers.helpers import ApiIndexedDataLimiter from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig from khoj.utils.yaml import save_config_to_file_updated_state @@ -53,6 +54,14 @@ async def update( user_agent: Optional[str] = Header(None), referer: Optional[str] = Header(None), host: Optional[str] = Header(None), + indexed_data_limiter: ApiIndexedDataLimiter = Depends( + ApiIndexedDataLimiter( + incoming_entries_size_limit=10, + subscribed_incoming_entries_size_limit=25, + total_entries_size_limit=10, + subscribed_total_entries_size_limit=100, + ) + ), ): user = request.user.object try: @@ -92,7 +101,7 @@ async def update( logger.info("📬 Initializing content index on first run.") default_full_config = FullConfig( content_type=None, - search_type=SearchConfig.parse_obj(constants.default_config["search-type"]), + search_type=SearchConfig.model_validate(constants.default_config["search-type"]), processor=None, ) state.config = default_full_config @@ -116,7 +125,7 @@ async def update( configure_content, state.content_index, state.config.content_type, - indexer_input.dict(), + indexer_input.model_dump(), state.search_models, force, t, diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index bf3ff957..c17704bd 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -1,6 +1,8 @@ # System Packages import json import os +import math +from datetime import timedelta # External Packages from fastapi import APIRouter @@ -137,8 +139,9 @@ def config_page(request: Request): subscription_renewal_date = ( user_subscription.renewal_date.strftime("%d %b %Y") if user_subscription and user_subscription.renewal_date - else None + else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y") ) + indexed_data_size_in_mb = math.ceil(EntryAdapters.get_size_of_indexed_data_in_mb(user)) enabled_content_source = set(EntryAdapters.get_unique_file_sources(user)) successfully_configured = { @@ -169,6 +172,7 @@ def config_page(request: Request): "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), "is_active": has_required_scope(request, ["subscribed"]), "has_documents": has_documents, + "indexed_data_size_in_mb": indexed_data_size_in_mb, }, ) diff --git a/tests/test_client.py b/tests/test_client.py index 19aba03b..98affe27 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -125,6 +125,34 @@ def test_regenerate_with_invalid_content_type(client): assert response.status_code == 422 +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) +def test_index_update_big_files(client): + state.billing_enabled = True + # Arrange + files = get_big_size_sample_files_data() + headers = {"Authorization": "Bearer kk-secret"} + + # Act + response = client.post("/api/v1/index/update", files=files, headers=headers) + + # Assert + assert response.status_code == 429 + + +@pytest.mark.django_db(transaction=True) +def test_index_update_big_files_no_billing(client): + # Arrange + files = get_big_size_sample_files_data() + headers = {"Authorization": "Bearer kk-secret"} + + # Act + response = client.post("/api/v1/index/update", files=files, headers=headers) + + # Assert + assert response.status_code == 200 + + # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) def test_index_update(client): @@ -421,3 +449,13 @@ def get_sample_files_data(): ), ("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")), ] + + +def get_big_size_sample_files_data(): + big_text = "a" * (25 * 1024 * 1024) # a string of approximately 25 MB + return [ + ( + "files", + ("path/to/filename.org", big_text, "text/org"), + ) + ] From 73e38fccf35ebea42ac6b557436a84f3a3f76d9e Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 25 Nov 2023 20:48:32 -0800 Subject: [PATCH 08/12] Explicitly set billing to off in the test for being able to index a large set of data --- tests/test_client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_client.py b/tests/test_client.py index 98affe27..9ba04416 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -128,8 +128,8 @@ def test_regenerate_with_invalid_content_type(client): # ---------------------------------------------------------------------------------------------------- @pytest.mark.django_db(transaction=True) def test_index_update_big_files(client): - state.billing_enabled = True # Arrange + state.billing_enabled = True files = get_big_size_sample_files_data() headers = {"Authorization": "Bearer kk-secret"} @@ -143,6 +143,7 @@ def test_index_update_big_files(client): @pytest.mark.django_db(transaction=True) def test_index_update_big_files_no_billing(client): # Arrange + state.billing_enabled = False files = get_big_size_sample_files_data() headers = {"Authorization": "Bearer kk-secret"} From 52b88de7f465961bdbc8db43e81d4fba493ce129 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sat, 25 Nov 2023 22:31:23 -0800 Subject: [PATCH 09/12] Indicate in the desktop if the user gets rate limited for indexing --- src/interface/desktop/config.html | 3 +++ src/interface/desktop/main.js | 10 ++++++++++ src/interface/desktop/preload.js | 4 ++++ src/interface/desktop/renderer.js | 11 ++++++++++- 4 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/interface/desktop/config.html b/src/interface/desktop/config.html index fb39fbb8..f8ecb06f 100644 --- a/src/interface/desktop/config.html +++ b/src/interface/desktop/config.html @@ -101,6 +101,9 @@

+
diff --git a/src/interface/desktop/main.js b/src/interface/desktop/main.js index eb355a5f..95927be1 100644 --- a/src/interface/desktop/main.js +++ b/src/interface/desktop/main.js @@ -198,6 +198,11 @@ function pushDataToKhoj (regenerate = false) { }) .catch(error => { console.error(error); + if (error.response.status == 429) { + const win = BrowserWindow.getAllWindows()[0]; + if (win) win.webContents.send('needsSubscription', true); + if (win) win.webContents.send('update-state', state); + } state['completed'] = false }) .finally(() => { @@ -396,6 +401,11 @@ app.whenReady().then(() => { event.reply('update-state', arg); }); + ipcMain.on('needsSubscription', (event, arg) => { + console.log(arg); + event.reply('needsSubscription', arg); + }); + ipcMain.on('navigate', (event, page) => { win.loadFile(page); }); diff --git a/src/interface/desktop/preload.js b/src/interface/desktop/preload.js index 1d4c6ec0..8d5152b7 100644 --- a/src/interface/desktop/preload.js +++ b/src/interface/desktop/preload.js @@ -31,6 +31,10 @@ contextBridge.exposeInMainWorld('updateStateAPI', { onUpdateState: (callback) => ipcRenderer.on('update-state', callback) }) +contextBridge.exposeInMainWorld('needsSubscriptionAPI', { + onNeedsSubscription: (callback) => ipcRenderer.on('needsSubscription', callback) +}) + contextBridge.exposeInMainWorld('removeFileAPI', { removeFile: (filePath) => ipcRenderer.invoke('removeFile', filePath) }) diff --git a/src/interface/desktop/renderer.js b/src/interface/desktop/renderer.js index 7d0d906e..16df6d2f 100644 --- a/src/interface/desktop/renderer.js +++ b/src/interface/desktop/renderer.js @@ -1,7 +1,7 @@ const setFolderButton = document.getElementById('update-folder'); const setFileButton = document.getElementById('update-file'); -const showKey = document.getElementById('show-key'); const loadingBar = document.getElementById('loading-bar'); +const needsSubscriptionElement = document.getElementById('needs-subscription'); async function removeFile(filePath) { const updatedFiles = await window.removeFileAPI.removeFile(filePath); @@ -165,6 +165,15 @@ window.updateStateAPI.onUpdateState((event, state) => { syncStatusElement.innerHTML = `⏱️ Synced at ${currentTime.toLocaleTimeString(undefined, options)}. Next sync at ${nextSyncTime.toLocaleTimeString(undefined, options)}.`; }); +window.needsSubscriptionAPI.onNeedsSubscription((event, needsSubscription) => { + console.log("needs subscription", needsSubscription); + if (needsSubscription) { + needsSubscriptionElement.style.display = 'block'; + } else { + needsSubscriptionElement.style.display = 'none'; + } +}); + const urlInput = document.getElementById('khoj-host-url'); (async function() { const url = await window.hostURLAPI.getURL(); From e438853b09e31ea84727ac77ed517a3759c696ac Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 26 Nov 2023 13:09:00 -0800 Subject: [PATCH 10/12] Add additional unit tests to verify behavior of unsubscribed/subscribed users --- tests/conftest.py | 31 +++++++++++++++++++++++++++++++ tests/test_client.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 54c664d5..9a500609 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -102,6 +102,24 @@ def default_user3(): return user +@pytest.mark.django_db +@pytest.fixture +def default_user4(): + """ + This user should not have a valid subscription + """ + if KhojUser.objects.filter(username="default4").exists(): + return KhojUser.objects.get(username="default4") + + user = KhojUser.objects.create( + username="default4", + email="default4@example.com", + password="default4", + ) + SubscriptionFactory(user=user, renewal_date=None) + return user + + @pytest.mark.django_db @pytest.fixture def api_user(default_user): @@ -141,6 +159,19 @@ def api_user3(default_user3): ) +@pytest.mark.django_db +@pytest.fixture +def api_user4(default_user4): + if KhojApiUser.objects.filter(user=default_user4).exists(): + return KhojApiUser.objects.get(user=default_user4) + + return KhojApiUser.objects.create( + user=default_user4, + name="api-key", + token="kk-diff-secret-4", + ) + + @pytest.fixture(scope="session") def search_models(search_config: SearchConfig): search_models = SearchModels() diff --git a/tests/test_client.py b/tests/test_client.py index 9ba04416..f23a350e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -140,6 +140,38 @@ def test_index_update_big_files(client): assert response.status_code == 429 +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) +def test_index_update_medium_file_unsubscribed(client, api_user4: KhojApiUser): + # Arrange + api_token = api_user4.token + state.billing_enabled = True + files = get_medium_size_sample_files_data() + headers = {"Authorization": f"Bearer {api_token}"} + + # Act + response = client.post("/api/v1/index/update", files=files, headers=headers) + + # Assert + assert response.status_code == 429 + + +# ---------------------------------------------------------------------------------------------------- +@pytest.mark.django_db(transaction=True) +def test_index_update_normal_file_unsubscribed(client, api_user4: KhojApiUser): + # Arrange + api_token = api_user4.token + state.billing_enabled = True + files = get_sample_files_data() + headers = {"Authorization": f"Bearer {api_token}"} + + # Act + response = client.post("/api/v1/index/update", files=files, headers=headers) + + # Assert + assert response.status_code == 200 + + @pytest.mark.django_db(transaction=True) def test_index_update_big_files_no_billing(client): # Arrange @@ -460,3 +492,13 @@ def get_big_size_sample_files_data(): ("path/to/filename.org", big_text, "text/org"), ) ] + + +def get_medium_size_sample_files_data(): + big_text = "a" * (10 * 1024 * 1024) # a string of approximately 10 MB + return [ + ( + "files", + ("path/to/filename.org", big_text, "text/org"), + ) + ] From eb5e3096e09771e6a102222f809245fac808855c Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 27 Nov 2023 11:39:20 -0800 Subject: [PATCH 11/12] Change subscribed scope to premium --- src/khoj/configure.py | 10 +++++----- src/khoj/routers/helpers.py | 4 ++-- src/khoj/routers/web_client.py | 16 ++++++++-------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/khoj/configure.py b/src/khoj/configure.py index bff7e3ca..19e7d403 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -83,9 +83,9 @@ class UserAuthenticationBackend(AuthenticationBackend): or subscription_state == SubscriptionState.UNSUBSCRIBED.value ) if subscribed: - return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user) - return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) if len(request.headers.get("Authorization", "").split("Bearer ")) == 2: # Get bearer token from header bearer_token = request.headers["Authorization"].split("Bearer ")[1] @@ -105,15 +105,15 @@ class UserAuthenticationBackend(AuthenticationBackend): or subscription_state == SubscriptionState.UNSUBSCRIBED.value ) if subscribed: - return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser( user_with_token.user ) return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user) - return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user_with_token.user) + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user) if state.anonymous_mode: user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst() if user: - return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) + return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user) return AuthCredentials(), UnauthenticatedUser() diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 39448d1a..1dd3f4c7 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -280,7 +280,7 @@ class ApiUserRateLimiter: def __call__(self, request: Request): user: KhojUser = request.user.object - subscribed = has_required_scope(request, ["subscribed"]) + subscribed = has_required_scope(request, ["premium"]) user_requests = self.cache[user.uuid] # Remove requests outside of the time window @@ -314,7 +314,7 @@ class ApiIndexedDataLimiter: def __call__(self, request: Request, files: List[UploadFile]): if state.billing_enabled is False: return - subscribed = has_required_scope(request, ["subscribed"]) + subscribed = has_required_scope(request, ["premium"]) incoming_data_size_mb = 0 deletion_file_names = set() diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index c17704bd..8ce9dbe3 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -47,7 +47,7 @@ def index(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": has_required_scope(request, ["subscribed"]), + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -66,7 +66,7 @@ def index_post(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": has_required_scope(request, ["subscribed"]), + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -85,7 +85,7 @@ def search_page(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": has_required_scope(request, ["subscribed"]), + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -104,7 +104,7 @@ def chat_page(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": has_required_scope(request, ["subscribed"]), + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -170,7 +170,7 @@ def config_page(request: Request): "subscription_state": user_subscription_state, "subscription_renewal_date": subscription_renewal_date, "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), - "is_active": has_required_scope(request, ["subscribed"]), + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, "indexed_data_size_in_mb": indexed_data_size_in_mb, }, @@ -211,7 +211,7 @@ def github_config_page(request: Request): "current_config": current_config, "username": user.username, "user_photo": user_picture, - "is_active": has_required_scope(request, ["subscribed"]), + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -238,7 +238,7 @@ def notion_config_page(request: Request): "current_config": current_config, "username": user.username, "user_photo": user_picture, - "is_active": has_required_scope(request, ["subscribed"]), + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) @@ -257,7 +257,7 @@ def computer_config_page(request: Request): "request": request, "username": user.username, "user_photo": user_picture, - "is_active": has_required_scope(request, ["subscribed"]), + "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, }, ) From 6290b463f518a7d2f22954c05d54a2d499256e7d Mon Sep 17 00:00:00 2001 From: sabaimran Date: Mon, 27 Nov 2023 12:05:00 -0800 Subject: [PATCH 12/12] Compute size of the indexed data only if explicitly requested to avoid heavy load on the DB --- src/khoj/interface/web/config.html | 14 +++++++++++++- src/khoj/routers/api.py | 12 ++++++++++++ src/khoj/routers/web_client.py | 2 -- 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/src/khoj/interface/web/config.html b/src/khoj/interface/web/config.html index 88fbc70d..96d82131 100644 --- a/src/khoj/interface/web/config.html +++ b/src/khoj/interface/web/config.html @@ -4,7 +4,10 @@

Content

-

{{indexed_data_size_in_mb}} MB used

+ +

@@ -472,6 +475,15 @@ }); } + function getIndexedDataSize() { + document.getElementById("indexed-data-size").innerHTML = "Calculating..."; + fetch('/api/config/index/size') + .then(response => response.json()) + .then(data => { + document.getElementById("indexed-data-size").innerHTML = data.indexed_data_size_in_mb + " MB used"; + }); + } + // List user's API keys on page load listApiKeys(); diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index cb0606f1..ae125980 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -334,6 +334,18 @@ def get_default_config_data(): return constants.empty_config +@api.get("/config/index/size", response_model=Dict[str, int]) +@requires(["authenticated"]) +async def get_indexed_data_size(request: Request, common: CommonQueryParams): + user = request.user.object + indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user) + return Response( + content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}), + media_type="application/json", + status_code=200, + ) + + @api.get("/config/types", response_model=List[str]) @requires(["authenticated"]) def get_config_types( diff --git a/src/khoj/routers/web_client.py b/src/khoj/routers/web_client.py index 8ce9dbe3..7907f99e 100644 --- a/src/khoj/routers/web_client.py +++ b/src/khoj/routers/web_client.py @@ -141,7 +141,6 @@ def config_page(request: Request): if user_subscription and user_subscription.renewal_date else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y") ) - indexed_data_size_in_mb = math.ceil(EntryAdapters.get_size_of_indexed_data_in_mb(user)) enabled_content_source = set(EntryAdapters.get_unique_file_sources(user)) successfully_configured = { @@ -172,7 +171,6 @@ def config_page(request: Request): "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), "is_active": has_required_scope(request, ["premium"]), "has_documents": has_documents, - "indexed_data_size_in_mb": indexed_data_size_in_mb, }, )