mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-11-30 10:53:02 +01:00
Store rate limiter-related metadata in the database for more resilience (#629)
* Store rate limiter-related metadata in the database for more resilience - This helps maintain state even between server restarts - Allows you to scale up workers on your service without having to implement sticky routing * Make the usage exceeded message less abrasive * Fix rate limiter for specific conversation commands and improve the copy
This commit is contained in:
parent
71cbe5160d
commit
4fb8d5c6d4
6 changed files with 90 additions and 30 deletions
|
@ -27,6 +27,7 @@ from khoj.database.adapters import (
|
|||
aget_or_create_user_by_phone_number,
|
||||
aget_user_by_phone_number,
|
||||
aget_user_subscription_state,
|
||||
delete_user_requests,
|
||||
get_all_users,
|
||||
get_or_create_search_models,
|
||||
)
|
||||
|
@ -328,3 +329,9 @@ def upload_telemetry():
|
|||
logger.error(f"📡 Error uploading telemetry: {e}", exc_info=True)
|
||||
else:
|
||||
state.telemetry = []
|
||||
|
||||
|
||||
@schedule.repeat(schedule.every(31).minutes)
|
||||
def delete_old_user_requests():
|
||||
num_deleted = delete_user_requests()
|
||||
logger.info(f"🔥 Deleted {num_deleted} day-old user requests")
|
||||
|
|
|
@ -34,6 +34,7 @@ from khoj.database.models import (
|
|||
Subscription,
|
||||
TextToImageModelConfig,
|
||||
UserConversationConfig,
|
||||
UserRequests,
|
||||
UserSearchModelConfig,
|
||||
)
|
||||
from khoj.search_filter.date_filter import DateFilter
|
||||
|
@ -284,6 +285,10 @@ def get_user_notion_config(user: KhojUser):
|
|||
return config
|
||||
|
||||
|
||||
def delete_user_requests(window: timedelta = timedelta(days=1)):
|
||||
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
|
||||
|
||||
|
||||
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
|
||||
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
|
||||
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None
|
||||
|
|
27
src/khoj/database/migrations/0029_userrequests.py
Normal file
27
src/khoj/database/migrations/0029_userrequests.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
# Generated by Django 4.2.7 on 2024-01-29 08:55
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.conf import settings
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("database", "0028_khojuser_verified_phone_number"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="UserRequests",
|
||||
fields=[
|
||||
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
|
||||
("created_at", models.DateTimeField(auto_now_add=True)),
|
||||
("updated_at", models.DateTimeField(auto_now=True)),
|
||||
("slug", models.CharField(max_length=200)),
|
||||
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
|
@ -223,3 +223,8 @@ class EntryDates(BaseModel):
|
|||
indexes = [
|
||||
models.Index(fields=["date"]),
|
||||
]
|
||||
|
||||
|
||||
class UserRequests(BaseModel):
|
||||
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
|
||||
slug = models.CharField(max_length=200)
|
||||
|
|
|
@ -59,7 +59,9 @@ from khoj.utils.state import SearchType
|
|||
# Initialize Router
|
||||
api = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
|
||||
conversation_command_rate_limiter = ConversationCommandRateLimiter(
|
||||
trial_rate_limit=2, subscribed_rate_limit=100, slug="command"
|
||||
)
|
||||
|
||||
|
||||
@api.get("/search", response_model=List[SearchResponse])
|
||||
|
@ -301,8 +303,12 @@ async def transcribe(
|
|||
request: Request,
|
||||
common: CommonQueryParams,
|
||||
file: UploadFile = File(...),
|
||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
|
||||
rate_limiter_per_minute=Depends(
|
||||
ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60, slug="transcribe_minute")
|
||||
),
|
||||
rate_limiter_per_day=Depends(
|
||||
ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24, slug="transcribe_day")
|
||||
),
|
||||
):
|
||||
user: KhojUser = request.user.object
|
||||
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
|
||||
|
@ -361,8 +367,12 @@ async def chat(
|
|||
n: Optional[int] = 5,
|
||||
d: Optional[float] = 0.18,
|
||||
stream: Optional[bool] = False,
|
||||
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60)),
|
||||
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24)),
|
||||
rate_limiter_per_minute=Depends(
|
||||
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
|
||||
),
|
||||
rate_limiter_per_day=Depends(
|
||||
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
|
||||
),
|
||||
) -> Response:
|
||||
user: KhojUser = request.user.object
|
||||
q = unquote(q)
|
||||
|
@ -370,7 +380,7 @@ async def chat(
|
|||
await is_ready_to_chat(user)
|
||||
conversation_command = get_conversation_command(query=q, any_references=True)
|
||||
|
||||
conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
|
||||
await conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
|
||||
|
||||
q = q.replace(f"/{conversation_command.value}", "").strip()
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import json
|
|||
import logging
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import partial
|
||||
from time import time
|
||||
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
@ -19,6 +19,7 @@ from khoj.database.models import (
|
|||
KhojUser,
|
||||
Subscription,
|
||||
TextToImageModelConfig,
|
||||
UserRequests,
|
||||
)
|
||||
from khoj.processor.conversation import prompts
|
||||
from khoj.processor.conversation.offline.chat_model import (
|
||||
|
@ -336,11 +337,11 @@ async def text_to_image(message: str, conversation_log: dict) -> Tuple[Optional[
|
|||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
def __init__(self, requests: int, subscribed_requests: int, window: int):
|
||||
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
||||
self.requests = requests
|
||||
self.subscribed_requests = subscribed_requests
|
||||
self.window = window
|
||||
self.cache: dict[str, list[float]] = defaultdict(list)
|
||||
self.slug = slug
|
||||
|
||||
def __call__(self, request: Request):
|
||||
# Rate limiting is disabled if user unauthenticated.
|
||||
|
@ -350,31 +351,32 @@ class ApiUserRateLimiter:
|
|||
|
||||
user: KhojUser = request.user.object
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
user_requests = self.cache[user.uuid]
|
||||
|
||||
# Remove requests outside of the time window
|
||||
cutoff = time() - self.window
|
||||
while user_requests and user_requests[0] < cutoff:
|
||||
user_requests.pop(0)
|
||||
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
|
||||
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
||||
|
||||
# Check if the user has exceeded the rate limit
|
||||
if subscribed and len(user_requests) >= self.subscribed_requests:
|
||||
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||
if not subscribed and len(user_requests) >= self.requests:
|
||||
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
|
||||
if subscribed and count_requests >= self.subscribed_requests:
|
||||
raise HTTPException(status_code=429, detail="Slow down! Too Many Requests")
|
||||
if not subscribed and count_requests >= self.requests:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your rate limit via [your settings](https://app.khoj.dev/config).",
|
||||
)
|
||||
|
||||
# Add the current request to the cache
|
||||
user_requests.append(time())
|
||||
UserRequests.objects.create(user=user, slug=self.slug)
|
||||
|
||||
|
||||
class ConversationCommandRateLimiter:
|
||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int):
|
||||
self.cache: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
|
||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
||||
self.slug = slug
|
||||
self.trial_rate_limit = trial_rate_limit
|
||||
self.subscribed_rate_limit = subscribed_rate_limit
|
||||
self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image]
|
||||
|
||||
def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
|
||||
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
|
||||
if state.billing_enabled is False:
|
||||
return
|
||||
|
||||
|
@ -385,19 +387,23 @@ class ConversationCommandRateLimiter:
|
|||
return
|
||||
|
||||
user: KhojUser = request.user.object
|
||||
user_cache = self.cache[user.uuid]
|
||||
subscribed = has_required_scope(request, ["premium"])
|
||||
user_cache[conversation_command].append(time())
|
||||
|
||||
# Remove requests outside of the 24-hr time window
|
||||
cutoff = time() - 60 * 60 * 24
|
||||
while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff:
|
||||
user_cache[conversation_command].pop(0)
|
||||
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=60 * 60 * 24)
|
||||
command_slug = f"{self.slug}_{conversation_command.value}"
|
||||
count_requests = await UserRequests.objects.filter(
|
||||
user=user, created_at__gte=cutoff, slug=command_slug
|
||||
).acount()
|
||||
|
||||
if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit:
|
||||
raise HTTPException(status_code=429, detail="Too Many Requests")
|
||||
if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit:
|
||||
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
|
||||
if subscribed and count_requests >= self.subscribed_rate_limit:
|
||||
raise HTTPException(status_code=429, detail="Slow down! Too Many Requests")
|
||||
if not subscribed and count_requests >= self.trial_rate_limit:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"We're glad you're enjoying Khoj! You've exceeded your `/{conversation_command.value}` command usage limit for today. You can increase your rate limit via [your settings](https://app.khoj.dev/config).",
|
||||
)
|
||||
await UserRequests.objects.acreate(user=user, slug=command_slug)
|
||||
return
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue