mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-23 23:48:56 +01:00
Add support for rate limiting the amount of data indexed
- Add a dependency on the indexer API endpoint that rounds up the amount of data indexed and uses that to determine whether the next set of data should be processed - Delete any files that are being removed for adminstering the calculation - Show current amount of data indexed in the config page
This commit is contained in:
parent
dd1badae81
commit
b2afbaa315
8 changed files with 127 additions and 11 deletions
|
@ -80,6 +80,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
subscribed = (
|
subscribed = (
|
||||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||||
or subscription_state == SubscriptionState.TRIAL.value
|
or subscription_state == SubscriptionState.TRIAL.value
|
||||||
|
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||||
)
|
)
|
||||||
if subscribed:
|
if subscribed:
|
||||||
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
|
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(user)
|
||||||
|
@ -101,6 +102,7 @@ class UserAuthenticationBackend(AuthenticationBackend):
|
||||||
subscribed = (
|
subscribed = (
|
||||||
subscription_state == SubscriptionState.SUBSCRIBED.value
|
subscription_state == SubscriptionState.SUBSCRIBED.value
|
||||||
or subscription_state == SubscriptionState.TRIAL.value
|
or subscription_state == SubscriptionState.TRIAL.value
|
||||||
|
or subscription_state == SubscriptionState.UNSUBSCRIBED.value
|
||||||
)
|
)
|
||||||
if subscribed:
|
if subscribed:
|
||||||
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(
|
return AuthCredentials(["authenticated", "subscribed"]), AuthenticatedKhojUser(
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import secrets
|
import secrets
|
||||||
|
import sys
|
||||||
from datetime import date, datetime, timezone, timedelta
|
from datetime import date, datetime, timezone, timedelta
|
||||||
from typing import List, Optional, Type
|
from typing import List, Optional, Type
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
@ -474,6 +475,12 @@ class EntryAdapters:
|
||||||
async def adelete_all_entries(user: KhojUser):
|
async def adelete_all_entries(user: KhojUser):
|
||||||
return await Entry.objects.filter(user=user).adelete()
|
return await Entry.objects.filter(user=user).adelete()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_size_of_indexed_data_in_mb(user: KhojUser):
|
||||||
|
entries = Entry.objects.filter(user=user).iterator()
|
||||||
|
total_size = sum(sys.getsizeof(entry.compiled) for entry in entries)
|
||||||
|
return total_size / 1024 / 1024
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
|
def apply_filters(user: KhojUser, query: str, file_type_filter: str = None):
|
||||||
q_filter_terms = Q()
|
q_filter_terms = Q()
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
<div class="page">
|
<div class="page">
|
||||||
<div id="content" class="section">
|
<div id="content" class="section">
|
||||||
<h2 class="section-title">Content</h2>
|
<h2 class="section-title">Content</h2>
|
||||||
|
<p id="indexed-data-size" class="card-description">{{indexed_data_size_in_mb}} MB used</p>
|
||||||
<div class="section-cards">
|
<div class="section-cards">
|
||||||
<div class="card">
|
<div class="card">
|
||||||
<div class="card-title-row">
|
<div class="card-title-row">
|
||||||
|
@ -171,7 +172,7 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{% if not billing_enabled %}
|
{% if billing_enabled %}
|
||||||
<div id="billing" class="section">
|
<div id="billing" class="section">
|
||||||
<h2 class="section-title">Billing</h2>
|
<h2 class="section-title">Billing</h2>
|
||||||
<div class="section-cards">
|
<div class="section-cards">
|
||||||
|
@ -191,7 +192,7 @@
|
||||||
<p id="trial-description"
|
<p id="trial-description"
|
||||||
class="card-description"
|
class="card-description"
|
||||||
style="display: {% if subscription_state != 'trial' %}none{% endif %}">
|
style="display: {% if subscription_state != 'trial' %}none{% endif %}">
|
||||||
Subscribe to Khoj Cloud
|
Subscribe to Khoj Cloud. See <a href="https://khoj.dev/pricing">pricing</a> for details.
|
||||||
</p>
|
</p>
|
||||||
<p id="unsubscribe-description"
|
<p id="unsubscribe-description"
|
||||||
class="card-description"
|
class="card-description"
|
||||||
|
|
|
@ -9,10 +9,10 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
from asgiref.sync import sync_to_async
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from fastapi.requests import Request
|
from fastapi.requests import Request
|
||||||
from fastapi.responses import Response, StreamingResponse
|
from fastapi.responses import Response, StreamingResponse
|
||||||
from starlette.authentication import requires, has_required_scope
|
from starlette.authentication import requires
|
||||||
|
|
||||||
# Internal Packages
|
# Internal Packages
|
||||||
from khoj.configure import configure_server
|
from khoj.configure import configure_server
|
||||||
|
|
|
@ -9,11 +9,12 @@ from functools import partial
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
# External Packages
|
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||||
from fastapi import Depends, Header, HTTPException, Request
|
|
||||||
from starlette.authentication import has_required_scope
|
from starlette.authentication import has_required_scope
|
||||||
|
from asgiref.sync import sync_to_async
|
||||||
|
|
||||||
from khoj.database.adapters import ConversationAdapters
|
|
||||||
|
from khoj.database.adapters import ConversationAdapters, EntryAdapters
|
||||||
from khoj.database.models import KhojUser, Subscription
|
from khoj.database.models import KhojUser, Subscription
|
||||||
from khoj.processor.conversation import prompts
|
from khoj.processor.conversation import prompts
|
||||||
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline
|
from khoj.processor.conversation.gpt4all.chat_model import converse_offline, send_message_to_model_offline
|
||||||
|
@ -297,6 +298,60 @@ class ApiUserRateLimiter:
|
||||||
user_requests.append(time())
|
user_requests.append(time())
|
||||||
|
|
||||||
|
|
||||||
|
class ApiIndexedDataLimiter:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
incoming_entries_size_limit: float,
|
||||||
|
subscribed_incoming_entries_size_limit: float,
|
||||||
|
total_entries_size_limit: float,
|
||||||
|
subscribed_total_entries_size_limit: float,
|
||||||
|
):
|
||||||
|
self.num_entries_size = incoming_entries_size_limit
|
||||||
|
self.subscribed_num_entries_size = subscribed_incoming_entries_size_limit
|
||||||
|
self.total_entries_size_limit = total_entries_size_limit
|
||||||
|
self.subscribed_total_entries_size = subscribed_total_entries_size_limit
|
||||||
|
|
||||||
|
def __call__(self, request: Request, files: List[UploadFile]):
|
||||||
|
if state.billing_enabled is False:
|
||||||
|
return
|
||||||
|
subscribed = has_required_scope(request, ["subscribed"])
|
||||||
|
incoming_data_size_mb = 0
|
||||||
|
deletion_file_names = set()
|
||||||
|
|
||||||
|
if not request.user.is_authenticated:
|
||||||
|
return
|
||||||
|
|
||||||
|
user: KhojUser = request.user.object
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
if file.size == 0:
|
||||||
|
deletion_file_names.add(file.filename)
|
||||||
|
|
||||||
|
incoming_data_size_mb += file.size / 1024 / 1024
|
||||||
|
|
||||||
|
num_deleted_entries = 0
|
||||||
|
for file_path in deletion_file_names:
|
||||||
|
deleted_count = EntryAdapters.delete_entry_by_file(user, file_path)
|
||||||
|
num_deleted_entries += deleted_count
|
||||||
|
|
||||||
|
logger.info(f"Deleted {num_deleted_entries} entries for user: {user}.")
|
||||||
|
|
||||||
|
if subscribed and incoming_data_size_mb >= self.subscribed_num_entries_size:
|
||||||
|
raise HTTPException(status_code=429, detail="Too much data indexed.")
|
||||||
|
if not subscribed and incoming_data_size_mb >= self.num_entries_size:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
|
||||||
|
)
|
||||||
|
|
||||||
|
user_size_data = EntryAdapters.get_size_of_indexed_data_in_mb(user)
|
||||||
|
if subscribed and user_size_data + incoming_data_size_mb >= self.subscribed_total_entries_size:
|
||||||
|
raise HTTPException(status_code=429, detail="Too much data indexed.")
|
||||||
|
if not subscribed and user_size_data + incoming_data_size_mb >= self.total_entries_size_limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429, detail="Too much data indexed. Subscribe to increase your data index limit."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CommonQueryParamsClass:
|
class CommonQueryParamsClass:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Optional, Union
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
from fastapi import APIRouter, Header, Request, Response, UploadFile
|
from fastapi import APIRouter, Header, Request, Response, UploadFile, Depends
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ from khoj.search_type import image_search, text_search
|
||||||
from khoj.utils import constants, state
|
from khoj.utils import constants, state
|
||||||
from khoj.utils.config import ContentIndex, SearchModels
|
from khoj.utils.config import ContentIndex, SearchModels
|
||||||
from khoj.utils.helpers import LRU, get_file_type
|
from khoj.utils.helpers import LRU, get_file_type
|
||||||
|
from khoj.routers.helpers import ApiIndexedDataLimiter
|
||||||
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
|
from khoj.utils.rawconfig import ContentConfig, FullConfig, SearchConfig
|
||||||
from khoj.utils.yaml import save_config_to_file_updated_state
|
from khoj.utils.yaml import save_config_to_file_updated_state
|
||||||
|
|
||||||
|
@ -53,6 +54,14 @@ async def update(
|
||||||
user_agent: Optional[str] = Header(None),
|
user_agent: Optional[str] = Header(None),
|
||||||
referer: Optional[str] = Header(None),
|
referer: Optional[str] = Header(None),
|
||||||
host: Optional[str] = Header(None),
|
host: Optional[str] = Header(None),
|
||||||
|
indexed_data_limiter: ApiIndexedDataLimiter = Depends(
|
||||||
|
ApiIndexedDataLimiter(
|
||||||
|
incoming_entries_size_limit=10,
|
||||||
|
subscribed_incoming_entries_size_limit=25,
|
||||||
|
total_entries_size_limit=10,
|
||||||
|
subscribed_total_entries_size_limit=100,
|
||||||
|
)
|
||||||
|
),
|
||||||
):
|
):
|
||||||
user = request.user.object
|
user = request.user.object
|
||||||
try:
|
try:
|
||||||
|
@ -92,7 +101,7 @@ async def update(
|
||||||
logger.info("📬 Initializing content index on first run.")
|
logger.info("📬 Initializing content index on first run.")
|
||||||
default_full_config = FullConfig(
|
default_full_config = FullConfig(
|
||||||
content_type=None,
|
content_type=None,
|
||||||
search_type=SearchConfig.parse_obj(constants.default_config["search-type"]),
|
search_type=SearchConfig.model_validate(constants.default_config["search-type"]),
|
||||||
processor=None,
|
processor=None,
|
||||||
)
|
)
|
||||||
state.config = default_full_config
|
state.config = default_full_config
|
||||||
|
@ -116,7 +125,7 @@ async def update(
|
||||||
configure_content,
|
configure_content,
|
||||||
state.content_index,
|
state.content_index,
|
||||||
state.config.content_type,
|
state.config.content_type,
|
||||||
indexer_input.dict(),
|
indexer_input.model_dump(),
|
||||||
state.search_models,
|
state.search_models,
|
||||||
force,
|
force,
|
||||||
t,
|
t,
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
# System Packages
|
# System Packages
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
# External Packages
|
# External Packages
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
|
@ -137,8 +139,9 @@ def config_page(request: Request):
|
||||||
subscription_renewal_date = (
|
subscription_renewal_date = (
|
||||||
user_subscription.renewal_date.strftime("%d %b %Y")
|
user_subscription.renewal_date.strftime("%d %b %Y")
|
||||||
if user_subscription and user_subscription.renewal_date
|
if user_subscription and user_subscription.renewal_date
|
||||||
else None
|
else (user_subscription.created_at + timedelta(days=7)).strftime("%d %b %Y")
|
||||||
)
|
)
|
||||||
|
indexed_data_size_in_mb = math.ceil(EntryAdapters.get_size_of_indexed_data_in_mb(user))
|
||||||
|
|
||||||
enabled_content_source = set(EntryAdapters.get_unique_file_sources(user))
|
enabled_content_source = set(EntryAdapters.get_unique_file_sources(user))
|
||||||
successfully_configured = {
|
successfully_configured = {
|
||||||
|
@ -169,6 +172,7 @@ def config_page(request: Request):
|
||||||
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
|
"khoj_cloud_subscription_url": os.getenv("KHOJ_CLOUD_SUBSCRIPTION_URL"),
|
||||||
"is_active": has_required_scope(request, ["subscribed"]),
|
"is_active": has_required_scope(request, ["subscribed"]),
|
||||||
"has_documents": has_documents,
|
"has_documents": has_documents,
|
||||||
|
"indexed_data_size_in_mb": indexed_data_size_in_mb,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -125,6 +125,34 @@ def test_regenerate_with_invalid_content_type(client):
|
||||||
assert response.status_code == 422
|
assert response.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_index_update_big_files(client):
|
||||||
|
state.billing_enabled = True
|
||||||
|
# Arrange
|
||||||
|
files = get_big_size_sample_files_data()
|
||||||
|
headers = {"Authorization": "Bearer kk-secret"}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.django_db(transaction=True)
|
||||||
|
def test_index_update_big_files_no_billing(client):
|
||||||
|
# Arrange
|
||||||
|
files = get_big_size_sample_files_data()
|
||||||
|
headers = {"Authorization": "Bearer kk-secret"}
|
||||||
|
|
||||||
|
# Act
|
||||||
|
response = client.post("/api/v1/index/update", files=files, headers=headers)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
@pytest.mark.django_db(transaction=True)
|
@pytest.mark.django_db(transaction=True)
|
||||||
def test_index_update(client):
|
def test_index_update(client):
|
||||||
|
@ -421,3 +449,13 @@ def get_sample_files_data():
|
||||||
),
|
),
|
||||||
("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")),
|
("files", ("path/to/filename2.md", "**Understanding science through the lens of art**", "text/markdown")),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_big_size_sample_files_data():
|
||||||
|
big_text = "a" * (25 * 1024 * 1024) # a string of approximately 25 MB
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
"files",
|
||||||
|
("path/to/filename.org", big_text, "text/org"),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
Loading…
Reference in a new issue