Merge pull request #569 from khoj-ai/features/enforce-subscription-status

Enforce subscription state on the chat API access
This commit is contained in:
sabaimran 2023-11-27 16:12:26 -08:00 committed by GitHub
commit 24b5aaef0a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 368 additions and 55 deletions

View file

@ -361,12 +361,25 @@
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.detail) {
newResponseText.innerHTML += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
newResponseText.innerHTML += chunk;
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
readStream();
}
}
// Scroll to bottom of chat window as chat response is streamed

View file

@ -101,6 +101,9 @@
<div class="card-description-row">
<div id="sync-status"></div>
</div>
<div id="needs-subscription" style="display: none;">
Looks like you're out of space to sync your files. <a href="https://app.khoj.dev/config">Upgrade your plan</a> to unlock more space.
</div>
</div>
</body>

View file

@ -198,6 +198,11 @@ function pushDataToKhoj (regenerate = false) {
})
.catch(error => {
console.error(error);
if (error.response.status == 429) {
const win = BrowserWindow.getAllWindows()[0];
if (win) win.webContents.send('needsSubscription', true);
if (win) win.webContents.send('update-state', state);
}
state['completed'] = false
})
.finally(() => {
@ -396,6 +401,11 @@ app.whenReady().then(() => {
event.reply('update-state', arg);
});
ipcMain.on('needsSubscription', (event, arg) => {
console.log(arg);
event.reply('needsSubscription', arg);
});
ipcMain.on('navigate', (event, page) => {
win.loadFile(page);
});

View file

@ -31,6 +31,10 @@ contextBridge.exposeInMainWorld('updateStateAPI', {
onUpdateState: (callback) => ipcRenderer.on('update-state', callback)
})
contextBridge.exposeInMainWorld('needsSubscriptionAPI', {
onNeedsSubscription: (callback) => ipcRenderer.on('needsSubscription', callback)
})
contextBridge.exposeInMainWorld('removeFileAPI', {
removeFile: (filePath) => ipcRenderer.invoke('removeFile', filePath)
})

View file

@ -1,7 +1,7 @@
const setFolderButton = document.getElementById('update-folder');
const setFileButton = document.getElementById('update-file');
const showKey = document.getElementById('show-key');
const loadingBar = document.getElementById('loading-bar');
const needsSubscriptionElement = document.getElementById('needs-subscription');
async function removeFile(filePath) {
const updatedFiles = await window.removeFileAPI.removeFile(filePath);
@ -165,6 +165,15 @@ window.updateStateAPI.onUpdateState((event, state) => {
syncStatusElement.innerHTML = `⏱️ Synced at ${currentTime.toLocaleTimeString(undefined, options)}. Next sync at ${nextSyncTime.toLocaleTimeString(undefined, options)}.`;
});
window.needsSubscriptionAPI.onNeedsSubscription((event, needsSubscription) => {
console.log("needs subscription", needsSubscription);
if (needsSubscription) {
needsSubscriptionElement.style.display = 'block';
} else {
needsSubscriptionElement.style.display = 'none';
}
});
const urlInput = document.getElementById('khoj-host-url');
(async function() {
const url = await window.hostURLAPI.getURL();

View file

@ -21,7 +21,12 @@ from starlette.authentication import (
# Internal Packages
from khoj.database.models import KhojUser, Subscription
from khoj.database.adapters import get_all_users, get_or_create_search_model
from khoj.database.adapters import (
get_all_users,
get_or_create_search_model,
aget_user_subscription_state,
SubscriptionState,
)
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
from khoj.routers.indexer import configure_content, load_content, configure_search
from khoj.utils import constants, state
@ -70,7 +75,17 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
if state.billing_enabled:
subscription_state = await aget_user_subscription_state(user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
# Get bearer token from header
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
@ -82,11 +97,23 @@ class UserAuthenticationBackend(AuthenticationBackend):
.afirst()
)
if user_with_token:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
if state.billing_enabled:
subscription_state = await aget_user_subscription_state(user_with_token.user)
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(
user_with_token.user
)
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
if state.anonymous_mode:
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
if user:
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
return AuthCredentials(), UnauthenticatedUser()

View file

@ -1,8 +1,10 @@
import math
import random
import secrets
from datetime import date, datetime, timezone
import sys
from datetime import date, datetime, timezone, timedelta
from typing import List, Optional, Type
from enum import Enum
from asgiref.sync import sync_to_async
from django.contrib.sessions.backends.db import SessionStore
@ -41,6 +43,14 @@ from khoj.utils.config import GPT4AllProcessorModel
from khoj.utils.helpers import generate_random_name
class SubscriptionState(Enum):
TRIAL = "trial"
SUBSCRIBED = "subscribed"
UNSUBSCRIBED = "unsubscribed"
EXPIRED = "expired"
INVALID = "invalid"
async def set_notion_config(token: str, user: KhojUser):
notion_config = await NotionConfig.objects.filter(user=user).afirst()
if not notion_config:
@ -128,22 +138,38 @@ async def set_user_subscription(
return None
def subscription_to_state(subscription: Subscription) -> str:
if not subscription:
return SubscriptionState.INVALID.value
elif subscription.type == Subscription.Type.TRIAL:
# Trial subscription is valid for 7 days
if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=7):
return SubscriptionState.EXPIRED.value
return SubscriptionState.TRIAL.value
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
return SubscriptionState.SUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
return SubscriptionState.UNSUBSCRIBED.value
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
return SubscriptionState.EXPIRED.value
return SubscriptionState.INVALID.value
def get_user_subscription_state(email: str) -> str:
"""Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
"""
user_subscription = Subscription.objects.filter(user__email=email).first()
if not user_subscription:
return "trial"
elif user_subscription.type == Subscription.Type.TRIAL:
return "trial"
elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
return "subscribed"
elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
return "unsubscribed"
elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc):
return "expired"
return "invalid"
return subscription_to_state(user_subscription)
async def aget_user_subscription_state(email: str) -> str:
"""Get subscription state of user
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
"""
user_subscription = await Subscription.objects.filter(user__email=email).afirst()
return subscription_to_state(user_subscription)
async def get_user_by_email(email: str) -> KhojUser:
@ -458,6 +484,12 @@ class EntryAdapters:
async def adelete_all_entries(user: KhojUser):
return await Entry.objects.filter(user=user).adelete()
@staticmethod
def get_size_of_indexed_data_in_mb(user: KhojUser):
entries = Entry.objects.filter(user=user).iterator()
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
return total_size / 1024 / 1024
@staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
q_filter_terms = Q()

View file

@ -402,10 +402,24 @@ To get started, just start typing below. You can also type / to see a list of co
newResponseText.removeChild(loadingSpinner);
}
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
if (chunk.startsWith("{") && chunk.endsWith("}")) {
try {
const responseAsJson = JSON.parse(chunk);
if (responseAsJson.detail) {
newResponseText.innerHTML += responseAsJson.detail;
}
} catch (error) {
// If the chunk is not a JSON object, just display it as is
newResponseText.innerHTML += chunk;
}
} else {
// If the chunk is not a JSON object, just display it as is
rawResponse += chunk;
newResponseText.innerHTML = "";
newResponseText.appendChild(formatHTMLMessage(rawResponse));
readStream();
}
}
// Scroll to bottom of chat window as chat response is streamed

View file

@ -4,6 +4,10 @@
<div class="page">
<div id="content" class="section">
<h2 class="section-title">Content</h2>
<button id="compute-index-size" class="card-button" onclick="getIndexedDataSize()">
Data Usage
</button>
<p id="indexed-data-size" class="card-description"></p>
<div class="section-cards">
<div class="card">
<div class="card-title-row">
@ -191,7 +195,7 @@
<p id="trial-description"
class="card-description"
style="display: {% if subscription_state != 'trial' %}none{% endif %}">
Subscribe to Khoj Cloud
Subscribe to Khoj Cloud. See <a href="https://khoj.dev/pricing">pricing</a> for details.
</p>
<p id="unsubscribe-description"
class="card-description"
@ -471,6 +475,15 @@
});
}
function getIndexedDataSize() {
document.getElementById("indexed-data-size").innerHTML = "Calculating...";
fetch('/api/config/index/size')
.then(response => response.json())
.then(data => {
document.getElementById("indexed-data-size").innerHTML = data.indexed_data_size_in_mb + " MB used";
});
}
// List user's API keys on page load
listApiKeys();

