diff --git a/src/khoj/configure.py b/src/khoj/configure.py index f3b75d1f..5d82fa4d 100644 --- a/src/khoj/configure.py +++ b/src/khoj/configure.py @@ -27,6 +27,7 @@ from khoj.database.adapters import ( aget_or_create_user_by_phone_number, aget_user_by_phone_number, aget_user_subscription_state, + delete_user_requests, get_all_users, get_or_create_search_models, ) @@ -328,3 +329,9 @@ def upload_telemetry(): logger.error(f"📡 Error uploading telemetry: {e}", exc_info=True) else: state.telemetry = [] + + +@schedule.repeat(schedule.every(31).minutes) +def delete_old_user_requests(): + num_deleted = delete_user_requests() + logger.info(f"🔥 Deleted {num_deleted} day-old user requests") diff --git a/src/khoj/database/adapters/__init__.py b/src/khoj/database/adapters/__init__.py index d86c244a..b4cf9c4c 100644 --- a/src/khoj/database/adapters/__init__.py +++ b/src/khoj/database/adapters/__init__.py @@ -34,6 +34,7 @@ from khoj.database.models import ( Subscription, TextToImageModelConfig, UserConversationConfig, + UserRequests, UserSearchModelConfig, ) from khoj.search_filter.date_filter import DateFilter @@ -284,6 +285,10 @@ def get_user_notion_config(user: KhojUser): return config +def delete_user_requests(window: timedelta = timedelta(days=1)): + return UserRequests.objects.filter(created_at__lte=datetime.now(tz=timezone.utc) - window).delete() + + async def set_text_content_config(user: KhojUser, object: Type[models.Model], updated_config): deduped_files = list(set(updated_config.input_files)) if updated_config.input_files else None deduped_filters = list(set(updated_config.input_filter)) if updated_config.input_filter else None diff --git a/src/khoj/database/migrations/0029_userrequests.py b/src/khoj/database/migrations/0029_userrequests.py new file mode 100644 index 00000000..cf4d5413 --- /dev/null +++ b/src/khoj/database/migrations/0029_userrequests.py @@ -0,0 +1,27 @@ +# Generated by Django 4.2.7 on 2024-01-29 08:55 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0028_khojuser_verified_phone_number"), + ] + + operations = [ + migrations.CreateModel( + name="UserRequests", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ("created_at", models.DateTimeField(auto_now_add=True)), + ("updated_at", models.DateTimeField(auto_now=True)), + ("slug", models.CharField(max_length=200)), + ("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 688ca9af..5a7ccd23 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -223,3 +223,8 @@ class EntryDates(BaseModel): indexes = [ models.Index(fields=["date"]), ] + + +class UserRequests(BaseModel): + user = models.ForeignKey(KhojUser, on_delete=models.CASCADE) + slug = models.CharField(max_length=200) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 65f938bf..2b3b7c6d 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -59,7 +59,9 @@ from khoj.utils.state import SearchType # Initialize Router api = APIRouter() logger = logging.getLogger(__name__) -conversation_command_rate_limiter = ConversationCommandRateLimiter(trial_rate_limit=5, subscribed_rate_limit=100) +conversation_command_rate_limiter = ConversationCommandRateLimiter( + trial_rate_limit=2, subscribed_rate_limit=100, slug="command" +) @api.get("/search", response_model=List[SearchResponse]) @@ -301,8 +303,12 @@ async def transcribe( request: Request, common: CommonQueryParams, file: UploadFile = File(...), - rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60)), - rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24)), + rate_limiter_per_minute=Depends( + ApiUserRateLimiter(requests=1, subscribed_requests=10, window=60, slug="transcribe_minute") + ), + rate_limiter_per_day=Depends( + ApiUserRateLimiter(requests=10, subscribed_requests=600, window=60 * 60 * 24, slug="transcribe_day") + ), ): user: KhojUser = request.user.object audio_filename = f"{user.uuid}-{str(uuid.uuid4())}.webm" @@ -361,8 +367,12 @@ async def chat( n: Optional[int] = 5, d: Optional[float] = 0.18, stream: Optional[bool] = False, - rate_limiter_per_minute=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60)), - rate_limiter_per_day=Depends(ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24)), + rate_limiter_per_minute=Depends( + ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") + ), + rate_limiter_per_day=Depends( + ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") + ), ) -> Response: user: KhojUser = request.user.object q = unquote(q) @@ -370,7 +380,7 @@ async def chat( await is_ready_to_chat(user) conversation_command = get_conversation_command(query=q, any_references=True) - conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command) + await conversation_command_rate_limiter.update_and_check_if_valid(request, conversation_command) q = q.replace(f"/{conversation_command.value}", "").strip() diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index d7f9214b..b0d8c348 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -3,7 +3,7 @@ import json import logging from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from datetime import datetime +from datetime import datetime, timedelta, timezone from functools import partial from time import time from typing import Annotated, Any, Dict, Iterator, List, Optional, Tuple, Union @@ -19,6 +19,7 @@ from khoj.database.models import ( KhojUser, Subscription, TextToImageModelConfig, + UserRequests, ) from khoj.processor.conversation import prompts from khoj.processor.conversation.offline.chat_model import ( @@ -336,11 +337,11 @@ async def text_to_image(message: str, conversation_log: dict) -> Tuple[Optional[ class ApiUserRateLimiter: - def __init__(self, requests: int, subscribed_requests: int, window: int): + def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str): self.requests = requests self.subscribed_requests = subscribed_requests self.window = window - self.cache: dict[str, list[float]] = defaultdict(list) + self.slug = slug def __call__(self, request: Request): # Rate limiting is disabled if user unauthenticated. @@ -350,31 +351,32 @@ class ApiUserRateLimiter: user: KhojUser = request.user.object subscribed = has_required_scope(request, ["premium"]) - user_requests = self.cache[user.uuid] # Remove requests outside of the time window - cutoff = time() - self.window - while user_requests and user_requests[0] < cutoff: - user_requests.pop(0) + cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window) + count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count() # Check if the user has exceeded the rate limit - if subscribed and len(user_requests) >= self.subscribed_requests: - raise HTTPException(status_code=429, detail="Too Many Requests") - if not subscribed and len(user_requests) >= self.requests: - raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.") + if subscribed and count_requests >= self.subscribed_requests: + raise HTTPException(status_code=429, detail="Slow down! Too Many Requests") + if not subscribed and count_requests >= self.requests: + raise HTTPException( + status_code=429, + detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your rate limit via [your settings](https://app.khoj.dev/config).", + ) # Add the current request to the cache - user_requests.append(time()) + UserRequests.objects.create(user=user, slug=self.slug) class ConversationCommandRateLimiter: - def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int): - self.cache: Dict[str, Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list)) + def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str): + self.slug = slug self.trial_rate_limit = trial_rate_limit self.subscribed_rate_limit = subscribed_rate_limit self.restricted_commands = [ConversationCommand.Online, ConversationCommand.Image] - def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand): + async def update_and_check_if_valid(self, request: Request, conversation_command: ConversationCommand): if state.billing_enabled is False: return @@ -385,19 +387,23 @@ class ConversationCommandRateLimiter: return user: KhojUser = request.user.object - user_cache = self.cache[user.uuid] subscribed = has_required_scope(request, ["premium"]) - user_cache[conversation_command].append(time()) # Remove requests outside of the 24-hr time window - cutoff = time() - 60 * 60 * 24 - while user_cache[conversation_command] and user_cache[conversation_command][0] < cutoff: - user_cache[conversation_command].pop(0) + cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=60 * 60 * 24) + command_slug = f"{self.slug}_{conversation_command.value}" + count_requests = await UserRequests.objects.filter( + user=user, created_at__gte=cutoff, slug=command_slug + ).acount() - if subscribed and len(user_cache[conversation_command]) > self.subscribed_rate_limit: - raise HTTPException(status_code=429, detail="Too Many Requests") - if not subscribed and len(user_cache[conversation_command]) > self.trial_rate_limit: - raise HTTPException(status_code=429, detail="Too Many Requests. Subscribe to increase your rate limit.") + if subscribed and count_requests >= self.subscribed_rate_limit: + raise HTTPException(status_code=429, detail="Slow down! Too Many Requests") + if not subscribed and count_requests >= self.trial_rate_limit: + raise HTTPException( + status_code=429, + detail=f"We're glad you're enjoying Khoj! You've exceeded your `/{conversation_command.value}` command usage limit for today. You can increase your rate limit via [your settings](https://app.khoj.dev/config).", + ) + await UserRequests.objects.acreate(user=user, slug=command_slug) return