Add support for rate limiting the amount of data indexed

- Add a dependency on the indexer API endpoint that rounds up the amount of data indexed and uses that to determine whether the next set of data should be processed
- Delete any files that are being removed for adminstering the calculation
- Show current amount of data indexed in the config page
This commit is contained in:
sabaimran 2023-11-25 20:28:04 -08:00
parent dd1badae81
commit b2afbaa315
8 changed files with 127 additions and 11 deletions

View file

@ -80,6 +80,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
subscribed = ( subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
) )
if subscribed: if subscribed:
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user) return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
@ -101,6 +102,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
subscribed = ( subscribed = (
subscription_state == SubscriptionState.SUBSCRIBED.value subscription_state == SubscriptionState.SUBSCRIBED.value
or subscription_state == SubscriptionState.TRIAL.value or subscription_state == SubscriptionState.TRIAL.value
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
) )
if subscribed: if subscribed:
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser( return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(

View file

@ -1,6 +1,7 @@
import math import math
import random import random
import secrets import secrets
import sys
from datetime import date, datetime, timezone, timedelta from datetime import date, datetime, timezone, timedelta
from typing import List, Optional, Type from typing import List, Optional, Type
from enum import Enum from enum import Enum
@ -474,6 +475,12 @@ class EntryAdapters:
async def adelete_all_entries(user: KhojUser): async def adelete_all_entries(user: KhojUser):
return await Entry.objects.filter(user=user).adelete() return await Entry.objects.filter(user=user).adelete()
@staticmethod
def get_size_of_indexed_data_in_mb(user: KhojUser):
entries = Entry.objects.filter(user=user).iterator()
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
return total_size / 1024 / 1024
@staticmethod @staticmethod
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None): def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
q_filter_terms = Q() q_filter_terms = Q()

View file

@ -4,6 +4,7 @@
<div class="page"> <div class="page">
<div id="content" class="section"> <div id="content" class="section">
<h2 class="section-title">Content</h2> <h2 class="section-title">Content</h2>
<p id="indexed-data-size" class="card-description">{{indexed_data_size_in_mb}} MB used</p>
<div class="section-cards"> <div class="section-cards">
<div class="card"> <div class="card">
<div class="card-title-row"> <div class="card-title-row">
@ -171,7 +172,7 @@
</div> </div>
</div> </div>
</div> </div>
{% if not billing_enabled %} {% if billing_enabled %}
<div id="billing" class="section"> <div id="billing" class="section">
<h2 class="section-title">Billing</h2> <h2 class="section-title">Billing</h2>
<div class="section-cards"> <div class="section-cards">
@ -191,7 +192,7 @@
<p id="trial-description" <p id="trial-description"
class="card-description" class="card-description"
style="display: {% if subscription_state != 'trial' %}none{% endif %}"> style="display: {% if subscription_state != 'trial' %}none{% endif %}">
Subscribe to Khoj Cloud Subscribe to Khoj Cloud. See <a href="https://khoj.dev/pricing">pricing</a> for details.
</p> </p>
<p id="unsubscribe-description" <p id="unsubscribe-description"
class="card-description" class="card-description"

View file

@ -9,10 +9,10 @@ from typing import Any, Dict, List, Optional, Union
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
# External Packages # External Packages
from fastapi import APIRouter, Depends, Header, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires, has_required_scope from starlette.authentication import requires
# Internal Packages # Internal Packages
from khoj.configure import configure_server from khoj.configure import configure_server

View file

@ -9,11 +9,12 @@ from functools import partial
from time import time from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
# External Packages from fastapi import Depends, Header, HTTPException, Request, UploadFile
from fastapi import Depends, Header, HTTPException, Request
from starlette.authentication import has_required_scope from starlette.authentication import has_required_scope
from asgiref.sync import sync_to_async
from khoj.database.adapters import ConversationAdapters
from khoj.database.adapters import ConversationAdapters, EntryAdapters
from khoj.database.models import KhojUser, Subscription from khoj.database.models import KhojUser, Subscription
from khoj.processor.conversation import prompts from khoj.processor.conversation import prompts
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline
@ -297,6 +298,60 @@ class ApiUserRateLimiter:
user_requests.append(time()) user_requests.append(time())
class ApiIndexedDataLimiter:
def __init__(
self,
incoming_entries_size_limit: float,
subscribed_incoming_entries_size_limit: float,
total_entries_size_limit: float,
subscribed_total_entries_size_limit: float,
):
self.num_entries_size = incoming_entries_size_limit
self.subscribed_num_entries_size = subscribed_incoming_entries_size_limit
self.total_entries_size_limit = total_entries_size_limit
self.subscribed_total_entries_size = subscribed_total_entries_size_limit
def __call__(self, request: Request, files: List[UploadFile]):
if state.billing_enabled is False:
return
subscribed = has_required_scope(request, ["subscribed"])
incoming_data_size_mb = 0
deletion_file_names = set()
if not request.user.is_authenticated:
return
user: KhojUser = request.user.object
for file in files:
if file.size == 0:
deletion_file_names.add(file.filename)
incoming_data_size_mb += file.size / 1024 / 1024
num_deleted_entries = 0
for file_path in deletion_file_names:
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
num_deleted_entries += deleted_count
logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.")
if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size:
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and incoming_data_size_mb >= self.num_entries_size:
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user)
if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size:
raise HTTPException(status_code=429, detail="Too much data indexed.")
if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit:
raise HTTPException(
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
)
class CommonQueryParamsClass: class CommonQueryParamsClass:
def __init__( def __init__(
self, self,

View file

@ -2,7 +2,7 @@ import asyncio
import logging import logging
from typing import Dict, Optional, Union from typing import Dict, Optional, Union
from fastapi import APIRouter, Header, Request, Response, UploadFile from fastapi import APIRouter, Header, Request, Response, UploadFile, Depends
from pydantic import BaseModel from pydantic import BaseModel
from starlette.authentication import requires from starlette.authentication import requires
@ -18,6 +18,7 @@ from khoj.search_type import image_search, text_search
from khoj.utils import constants, state from khoj.utils import constants, state
from khoj.utils.config import ContentIndex, SearchModels from khoj.utils.config import ContentIndex, SearchModels
from khoj.utils.helpers import LRU, get_file_type from khoj.utils.helpers import LRU, get_file_type
from khoj.routers.helpers import ApiIndexedDataLimiter
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
from khoj.utils.yaml import save_config_to_file_updated_state from khoj.utils.yaml import save_config_to_file_updated_state
@ -53,6 +54,14 @@ async def update(
user_agent: Optional[str] = Header(None), user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None), referer: Optional[str] = Header(None),
host: Optional[str] = Header(None), host: Optional[str] = Header(None),
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
ApiIndexedDataLimiter(
incoming_entries_size_limit=10,
subscribed_incoming_entries_size_limit=25,
total_entries_size_limit=10,
subscribed_total_entries_size_limit=100,
)
),
): ):
user = request.user.object user = request.user.object
try: try:
@ -92,7 +101,7 @@ async def update(
logger.info("📬 Initializing content index on first run.") logger.info("📬 Initializing content index on first run.")
default_full_config = FullConfig( default_full_config = FullConfig(
content_type=None, content_type=None,
search_type=SearchConfig.parse_obj(constants.default_config["search-type"]), search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
processor=None, processor=None,
) )
state.config = default_full_config state.config = default_full_config
@ -116,7 +125,7 @@ async def update(
configure_content, configure_content,
state.content_index, state.content_index,
state.config.content_type, state.config.content_type,
indexer_input.dict(), indexer_input.model_dump(),
state.search_models, state.search_models,
force, force,
t, t,

View file

@ -1,6 +1,8 @@
# System Packages # System Packages
import json import json
import os import os
import math
from datetime import timedelta
# External Packages # External Packages
from fastapi import APIRouter from fastapi import APIRouter
@ -137,8 +139,9 @@ def config_page(request: Request):
subscription_renewal_date = ( subscription_renewal_date = (
user_subscription.renewal_date.strftime("%d %b %Y") user_subscription.renewal_date.strftime("%d %b %Y")
if user_subscription and user_subscription.renewal_date if user_subscription and user_subscription.renewal_date
else None else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y")
) )
indexed_data_size_in_mb = math.ceil(EntryAdapters.get_size_of_indexed_data_in_mb(user))
enabled_content_source = set(EntryAdapters.get_unique_file_sources(user)) enabled_content_source = set(EntryAdapters.get_unique_file_sources(user))
successfully_configured = { successfully_configured = {
@ -169,6 +172,7 @@ def config_page(request: Request):
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"), "khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
"is_active": has_required_scope(request, ["subscribed"]), "is_active": has_required_scope(request, ["subscribed"]),
"has_documents": has_documents, "has_documents": has_documents,
"indexed_data_size_in_mb": indexed_data_size_in_mb,
}, },
) )

View file

@ -125,6 +125,34 @@ def test_regenerate_with_invalid_content_type(client):
assert response.status_code == 422 assert response.status_code == 422
# ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True)
def test_index_update_big_files(client):
state.billing_enabled = True
# Arrange
files = get_big_size_sample_files_data()
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 429
@pytest.mark.django_db(transaction=True)
def test_index_update_big_files_no_billing(client):
# Arrange
files = get_big_size_sample_files_data()
headers = {"Authorization": "Bearer kk-secret"}
# Act
response = client.post("/api/v1/index/update", files=files, headers=headers)
# Assert
assert response.status_code == 200
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
@pytest.mark.django_db(transaction=True) @pytest.mark.django_db(transaction=True)
def test_index_update(client): def test_index_update(client):
@ -421,3 +449,13 @@ def get_sample_files_data():
), ),
("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")), ("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")),
] ]
def get_big_size_sample_files_data():
big_text = "a" * (25 * 1024 * 1024) # a string of approximately 25 MB
return [
(
"files",
("path/to/filename.org", big_text, "text/org"),
)
]