Store rate limiter-related metadata in the database for more resilience (#629)

* Store rate limiter-related metadata in the database for more resilience
- This helps maintain state even between server restarts
- Allows you to scale up workers on your service without having to implement sticky routing
* Make the usage exceeded message less abrasive
* Fix rate limiter for specific conversation commands and improve the copy
This commit is contained in:
sabaimran 2024-01-29 01:57:06 -08:00 committed by GitHub
parent 71cbe5160d
commit 4fb8d5c6d4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 90 additions and 30 deletions

View file

@ -27,6 +27,7 @@ from khoj.database.adapters import (
aget_or_create_user_by_phone_number,
aget_user_by_phone_number,
aget_user_subscription_state,
delete_user_requests,
get_all_users,
get_or_create_search_models,
)
@ -328,3 +329,9 @@ def upload_telemetry():
logger.error(f"📡 Error uploading telemetry: {e}", exc_info=True)
else:
state.telemetry = []
@schedule.repeat(schedule.every(31).minutes)
def delete_old_user_requests():
num_deleted = delete_user_requests()
logger.info(f"🔥 Deleted {num_deleted} day-old user requests")

View file

@ -34,6 +34,7 @@ from khoj.database.models import (
Subscription,
TextToImageModelConfig,
UserConversationConfig,
UserRequests,
UserSearchModelConfig,
)
from khoj.search_filter.date_filter import DateFilter
@ -284,6 +285,10 @@ def get_user_notion_config(user: KhojUser):
return config
def delete_user_requests(window: timedelta = timedelta(days=1)):
return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete()
async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config):
deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None
deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None

View file

@ -0,0 +1,27 @@
# Generated by Django 4.2.7 on 2024-01-29 08:55
import django.db.models.deletion
from django.conf import settings
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("database", "0028_khojuser_verified_phone_number"),
]
operations = [
migrations.CreateModel(
name="UserRequests",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("slug", models.CharField(max_length=200)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
]

View file

@ -223,3 +223,8 @@ class EntryDates(BaseModel):
indexes = [
models.Index(fields=["date"]),
]
class UserRequests(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
slug = models.CharField(max_length=200)

View file

@ -59,7 +59,9 @@ from khoj.utils.state import SearchType
# Initialize Router
api = APIRouter()
logger = logging.getLogger(__name__)
conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100)
conversation_command_rate_limiter = ConversationCommandRateLimiter(
trial_rate_limit=2, subscribed_rate_limit=100, slug="command"
)
@api.get("/search", response_model=List[SearchResponse])
@ -301,8 +303,12 @@ async def transcribe(
request: Request,
common: CommonQueryParams,
file: UploadFile = File(...),
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)),
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60, slug="transcribe_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24, slug="transcribe_day")
),
):
user: KhojUser = request.user.object
audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm"
@ -361,8 +367,12 @@ async def chat(
n: Optional[int] = 5,
d: Optional[float] = 0.18,
stream: Optional[bool] = False,
rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60)),
rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24)),
rate_limiter_per_minute=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute")
),
rate_limiter_per_day=Depends(
ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
),
) -> Response:
user: KhojUser = request.user.object
q = unquote(q)
@ -370,7 +380,7 @@ async def chat(
await is_ready_to_chat(user)
conversation_command = get_conversation_command(query=q, any_references=True)
conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
await conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command)
q = q.replace(f"/{conversation_command.value}", "").strip()

View file

@ -3,7 +3,7 @@ import json
import logging
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from datetime import datetime, timedelta, timezone
from functools import partial
from time import time
from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union
@ -19,6 +19,7 @@ from khoj.database.models import (
KhojUser,
Subscription,
TextToImageModelConfig,
UserRequests,
)
from khoj.processor.conversation import prompts
from khoj.processor.conversation.offline.chat_model import (
@ -336,11 +337,11 @@ async def text_to_image(message: str, conversation_log: dict) -> Tuple[Optional[
class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int):
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests
self.subscribed_requests = subscribed_requests
self.window = window
self.cache: dict[str, list[float]] = defaultdict(list)
self.slug = slug
def __call__(self, request: Request):
# Rate limiting is disabled if user unauthenticated.
@ -350,31 +351,32 @@ class ApiUserRateLimiter:
user: KhojUser = request.user.object
subscribed = has_required_scope(request, ["premium"])
user_requests = self.cache[user.uuid]
# Remove requests outside of the time window
cutoff = time() - self.window
while user_requests and user_requests[0] < cutoff:
user_requests.pop(0)
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
# Check if the user has exceeded the rate limit
if subscribed and len(user_requests) >= self.subscribed_requests:
raise HTTPException(status_code=429, detail="Too Many Requests")
if not subscribed and len(user_requests) >= self.requests:
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
if subscribed and count_requests >= self.subscribed_requests:
raise HTTPException(status_code=429, detail="Slow down! Too Many Requests")
if not subscribed and count_requests >= self.requests:
raise HTTPException(
status_code=429,
detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your rate limit via [your settings](https://app.khoj.dev/config).",
)
# Add the current request to the cache
user_requests.append(time())
UserRequests.objects.create(user=user, slug=self.slug)
class ConversationCommandRateLimiter:
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int):
self.cache: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
self.slug = slug
self.trial_rate_limit = trial_rate_limit
self.subscribed_rate_limit = subscribed_rate_limit
self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image]
def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand):
if state.billing_enabled is False:
return
@ -385,19 +387,23 @@ class ConversationCommandRateLimiter:
return
user: KhojUser = request.user.object
user_cache = self.cache[user.uuid]
subscribed = has_required_scope(request, ["premium"])
user_cache[conversation_command].append(time())
# Remove requests outside of the 24-hr time window
cutoff = time() - 60 * 60 * 24
while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff:
user_cache[conversation_command].pop(0)
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=60 * 60 * 24)
command_slug = f"{self.slug}_{conversation_command.value}"
count_requests = await UserRequests.objects.filter(
user=user, created_at__gte=cutoff, slug=command_slug
).acount()
if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit:
raise HTTPException(status_code=429, detail="Too Many Requests")
if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit:
raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.")
if subscribed and count_requests >= self.subscribed_rate_limit:
raise HTTPException(status_code=429, detail="Slow down! Too Many Requests")
if not subscribed and count_requests >= self.trial_rate_limit:
raise HTTPException(
status_code=429,
detail=f"We're glad you're enjoying Khoj! You've exceeded your `/{conversation_command.value}` command usage limit for today. You can increase your rate limit via [your settings](https://app.khoj.dev/config).",
)
await UserRequests.objects.acreate(user=user, slug=command_slug)
return