mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 15:38:55 +01:00
Add support for rate limiting the amount of data indexed
- Add a dependency on the indexer API endpoint that rounds up the amount of data indexed and uses that to determine whether the next set of data should be processed - Delete any files that are being removed for adminstering the calculation - Show current amount of data indexed in the config page
This commit is contained in:
parent
dd1badae81
commit
b2afbaa315
8 changed files with 127 additions and 11 deletions
|
@ -80,6 +80,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
|
||||
|
@ -101,6 +102,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
|||
subscribed = (
|
||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||
or subscription_state == SubscriptionState.TRIAL.value
|
||||
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||
)
|
||||
if subscribed:
|
||||
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import math
|
||||
import random
|
||||
import secrets
|
||||
import sys
|
||||
from datetime import date, datetime, timezone, timedelta
|
||||
from typing import List, Optional, Type
|
||||
from enum import Enum
|
||||
|
@ -474,6 +475,12 @@ class EntryAdapters:
|
|||
async def adelete_all_entries(user: KhojUser):
|
||||
return await Entry.objects.filter(user=user).adelete()
|
||||
|
||||
@staticmethod
|
||||
def get_size_of_indexed_data_in_mb(user: KhojUser):
|
||||
entries = Entry.objects.filter(user=user).iterator()
|
||||
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
|
||||
return total_size / 1024 / 1024
|
||||
|
||||
@staticmethod
|
||||
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
|
||||
q_filter_terms = Q()
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
<div class="page">
|
||||
<div id="content" class="section">
|
||||
<h2 class="section-title">Content</h2>
|
||||
<p id="indexed-data-size" class="card-description">{{indexed_data_size_in_mb}} MB used</p>
|
||||
<div class="section-cards">
|
||||
<div class="card">
|
||||
<div class="card-title-row">
|
||||
|
@ -171,7 +172,7 @@
|
|||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{% if not billing_enabled %}
|
||||
{% if billing_enabled %}
|
||||
<div id="billing" class="section">
|
||||
<h2 class="section-title">Billing</h2>
|
||||
<div class="section-cards">
|
||||
|
@ -191,7 +192,7 @@
|
|||
<p id="trial-description"
|
||||
class="card-description"
|
||||
style="display: {% if subscription_state != 'trial' %}none{% endif %}">
|
||||
Subscribe to Khoj Cloud
|
||||
Subscribe to Khoj Cloud. See <a href="https://khoj.dev/pricing">pricing</a> for details.
|
||||
</p>
|
||||
<p id="unsubscribe-description"
|
||||
class="card-description"
|
||||
|
|
|
@ -9,10 +9,10 @@ from typing import Any, Dict, List, Optional, Union
|
|||
from asgiref.sync import sync_to_async
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from starlette.authentication import requires, has_required_scope
|
||||
from starlette.authentication import requires
|
||||
|
||||
# Internal Packages
|
||||
from khoj.configure import configure_server
|
||||
|
|
|
@ -9,11 +9,12 @@ from functools import partial
|
|||
from time import time
|
||||
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
# External Packages
|
||||
from fastapi import Depends, Header, HTTPException, Request
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from starlette.authentication import has_required_scope
|
||||
from asgiref.sync import sync_to_async
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters
|
||||
|
||||
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||
from khoj.database.models import KhojUser, Subscription
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline
|
||||
|
@ -297,6 +298,60 @@ class ApiUserRateLimiter:
|
|||
user_requests.append(time())
|
||||
|
||||
|
||||
class ApiIndexedDataLimiter:
|
||||
def __init__(
|
||||
self,
|
||||
incoming_entries_size_limit: float,
|
||||
subscribed_incoming_entries_size_limit: float,
|
||||
total_entries_size_limit: float,
|
||||
subscribed_total_entries_size_limit: float,
|
||||
):
|
||||
self.num_entries_size = incoming_entries_size_limit
|
||||
self.subscribed_num_entries_size = subscribed_incoming_entries_size_limit
|
||||
self.total_entries_size_limit = total_entries_size_limit
|
||||
self.subscribed_total_entries_size = subscribed_total_entries_size_limit
|
||||
|
||||
def __call__(self, request: Request, files: List[UploadFile]):
|
||||
if state.billing_enabled is False:
|
||||
return
|
||||
subscribed = has_required_scope(request, ["subscribed"])
|
||||
incoming_data_size_mb = 0
|
||||
deletion_file_names = set()
|
||||
|
||||
if not request.user.is_authenticated:
|
||||
return
|
||||
|
||||
user: KhojUser = request.user.object
|
||||
|
||||
for file in files:
|
||||
if file.size == 0:
|
||||
deletion_file_names.add(file.filename)
|
||||
|
||||
incoming_data_size_mb += file.size / 1024 / 1024
|
||||
|
||||
num_deleted_entries = 0
|
||||
for file_path in deletion_file_names:
|
||||
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
||||
num_deleted_entries += deleted_count
|
||||
|
||||
logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.")
|
||||
|
||||
if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size:
|
||||
raise HTTPException(status_code=429, detail="Too much data indexed.")
|
||||
if not subscribed and incoming_data_size_mb >= self.num_entries_size:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
|
||||
)
|
||||
|
||||
user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user)
|
||||
if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size:
|
||||
raise HTTPException(status_code=429, detail="Too much data indexed.")
|
||||
if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit:
|
||||
raise HTTPException(
|
||||
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
|
||||
)
|
||||
|
||||
|
||||
class CommonQueryParamsClass:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import logging
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Header, Request, Response, UploadFile
|
||||
from fastapi import APIRouter, Header, Request, Response, UploadFile, Depends
|
||||
from pydantic import BaseModel
|
||||
from starlette.authentication import requires
|
||||
|
||||
|
@ -18,6 +18,7 @@ from khoj.search_type import image_search, text_search
|
|||
from khoj.utils import constants, state
|
||||
from khoj.utils.config import ContentIndex, SearchModels
|
||||
from khoj.utils.helpers import LRU, get_file_type
|
||||
from khoj.routers.helpers import ApiIndexedDataLimiter
|
||||
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
|
||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||
|
||||
|
@ -53,6 +54,14 @@ async def update(
|
|||
user_agent: Optional[str] = Header(None),
|
||||
referer: Optional[str] = Header(None),
|
||||
host: Optional[str] = Header(None),
|
||||
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
|
||||
ApiIndexedDataLimiter(
|
||||
incoming_entries_size_limit=10,
|
||||
subscribed_incoming_entries_size_limit=25,
|
||||
total_entries_size_limit=10,
|
||||
subscribed_total_entries_size_limit=100,
|
||||
)
|
||||
),
|
||||
):
|
||||
user = request.user.object
|
||||
try:
|
||||
|
@ -92,7 +101,7 @@ async def update(
|
|||
logger.info("📬 Initializing content index on first run.")
|
||||
default_full_config = FullConfig(
|
||||
content_type=None,
|
||||
search_type=SearchConfig.parse_obj(constants.default_config["search-type"]),
|
||||
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
|
||||
processor=None,
|
||||
)
|
||||
state.config = default_full_config
|
||||
|
@ -116,7 +125,7 @@ async def update(
|
|||
configure_content,
|
||||
state.content_index,
|
||||
state.config.content_type,
|
||||
indexer_input.dict(),
|
||||
indexer_input.model_dump(),
|
||||
state.search_models,
|
||||
force,
|
||||
t,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# System Packages
|
||||
import json
|
||||
import os
|
||||
import math
|
||||
from datetime import timedelta
|
||||
|
||||
# External Packages
|
||||
from fastapi import APIRouter
|
||||
|
@ -137,8 +139,9 @@ def config_page(request: Request):
|
|||
subscription_renewal_date = (
|
||||
user_subscription.renewal_date.strftime("%d %b %Y")
|
||||
if user_subscription and user_subscription.renewal_date
|
||||
else None
|
||||
else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y")
|
||||
)
|
||||
indexed_data_size_in_mb = math.ceil(EntryAdapters.get_size_of_indexed_data_in_mb(user))
|
||||
|
||||
enabled_content_source = set(EntryAdapters.get_unique_file_sources(user))
|
||||
successfully_configured = {
|
||||
|
@ -169,6 +172,7 @@ def config_page(request: Request):
|
|||
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
|
||||
"is_active": has_required_scope(request, ["subscribed"]),
|
||||
"has_documents": has_documents,
|
||||
"indexed_data_size_in_mb": indexed_data_size_in_mb,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
@ -125,6 +125,34 @@ def test_regenerate_with_invalid_content_type(client):
|
|||
assert response.status_code == 422
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update_big_files(client):
|
||||
state.billing_enabled = True
|
||||
# Arrange
|
||||
files = get_big_size_sample_files_data()
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
|
||||
# Act
|
||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 429
|
||||
|
||||
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update_big_files_no_billing(client):
|
||||
# Arrange
|
||||
files = get_big_size_sample_files_data()
|
||||
headers = {"Authorization": "Bearer kk-secret"}
|
||||
|
||||
# Act
|
||||
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
@pytest.mark.django_db(transaction=True)
|
||||
def test_index_update(client):
|
||||
|
@ -421,3 +449,13 @@ def get_sample_files_data():
|
|||
),
|
||||
("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")),
|
||||
]
|
||||
|
||||
|
||||
def get_big_size_sample_files_data():
|
||||
big_text = "a" * (25 * 1024 * 1024) # a string of approximately 25 MB
|
||||
return [
|
||||
(
|
||||
"files",
|
||||
("path/to/filename.org", big_text, "text/org"),
|
||||
)
|
||||
]
|
||||
|
|
Loading…
Reference in a new issue