View file

@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional, Union
import uuid
# External Packages
from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from asgiref.sync import sync_to_async
from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires
@ -334,6 +334,18 @@ def get_default_config_data():
return constants.empty_config
@api.get("/config/index/size", response_model=Dict[str, int])
@requires(["authenticated"])
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
user = request.user.object
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
return Response(
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
media_type="application/json",
status_code=200,
)
@api.get("/config/types", response_model=List[str])
@requires(["authenticated"])
def get_config_types(
@ -650,8 +662,8 @@ async def chat(
n: Optional[int] = 5,
d: Optional[float] = 0.18,
stream: Optional[bool] = False,
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
) -> Response:
user = request.user.object

View file

@ -9,10 +9,12 @@ from functools import partial
from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
# External Packages
from fastapi import Depends, Header, HTTPException, Request
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from starlette.authentication import has_required_scope
from asgiref.sync import sync_to_async
from khoj.database.adapters import ConversationAdapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser, Subscription
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
@ -270,13 +272,15 @@ def generate_chat_response(
class ApiUserRateLimiter:
def __init__(self, requests: int, window: int):
def __init__(self, requests: int, subscribed_requests: int, window: int):
self.requests = requests
self.subscribed_requests = subscribed_requests
self.window = window
self.cache: dict[str, list[float]] = defaultdict(list)
def __call__(self, request: Request):
user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
user_requests = self.cache[user.uuid]
# Remove requests outside of the time window
@ -285,13 +289,69 @@ class ApiUserRateLimiter:
user_requests.pop(0)
# Check if the user has exceeded the rate limit
if len(user_requests) >= self.requests:
if subscribed and len(user_requests) >= self.subscribed_requests:
raise HTTPException(status_code=429, detail="Too Many Requests")
if not subscribed and len(user_requests) >= self.requests:
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
# Add the current request to the cache
user_requests.append(time())
class ApiIndexedDataLimiter:
def __init__(
self,
incoming_entries_size_limit: float,
subscribed_incoming_entries_size_limit: float,
total_entries_size_limit: float,
subscribed_total_entries_size_limit: float,
):
self.num_entries_size = incoming_entries_size_limit
self.subscribed_num_entries_size = subscribed_incoming_entries_size_limit
self.total_entries_size_limit = total_entries_size_limit
self.subscribed_total_entries_size = subscribed_total_entries_size_limit
def __call__(self, request: Request, files: List[UploadFile]):
if state.billing_enabled is False:
return
subscribed = has_required_scope(request, ["premium"])
incoming_data_size_mb = 0
deletion_file_names = set()
if not request.user.is_authenticated:
return
user: KhojUser = request.user.object
for file in files:
if file.size == 0:
deletion_file_names.add(file.filename)
incoming_data_size_mb += file.size / 1024 / 1024
num_deleted_entries = 0
for file_path in deletion_file_names:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_entries += deleted_count
logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.")
if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size:
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and incoming_data_size_mb >= self.num_entries_size:
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user)
if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size:
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit:
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
class CommonQueryParamsClass:
def __init__(
self,

View file

@ -2,7 +2,7 @@ import asyncio
import logging
from typing import Dict, Optional, Union
from fastapi import APIRouter, Header, Request, Response, UploadFile
from fastapi import APIRouter, Header, Request, Response, UploadFile, Depends
from pydantic import BaseModel
from starlette.authentication import requires
@ -18,6 +18,7 @@ from khoj.search_type import image_search, text_search
from khoj.utils import constants, state
from khoj.utils.config import ContentIndex, SearchModels
from khoj.utils.helpers import LRU, get_file_type
from khoj.routers.helpers import ApiIndexedDataLimiter
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
from khoj.utils.yaml import save_config_to_file_updated_state
@ -53,6 +54,14 @@ async def update(
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
ApiIndexedDataLimiter(
incoming_entries_size_limit=10,
subscribed_incoming_entries_size_limit=25,
total_entries_size_limit=10,
subscribed_total_entries_size_limit=100,
)
),
):
user = request.user.object
try:
@ -92,7 +101,7 @@ async def update(
logger.info("📬 Initializing content index on first run.")
default_full_config = FullConfig(
content_type=None,
search_type=SearchConfig.parse_obj(constants.default_config["search-type"]),
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
processor=None,
)
state.config = default_full_config
@ -116,7 +125,7 @@ async def update(
configure_content,
state.content_index,
state.config.content_type,
indexer_input.dict(),
indexer_input.model_dump(),
state.search_models,
force,
t,

View file

@ -1,13 +1,15 @@
# System Packages
import json
import os
import math
from datetime import timedelta
# External Packages
from fastapi import APIRouter
from fastapi import Request
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from starlette.authentication import requires
from starlette.authentication import requires, has_required_scope
from khoj.database import adapters
from khoj.database.models import KhojUser
from khoj.utils.rawconfig import (
@ -37,7 +39,6 @@ templates = Jinja2Templates(directory=constants.web_directory)
def index(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@ -46,7 +47,7 @@ def index(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents,
},
)
@ -57,7 +58,6 @@ def index(request: Request):
def index_post(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@ -66,7 +66,7 @@ def index_post(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents,
},
)
@ -77,7 +77,6 @@ def index_post(request: Request):
def search_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@ -86,7 +85,7 @@ def search_page(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents,
},
)
@ -97,7 +96,6 @@ def search_page(request: Request):
def chat_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@ -106,7 +104,7 @@ def chat_page(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents,
},
)
@ -141,7 +139,7 @@ def config_page(request: Request):
subscription_renewal_date = (
user_subscription.renewal_date.strftime("%d %b %Y")
if user_subscription and user_subscription.renewal_date
else None
else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y")
)
enabled_content_source = set(EntryAdapters.get_unique_file_sources(user))
@ -171,7 +169,7 @@ def config_page(request: Request):
"subscription_state": user_subscription_state,
"subscription_renewal_date": subscription_renewal_date,
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents,
},
)
@ -182,7 +180,6 @@ def config_page(request: Request):
def github_config_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
current_github_config = get_user_github_config(user)
@ -212,7 +209,7 @@ def github_config_page(request: Request):
"current_config": current_config,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents,
},
)
@ -223,7 +220,6 @@ def github_config_page(request: Request):
def notion_config_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = adapters.get_user_subscription(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
current_notion_config = get_user_notion_config(user)
@ -240,7 +236,7 @@ def notion_config_page(request: Request):
"current_config": current_config,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents,
},
)
@ -251,7 +247,6 @@ def notion_config_page(request: Request):
def computer_config_page(request: Request):
user = request.user.object
user_picture = request.session.get("user", {}).get("picture")
user_subscription_state = get_user_subscription_state(user.email)
has_documents = EntryAdapters.user_has_entries(user=user)
return templates.TemplateResponse(
@ -260,7 +255,7 @@ def computer_config_page(request: Request):
"request": request,
"username": user.username,
"user_photo": user_picture,
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
"is_active": has_required_scope(request, ["premium"]),
"has_documents": has_documents,
},
)

View file

@ -102,6 +102,24 @@ def default_user3():
return user
@pytest.mark.django_db
@pytest.fixture
def default_user4():
"""
This user should not have a valid subscription
"""
if KhojUser.objects.filter(username="default4").exists():
return KhojUser.objects.get(username="default4")
user = KhojUser.objects.create(
username="default4",
email="default4@example.com",
password="default4",
)
SubscriptionFactory(user=user, renewal_date=None)
return user
@pytest.mark.django_db
@pytest.fixture
def api_user(default_user):
@ -141,6 +159,19 @@ def api_user3(default_user3):
)
@pytest.mark.django_db
@pytest.fixture
def api_user4(default_user4):
if KhojApiUser.objects.filter(user=default_user4).exists():
return KhojApiUser.objects.get(user=default_user4)
return KhojApiUser.objects.create(
user=default_user4,
name="api-key",
token="kk-diff-secret-4",
)
@pytest.fixture(scope="session")
def search_models(search_config: SearchConfig):
search_models = SearchModels()

View file

@ -125,6 +125,67 @@ def test_regenerate_with_invalid_content_type(client):
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_big_files(client):
# Arrange
state.billing_enabled = True
files = get_big_size_sample_files_data()
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 429
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_medium_file_unsubscribed(client, api_user4: KhojApiUser):
# Arrange
api_token = api_user4.token
state.billing_enabled = True
files = get_medium_size_sample_files_data()
headers = {"Authorization": f"Bearer {api_token}"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 429
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_normal_file_unsubscribed(client, api_user4: KhojApiUser):
# Arrange
api_token = api_user4.token
state.billing_enabled = True
files = get_sample_files_data()
headers = {"Authorization": f"Bearer {api_token}"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 200
@pytest.mark.django_db(transaction=True)
def test_index_update_big_files_no_billing(client):
# Arrange
state.billing_enabled = False
files = get_big_size_sample_files_data()
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 200
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update(client):
@ -421,3 +482,23 @@ def get_sample_files_data():
),
("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")),
]
def get_big_size_sample_files_data():
big_text = "a" * (25 * 1024 * 1024) # a string of approximately 25 MB
return [
(
"files",
("path/to/filename.org", big_text, "text/org"),
)
]
def get_medium_size_sample_files_data():
big_text = "a" * (10 * 1024 * 1024) # a string of approximately 10 MB
return [
(
"files",
("path/to/filename.org", big_text, "text/org"),
)
]