mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Merge pull request #569 from khoj-ai/features/enforce-subscription-status
Enforce subscription state on the chat API access
This commit is contained in:
commit
24b5aaef0a
15 changed files with 368 additions and 55 deletions
|
@ -361,12 +361,25 @@
|
|||
if (newResponseText.getElementsByClassName("spinner").length > 0) {
|
||||
newResponseText.removeChild(loadingSpinner);
|
||||
}
|
||||
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
|
||||
if (chunk.startsWith("{") && chunk.endsWith("}")) {
|
||||
try {
|
||||
const responseAsJson = JSON.parse(chunk);
|
||||
if (responseAsJson.detail) {
|
||||
newResponseText.innerHTML += responseAsJson.detail;
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
newResponseText.innerHTML += chunk;
|
||||
}
|
||||
} else {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
rawResponse += chunk;
|
||||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
|
||||
rawResponse += chunk;
|
||||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
|
||||
readStream();
|
||||
readStream();
|
||||
}
|
||||
}
|
||||
|
||||
// Scroll to bottom of chat window as chat response is streamed
|
||||
|
|
|
@ -101,6 +101,9 @@
|
|||
<div class="card-description-row">
|
||||
<div id="sync-status"></div>
|
||||
</div>
|
||||
<div id="needs-subscription" style="display: none;">
|
||||
Looks like you're out of space to sync your files. <a href="https://app.khoj.dev/config">Upgrade your plan</a> to unlock more space.
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
|
|
|
@ -198,6 +198,11 @@ function pushDataToKhoj (regenerate = false) {
|
|||
})
|
||||
.catch(error => {
|
||||
console.error(error);
|
||||
if (error.response.status == 429) {
|
||||
const win = BrowserWindow.getAllWindows()[0];
|
||||
if (win) win.webContents.send('needsSubscription', true);
|
||||
if (win) win.webContents.send('update-state', state);
|
||||
}
|
||||
state['completed'] = false
|
||||
})
|
||||
.finally(() => {
|
||||
|
@ -396,6 +401,11 @@ app.whenReady().then(() => {
|
|||
event.reply('update-state', arg);
|
||||
});
|
||||
|
||||
ipcMain.on('needsSubscription', (event, arg) => {
|
||||
console.log(arg);
|
||||
event.reply('needsSubscription', arg);
|
||||
});
|
||||
|
||||
ipcMain.on('navigate', (event, page) => {
|
||||
win.loadFile(page);
|
||||
});
|
||||
|
|
|
@ -31,6 +31,10 @@ contextBridge.exposeInMainWorld('updateStateAPI', {
|
|||
onUpdateState: (callback) => ipcRenderer.on('update-state', callback)
|
||||
})
|
||||
|
||||
contextBridge.exposeInMainWorld('needsSubscriptionAPI', {
|
||||
onNeedsSubscription: (callback) => ipcRenderer.on('needsSubscription', callback)
|
||||
})
|
||||
|
||||
contextBridge.exposeInMainWorld('removeFileAPI', {
|
||||
removeFile: (filePath) => ipcRenderer.invoke('removeFile', filePath)
|
||||
})
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
const setFolderButton = document.getElementById('update-folder');
|
||||
const setFileButton = document.getElementById('update-file');
|
||||
const showKey = document.getElementById('show-key');
|
||||
const loadingBar = document.getElementById('loading-bar');
|
||||
const needsSubscriptionElement = document.getElementById('needs-subscription');
|
||||
|
||||
async function removeFile(filePath) {
|
||||
const updatedFiles = await window.removeFileAPI.removeFile(filePath);
|
||||
|
@ -165,6 +165,15 @@ window.updateStateAPI.onUpdateState((event, state) => {
|
|||
syncStatusElement.innerHTML = `⏱️ Synced at ${currentTime.toLocaleTimeString(undefined, options)}. Next sync at ${nextSyncTime.toLocaleTimeString(undefined, options)}.`;
|
||||
});
|
||||
|
||||
window.needsSubscriptionAPI.onNeedsSubscription((event, needsSubscription) => {
|
||||
console.log("needs subscription", needsSubscription);
|
||||
if (needsSubscription) {
|
||||
needsSubscriptionElement.style.display = 'block';
|
||||
} else {
|
||||
needsSubscriptionElement.style.display = 'none';
|
||||
}
|
||||
});
|
||||
|
||||
const urlInput = document.getElementById('khoj-host-url');
|
||||
(async function() {
|
||||
const url = await window.hostURLAPI.getURL();
|
||||
|
|
|
@ -21,7 +21,12 @@ from starlette.authentication import (
|
|||
|
||||
# Internal Packages
|
||||
from khoj.database.models import KhojUser, Subscription
|
||||
from khoj.database.adapters import get_all_users, get_or_create_search_model
|
||||
from khoj.database.adapters import (
|
||||
get_all_users,
|
||||
get_or_create_search_model,
|
||||
aget_user_subscription_state,
|
||||
SubscriptionState,
|
||||
)
|
||||
from khoj.processor.embeddings import CrossEncoderModel, EmbeddingsModel
|
||||
from khoj.routers.indexer import configure_content, load_content, configure_search
|
||||
from khoj.utils import constants, state
|
||||
|
@ -70,7 +75,17 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
.afirst()
|
||||
)
|
||||
if user:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
if state.billing_enabled:
|
||||
subscription_state = await aget_user_subscription_state(user)
|
||||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||
if len(request.headers.get("Authorization", "").split("Bearer ")) == 2:
|
||||
# Get bearer token from header
|
||||
bearer_token = request.headers["Authorization"].split("Bearer ")[1]
|
||||
|
@ -82,11 +97,23 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
.afirst()
|
||||
)
|
||||
if user_with_token:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
if state.billing_enabled:
|
||||
subscription_state = await aget_user_subscription_state(user_with_token.user)
|
||||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(
|
||||
user_with_token.user
|
||||
)
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user_with_token.user)
|
||||
if state.anonymous_mode:
|
||||
user = await self.khojuser_manager.filter(username="default").prefetch_related("subscription").afirst()
|
||||
if user:
|
||||
return AuthCredentials(["authenticated"]), AuthenticatedKhojUser(user)
|
||||
return AuthCredentials(["authenticated", "premium"]), AuthenticatedKhojUser(user)
|
||||
|
||||
return AuthCredentials(), UnauthenticatedUser()
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import math
|
||||
import random
|
||||
import secrets
|
||||
from datetime import date, datetime, timezone
|
||||
import sys
|
||||
from datetime import date, datetime, timezone, timedelta
|
||||
from typing import List, Optional, Type
|
||||
from enum import Enum
|
||||
|
||||
from asgiref.sync import sync_to_async
|
||||
from django.contrib.sessions.backends.db import SessionStore
|
||||
|
@ -41,6 +43,14 @@ from khoj.utils.config import GPT4AllProcessorModel
|
|||
from khoj.utils.helpers import generate_random_name
|
||||
|
||||
|
||||
class SubscriptionState(Enum):
|
||||
TRIAL = "trial"
|
||||
SUBSCRIBED = "subscribed"
|
||||
UNSUBSCRIBED = "unsubscribed"
|
||||
EXPIRED = "expired"
|
||||
INVALID = "invalid"
|
||||
|
||||
|
||||
async def set_notion_config(token: str, user: KhojUser):
|
||||
notion_config = await NotionConfig.objects.filter(user=user).afirst()
|
||||
if not notion_config:
|
||||
|
@ -128,22 +138,38 @@ async def set_user_subscription(
|
|||
return None
|
||||
|
||||
|
||||
def subscription_to_state(subscription: Subscription) -> str:
|
||||
if not subscription:
|
||||
return SubscriptionState.INVALID.value
|
||||
elif subscription.type == Subscription.Type.TRIAL:
|
||||
# Trial subscription is valid for 7 days
|
||||
if datetime.now(tz=timezone.utc) - subscription.created_at > timedelta(days=7):
|
||||
return SubscriptionState.EXPIRED.value
|
||||
|
||||
return SubscriptionState.TRIAL.value
|
||||
elif subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
return SubscriptionState.SUBSCRIBED.value
|
||||
elif not subscription.is_recurring and subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
return SubscriptionState.UNSUBSCRIBED.value
|
||||
elif not subscription.is_recurring and subscription.renewal_date < datetime.now(tz=timezone.utc):
|
||||
return SubscriptionState.EXPIRED.value
|
||||
return SubscriptionState.INVALID.value
|
||||
|
||||
|
||||
def get_user_subscription_state(email: str) -> str:
|
||||
"""Get subscription state of user
|
||||
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
||||
"""
|
||||
user_subscription = Subscription.objects.filter(user__email=email).first()
|
||||
if not user_subscription:
|
||||
return "trial"
|
||||
elif user_subscription.type == Subscription.Type.TRIAL:
|
||||
return "trial"
|
||||
elif user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
return "subscribed"
|
||||
elif not user_subscription.is_recurring and user_subscription.renewal_date >= datetime.now(tz=timezone.utc):
|
||||
return "unsubscribed"
|
||||
elif not user_subscription.is_recurring and user_subscription.renewal_date < datetime.now(tz=timezone.utc):
|
||||
return "expired"
|
||||
return "invalid"
|
||||
return subscription_to_state(user_subscription)
|
||||
|
||||
|
||||
async def aget_user_subscription_state(email: str) -> str:
|
||||
"""Get subscription state of user
|
||||
Valid state transitions: trial -> subscribed <-> unsubscribed OR expired
|
||||
"""
|
||||
user_subscription = await Subscription.objects.filter(user__email=email).afirst()
|
||||
return subscription_to_state(user_subscription)
|
||||
|
||||
|
||||
async def get_user_by_email(email: str) -> KhojUser:
|
||||
|
@ -458,6 +484,12 @@ class EntryAdapters:
|
|||
async def adelete_all_entries(user: KhojUser):
|
||||
return await Entry.objects.filter(user=user).adelete()
|
||||
|
||||
@staticmethod
|
||||
def get_size_of_indexed_data_in_mb(user: KhojUser):
|
||||
entries = Entry.objects.filter(user=user).iterator()
|
||||
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
|
||||
return total_size / 1024 / 1024
|
||||
|
||||
@staticmethod
|
||||
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
|
||||
q_filter_terms = Q()
|
||||
|
|
|
@ -402,10 +402,24 @@ To get started, just start typing below. You can also type / to see a list of co
|
|||
newResponseText.removeChild(loadingSpinner);
|
||||
}
|
||||
|
||||
rawResponse += chunk;
|
||||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
readStream();
|
||||
// Try to parse the chunk as a JSON object. It will be a JSON object if there is an error.
|
||||
if (chunk.startsWith("{") && chunk.endsWith("}")) {
|
||||
try {
|
||||
const responseAsJson = JSON.parse(chunk);
|
||||
if (responseAsJson.detail) {
|
||||
newResponseText.innerHTML += responseAsJson.detail;
|
||||
}
|
||||
} catch (error) {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
newResponseText.innerHTML += chunk;
|
||||
}
|
||||
} else {
|
||||
// If the chunk is not a JSON object, just display it as is
|
||||
rawResponse += chunk;
|
||||
newResponseText.innerHTML = "";
|
||||
newResponseText.appendChild(formatHTMLMessage(rawResponse));
|
||||
readStream();
|
||||
}
|
||||
}
|
||||
|
||||
// Scroll to bottom of chat window as chat response is streamed
|
||||
|
|
|
@ -4,6 +4,10 @@
|
|||
<div class="page">
|
||||
<div id="content" class="section">
|
||||
<h2 class="section-title">Content</h2>
|
||||
<button id="compute-index-size" class="card-button" onclick="getIndexedDataSize()">
|
||||
Data Usage
|
||||
</button>
|
||||
<p id="indexed-data-size" class="card-description"></p>
|
||||
<div class="section-cards">
|
||||
<div class="card">
|
||||
<div class="card-title-row">
|
||||
|
@ -191,7 +195,7 @@
|
|||
<p id="trial-description"
|
||||
class="card-description"
|
||||
style="display: {% if subscription_state != 'trial' %}none{% endif %}">
|
||||
Subscribe to Khoj Cloud
|
||||
Subscribe to Khoj Cloud. See <a href="https://khoj.dev/pricing">pricing</a> for details.
|
||||
</p>
|
||||
<p id="unsubscribe-description"
|
||||
class="card-description"
|
||||
|
@ -471,6 +475,15 @@
|
|||
});
|
||||
}
|
||||
|
||||
function getIndexedDataSize() {
|
||||
document.getElementById("indexed-data-size").innerHTML = "Calculating...";
|
||||
fetch('/api/config/index/size')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
document.getElementById("indexed-data-size").innerHTML = data.indexed_data_size_in_mb + " MB used";
|
||||
});
|
||||
}
|
||||
|
||||
// List user's API keys on page load
|
||||
listApiKeys();
|
||||
|
||||
|
|
|
@ -9,8 +9,8 @@ from typing import Any, Dict, List, Optional, Union
|
|||
import uuid
|
||||
|
||||
# External Packages
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from starlette.authentication import requires
|
||||
|
@ -334,6 +334,18 @@ def get_default_config_data():
|
|||
return constants.empty_config
|
||||
|
||||
|
||||
@api.get("/config/index/size", response_model=Dict[str, int])
|
||||
@requires(["authenticated"])
|
||||
async def get_indexed_data_size(request: Request, common: CommonQueryParams):
|
||||
user = request.user.object
|
||||
indexed_data_size_in_mb = await sync_to_async(EntryAdapters.get_size_of_indexed_data_in_mb)(user)
|
||||
return Response(
|
||||
content=json.dumps({"indexed_data_size_in_mb": math.ceil(indexed_data_size_in_mb)}),
|
||||
media_type="application/json",
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
@api.get("/config/types", response_model=List[str])
|
||||
@requires(["authenticated"])
|
||||
def get_config_types(
|
||||
|
@ -650,8 +662,8 @@ async def chat(
|
|||
n: Optional[int] = 5,
|
||||
d: Optional[float] = 0.18,
|
||||
stream: Optional[bool] = False,
|
||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=30, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=500, window=60 * 60 * 24)),
|
||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=60, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||
) -> Response:
|
||||
user = request.user.object
|
||||
|
||||
|
|
|
@ -9,10 +9,12 @@ from functools import partial
|
|||
from time import time
|
||||
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
# External Packages
|
||||
from fastapi import Depends, Header, HTTPException, Request
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from starlette.authentication import has_required_scope
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.models import KhojUser, Subscription
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.chat_model import converse_offline, send_message_to_model_offline
|
||||
|
@ -270,13 +272,15 @@ def generate_chat_response(
|
|||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
def __init__(self, requests: int, window: int):
|
||||
def __init__(self, requests: int, subscribed_requests: int, window: int):
|
||||
self.requests = requests
|
||||
self.subscribed_requests = subscribed_requests
|
||||
self.window = window
|
||||
self.cache: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
def __call__(self, request: Request):
|
||||
user: KhojUser = request.user.object
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
user_requests = self.cache[user.uuid]
|
||||
|
||||
# Remove requests outside of the time window
|
||||
|
@ -285,13 +289,69 @@ class ApiUserRateLimiter:
|
|||
user_requests.pop(0)
|
||||
|
||||
# Check if the user has exceeded the rate limit
|
||||
if len(user_requests) >= self.requests:
|
||||
if subscribed and len(user_requests) >= self.subscribed_requests:
|
||||
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||
if not subscribed and len(user_requests) >= self.requests:
|
||||
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
|
||||
|
||||
# Add the current request to the cache
|
||||
user_requests.append(time())
|
||||
|
||||
|
||||
class ApiIndexedDataLimiter:
|
||||
def __init__(
|
||||
self,
|
||||
incoming_entries_size_limit: float,
|
||||
subscribed_incoming_entries_size_limit: float,
|
||||
total_entries_size_limit: float,
|
||||
subscribed_total_entries_size_limit: float,
|
||||
):
|
||||
self.num_entries_size = incoming_entries_size_limit
|
||||
self.subscribed_num_entries_size = subscribed_incoming_entries_size_limit
|
||||
self.total_entries_size_limit = total_entries_size_limit
|
||||
self.subscribed_total_entries_size = subscribed_total_entries_size_limit
|
||||
|
||||
def __call__(self, request: Request, files: List[UploadFile]):
|
||||
if state.billing_enabled is False:
|
||||
return
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
incoming_data_size_mb = 0
|
||||
deletion_file_names = set()
|
||||
|
||||
if not request.user.is_authenticated:
|
||||
return
|
||||
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
for file in files:
|
||||
if file.size == 0:
|
||||
deletion_file_names.add(file.filename)
|
||||
|
||||
incoming_data_size_mb += file.size / 1024 / 1024
|
||||
|
||||
num_deleted_entries = 0
|
||||
for file_path in deletion_file_names:
|
||||
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
||||
num_deleted_entries += deleted_count
|
||||
|
||||
logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.")
|
||||
|
||||
if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size:
|
||||
raise HTTPException(status_code=429, detail="Too much data indexed.")
|
||||
if not subscribed and incoming_data_size_mb >= self.num_entries_size:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
|
||||
)
|
||||
|
||||
user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user)
|
||||
if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size:
|
||||
raise HTTPException(status_code=429, detail="Too much data indexed.")
|
||||
if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
|
||||
)
|
||||
|
||||
|
||||
class CommonQueryParamsClass:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import logging
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Header, Request, Response, UploadFile
|
||||
from fastapi import APIRouter, Header, Request, Response, UploadFile, Depends
|
||||
from pydantic import BaseModel
|
||||
from starlette.authentication import requires
|
||||
|
||||
|
@ -18,6 +18,7 @@ from khoj.search_type import image_search, text_search
|
|||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import ContentIndex, SearchModels
|
||||
from khoj.utils.helpers import LRU, get_file_type
|
||||
from khoj.routers.helpers import ApiIndexedDataLimiter
|
||||
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
|
||||
|
@ -53,6 +54,14 @@ async def update(
|
|||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
|
||||
ApiIndexedDataLimiter(
|
||||
incoming_entries_size_limit=10,
|
||||
subscribed_incoming_entries_size_limit=25,
|
||||
total_entries_size_limit=10,
|
||||
subscribed_total_entries_size_limit=100,
|
||||
)
|
||||
),
|
||||
):
|
||||
user = request.user.object
|
||||
try:
|
||||
|
@ -92,7 +101,7 @@ async def update(
|
|||
logger.info("📬 Initializing content index on first run.")
|
||||
default_full_config = FullConfig(
|
||||
content_type=None,
|
||||
search_type=SearchConfig.parse_obj(constants.default_config["search-type"]),
|
||||
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
|
||||
processor=None,
|
||||
)
|
||||
state.config = default_full_config
|
||||
|
@ -116,7 +125,7 @@ async def update(
|
|||
configure_content,
|
||||
state.content_index,
|
||||
state.config.content_type,
|
||||
indexer_input.dict(),
|
||||
indexer_input.model_dump(),
|
||||
state.search_models,
|
||||
force,
|
||||
t,
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
# System Packages
|
||||
import json
|
||||
import os
|
||||
import math
|
||||
from datetime import timedelta
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Request
|
||||
from fastapi.responses import HTMLResponse, FileResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from starlette.authentication import requires
|
||||
from starlette.authentication import requires, has_required_scope
|
||||
from khoj.database import adapters
|
||||
from khoj.database.models import KhojUser
|
||||
from khoj.utils.rawconfig import (
|
||||
|
@ -37,7 +39,6 @@ templates = Jinja2Templates(directory=constants.web_directory)
|
|||
def index(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
|
@ -46,7 +47,7 @@ def index(request: Request):
|
|||
"request": request,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["premium"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -57,7 +58,6 @@ def index(request: Request):
|
|||
def index_post(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
|
@ -66,7 +66,7 @@ def index_post(request: Request):
|
|||
"request": request,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["premium"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -77,7 +77,6 @@ def index_post(request: Request):
|
|||
def search_page(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
|
@ -86,7 +85,7 @@ def search_page(request: Request):
|
|||
"request": request,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["premium"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -97,7 +96,6 @@ def search_page(request: Request):
|
|||
def chat_page(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
|
@ -106,7 +104,7 @@ def chat_page(request: Request):
|
|||
"request": request,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["premium"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -141,7 +139,7 @@ def config_page(request: Request):
|
|||
subscription_renewal_date = (
|
||||
user_subscription.renewal_date.strftime("%d %b %Y")
|
||||
if user_subscription and user_subscription.renewal_date
|
||||
else None
|
||||
else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y")
|
||||
)
|
||||
|
||||
enabled_content_source = set(EntryAdapters.get_unique_file_sources(user))
|
||||
|
@ -171,7 +169,7 @@ def config_page(request: Request):
|
|||
"subscription_state": user_subscription_state,
|
||||
"subscription_renewal_date": subscription_renewal_date,
|
||||
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["premium"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -182,7 +180,6 @@ def config_page(request: Request):
|
|||
def github_config_page(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
current_github_config = get_user_github_config(user)
|
||||
|
||||
|
@ -212,7 +209,7 @@ def github_config_page(request: Request):
|
|||
"current_config": current_config,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["premium"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -223,7 +220,6 @@ def github_config_page(request: Request):
|
|||
def notion_config_page(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = adapters.get_user_subscription(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
current_notion_config = get_user_notion_config(user)
|
||||
|
||||
|
@ -240,7 +236,7 @@ def notion_config_page(request: Request):
|
|||
"current_config": current_config,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["premium"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
@ -251,7 +247,6 @@ def notion_config_page(request: Request):
|
|||
def computer_config_page(request: Request):
|
||||
user = request.user.object
|
||||
user_picture = request.session.get("user", {}).get("picture")
|
||||
user_subscription_state = get_user_subscription_state(user.email)
|
||||
has_documents = EntryAdapters.user_has_entries(user=user)
|
||||
|
||||
return templates.TemplateResponse(
|
||||
|
@ -260,7 +255,7 @@ def computer_config_page(request: Request):
|
|||
"request": request,
|
||||
"username": user.username,
|
||||
"user_photo": user_picture,
|
||||
"is_active": user_subscription_state == "subscribed" or user_subscription_state == "unsubscribed",
|
||||
"is_active": has_required_scope(request, ["premium"]),
|
||||
"has_documents": has_documents,
|
||||
},
|
||||
)
|
||||
|
|
|
@ -102,6 +102,24 @@ def default_user3():
|
|||
return user
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def default_user4():
|
||||
"""
|
||||
This user should not have a valid subscription
|
||||
"""
|
||||
if KhojUser.objects.filter(username="default4").exists():
|
||||
return KhojUser.objects.get(username="default4")
|
||||
|
||||
user = KhojUser.objects.create(
|
||||
username="default4",
|
||||
email="default4@example.com",
|
||||
password="default4",
|
||||
)
|
||||
SubscriptionFactory(user=user, renewal_date=None)
|
||||
return user
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def api_user(default_user):
|
||||
|
@ -141,6 +159,19 @@ def api_user3(default_user3):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
@pytest.fixture
|
||||
def api_user4(default_user4):
|
||||
if KhojApiUser.objects.filter(user=default_user4).exists():
|
||||
return KhojApiUser.objects.get(user=default_user4)
|
||||
|
||||
return KhojApiUser.objects.create(
|
||||
user=default_user4,
|
||||
name="api-key",
|
||||
token="kk-diff-secret-4",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_models(search_config: SearchConfig):
|
||||
search_models = SearchModels()
|
||||
|
|
|
@ -125,6 +125,67 @@ def test_regenerate_with_invalid_content_type(client):
|
|||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update_big_files(client):
|
||||
# Arrange
|
||||
state.billing_enabled = True
|
||||
files = get_big_size_sample_files_data()
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
|
||||
# Act
|
||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 429
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update_medium_file_unsubscribed(client, api_user4: KhojApiUser):
|
||||
# Arrange
|
||||
api_token = api_user4.token
|
||||
state.billing_enabled = True
|
||||
files = get_medium_size_sample_files_data()
|
||||
headers = {"Authorization": f"Bearer {api_token}"}
|
||||
|
||||
# Act
|
||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 429
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update_normal_file_unsubscribed(client, api_user4: KhojApiUser):
|
||||
# Arrange
|
||||
api_token = api_user4.token
|
||||
state.billing_enabled = True
|
||||
files = get_sample_files_data()
|
||||
headers = {"Authorization": f"Bearer {api_token}"}
|
||||
|
||||
# Act
|
||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update_big_files_no_billing(client):
|
||||
# Arrange
|
||||
state.billing_enabled = False
|
||||
files = get_big_size_sample_files_data()
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
|
||||
# Act
|
||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update(client):
|
||||
|
@ -421,3 +482,23 @@ def get_sample_files_data():
|
|||
),
|
||||
("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")),
|
||||
]
|
||||
|
||||
|
||||
def get_big_size_sample_files_data():
|
||||
big_text = "a" * (25 * 1024 * 1024) # a string of approximately 25 MB
|
||||
return [
|
||||
(
|
||||
"files",
|
||||
("path/to/filename.org", big_text, "text/org"),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def get_medium_size_sample_files_data():
|
||||
big_text = "a" * (10 * 1024 * 1024) # a string of approximately 10 MB
|
||||
return [
|
||||
(
|
||||
"files",
|
||||
("path/to/filename.org", big_text, "text/org"),
|
||||
)
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue