Add support for rate limiting the amount of data indexed

- Add a dependency on the indexer API endpoint that rounds up the amount of data indexed and uses that to determine whether the next set of data should be processed
- Delete any files that are being removed for adminstering the calculation
- Show current amount of data indexed in the config page
This commit is contained in:
sabaimran 2023-11-25 20:28:04 -08:00
parent dd1badae81
commit b2afbaa315
8 changed files with 127 additions and 11 deletions

View file

@ -80,6 +80,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
@ -101,6 +102,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
)
if subscribed:
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(

View file

@ -1,6 +1,7 @@
import math
import random
import secrets
import sys
from datetime import date, datetime, timezone, timedelta
from typing import List, Optional, Type
from enum import Enum
@ -474,6 +475,12 @@ class EntryAdapters:
async def adelete_all_entries(user: KhojUser):
return await Entry.objects.filter(user=user).adelete()
@staticmethod
def get_size_of_indexed_data_in_mb(user: KhojUser):
entries = Entry.objects.filter(user=user).iterator()
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
return total_size / 1024 / 1024
@staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
q_filter_terms = Q()

View file

@ -4,6 +4,7 @@
<div class="page">
<div id="content" class="section">
<h2 class="section-title">Content</h2>
<p id="indexed-data-size" class="card-description">{{indexed_data_size_in_mb}} MB used</p>
<div class="section-cards">
<div class="card">
<div class="card-title-row">
@ -171,7 +172,7 @@
</div>
</div>
</div>
{% if not billing_enabled %}
{% if billing_enabled %}
<div id="billing" class="section">
<h2 class="section-title">Billing</h2>
<div class="section-cards">
@ -191,7 +192,7 @@
<p id="trial-description"
class="card-description"
style="display: {% if subscription_state != 'trial' %}none{% endif %}">
Subscribe to Khoj Cloud
Subscribe to Khoj Cloud. See <a href="https://khoj.dev/pricing">pricing</a> for details.
</p>
<p id="unsubscribe-description"
class="card-description"

View file

@ -9,10 +9,10 @@ from typing import Any, Dict, List, Optional, Union
from asgiref.sync import sync_to_async
# External Packages
from fastapi import APIRouter, Depends, Header, HTTPException, Request
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires, has_required_scope
from starlette.authentication import requires
# Internal Packages
from khoj.configure import configure_server

View file

@ -9,11 +9,12 @@ from functools import partial
from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
# External Packages
from fastapi import Depends, Header, HTTPException, Request
from fastapi import Depends, Header, HTTPException, Request, UploadFile
from starlette.authentication import has_required_scope
from asgiref.sync import sync_to_async
from khoj.database.adapters import ConversationAdapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser, Subscription
from khoj.processor.conversation import prompts
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline
@ -297,6 +298,60 @@ class ApiUserRateLimiter:
user_requests.append(time())
class ApiIndexedDataLimiter:
def __init__(
self,
incoming_entries_size_limit: float,
subscribed_incoming_entries_size_limit: float,
total_entries_size_limit: float,
subscribed_total_entries_size_limit: float,
):
self.num_entries_size = incoming_entries_size_limit
self.subscribed_num_entries_size = subscribed_incoming_entries_size_limit
self.total_entries_size_limit = total_entries_size_limit
self.subscribed_total_entries_size = subscribed_total_entries_size_limit
def __call__(self, request: Request, files: List[UploadFile]):
if state.billing_enabled is False:
return
subscribed = has_required_scope(request, ["subscribed"])
incoming_data_size_mb = 0
deletion_file_names = set()
if not request.user.is_authenticated:
return
user: KhojUser = request.user.object
for file in files:
if file.size == 0:
deletion_file_names.add(file.filename)
incoming_data_size_mb += file.size / 1024 / 1024
num_deleted_entries = 0
for file_path in deletion_file_names:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_entries += deleted_count
logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.")
if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size:
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and incoming_data_size_mb >= self.num_entries_size:
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user)
if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size:
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit:
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
class CommonQueryParamsClass:
def __init__(
self,

View file

@ -2,7 +2,7 @@ import asyncio
import logging
from typing import Dict, Optional, Union
from fastapi import APIRouter, Header, Request, Response, UploadFile
from fastapi import APIRouter, Header, Request, Response, UploadFile, Depends
from pydantic import BaseModel
from starlette.authentication import requires
@ -18,6 +18,7 @@ from khoj.search_type import image_search, text_search
from khoj.utils import constants, state
from khoj.utils.config import ContentIndex, SearchModels
from khoj.utils.helpers import LRU, get_file_type
from khoj.routers.helpers import ApiIndexedDataLimiter
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
from khoj.utils.yaml import save_config_to_file_updated_state
@ -53,6 +54,14 @@ async def update(
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
ApiIndexedDataLimiter(
incoming_entries_size_limit=10,
subscribed_incoming_entries_size_limit=25,
total_entries_size_limit=10,
subscribed_total_entries_size_limit=100,
)
),
):
user = request.user.object
try:
@ -92,7 +101,7 @@ async def update(
logger.info("📬 Initializing content index on first run.")
default_full_config = FullConfig(
content_type=None,
search_type=SearchConfig.parse_obj(constants.default_config["search-type"]),
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
processor=None,
)
state.config = default_full_config
@ -116,7 +125,7 @@ async def update(
configure_content,
state.content_index,
state.config.content_type,
indexer_input.dict(),
indexer_input.model_dump(),
state.search_models,
force,
t,

View file

@ -1,6 +1,8 @@
# System Packages
import json
import os
import math
from datetime import timedelta
# External Packages
from fastapi import APIRouter
@ -137,8 +139,9 @@ def config_page(request: Request):
subscription_renewal_date = (
user_subscription.renewal_date.strftime("%d %b %Y")
if user_subscription and user_subscription.renewal_date
else None
else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y")
)
indexed_data_size_in_mb = math.ceil(EntryAdapters.get_size_of_indexed_data_in_mb(user))
enabled_content_source = set(EntryAdapters.get_unique_file_sources(user))
successfully_configured = {
@ -169,6 +172,7 @@ def config_page(request: Request):
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
"is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents,
"indexed_data_size_in_mb": indexed_data_size_in_mb,
},
)

View file

@ -125,6 +125,34 @@ def test_regenerate_with_invalid_content_type(client):
assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_big_files(client):
state.billing_enabled = True
# Arrange
files = get_big_size_sample_files_data()
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 429
@pytest.mark.django_db(transaction=True)
def test_index_update_big_files_no_billing(client):
# Arrange
files = get_big_size_sample_files_data()
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 200
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update(client):
@ -421,3 +449,13 @@ def get_sample_files_data():
),
("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")),
]
def get_big_size_sample_files_data():
big_text = "a" * (25 * 1024 * 1024) # a string of approximately 25 MB
return [
(
"files",
("path/to/filename.org", big_text, "text/org"),
)
]