mirror of
https://github.com/khoj-ai/khoj.git
synced 2025-02-18 23:54:20 +00:00
Rate limit the count and total size of images shared via API
This commit is contained in:
parent
7646ac6779
commit
e8fb79a369
4 changed files with 75 additions and 21 deletions
src
interface/web/app
khoj/routers
|
@ -265,7 +265,8 @@ export default function Chat() {
|
|||
try {
|
||||
await readChatStream(response);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
const apiError = await response.json();
|
||||
console.error(apiError);
|
||||
// Retrieve latest message being processed
|
||||
const currentMessage = messages.find((message) => !message.completed);
|
||||
if (!currentMessage) return;
|
||||
|
@ -274,7 +275,11 @@ export default function Chat() {
|
|||
const errorMessage = (err as Error).message;
|
||||
if (errorMessage.includes("Error in input stream"))
|
||||
currentMessage.rawResponse = `Woops! The connection broke while I was writing my thoughts down. Maybe try again in a bit or dislike this message if the issue persists?`;
|
||||
else
|
||||
else if (response.status === 429) {
|
||||
"detail" in apiError
|
||||
? (currentMessage.rawResponse = `${apiError.detail}`)
|
||||
: (currentMessage.rawResponse = `I'm a bit overwhelmed at the moment. Could you try again in a bit or dislike this message if the issue persists?`);
|
||||
} else
|
||||
currentMessage.rawResponse = `Umm, not sure what just happened. I see this error message: ${errorMessage}. Could you try again or dislike this message if the issue persists?`;
|
||||
|
||||
// Complete message streaming teardown properly
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
div.actualInputArea {
|
||||
display: grid;
|
||||
grid-template-columns: auto 1fr auto auto;
|
||||
max-width: 700px;
|
||||
}
|
||||
|
|
|
@ -30,8 +30,10 @@ from khoj.processor.speech.text_to_speech import generate_text_to_speech
|
|||
from khoj.processor.tools.online_search import read_webpages, search_online
|
||||
from khoj.routers.api import extract_references_and_questions
|
||||
from khoj.routers.helpers import (
|
||||
ApiImageRateLimiter,
|
||||
ApiUserRateLimiter,
|
||||
ChatEvent,
|
||||
ChatRequestBody,
|
||||
CommonQueryParams,
|
||||
ConversationCommandRateLimiter,
|
||||
agenerate_chat_response,
|
||||
|
@ -523,22 +525,6 @@ async def set_conversation_title(
|
|||
)
|
||||
|
||||
|
||||
class ChatRequestBody(BaseModel):
|
||||
q: str
|
||||
n: Optional[int] = 7
|
||||
d: Optional[float] = None
|
||||
stream: Optional[bool] = False
|
||||
title: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
images: Optional[list[str]] = None
|
||||
create_new: Optional[bool] = False
|
||||
|
||||
|
||||
@api_chat.post("")
|
||||
@requires(["authenticated"])
|
||||
async def chat(
|
||||
|
@ -551,6 +537,7 @@ async def chat(
|
|||
rate_limiter_per_day=Depends(
|
||||
ApiUserRateLimiter(requests=600, subscribed_requests=6000, window=60 * 60 * 24, slug="chat_day")
|
||||
),
|
||||
image_rate_limiter=Depends(ApiImageRateLimiter(max_images=10, max_combined_size_mb=10)),
|
||||
):
|
||||
# Access the parameters from the body
|
||||
q = body.q
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
|
@ -21,7 +22,7 @@ from typing import (
|
|||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from urllib.parse import parse_qs, quote, urljoin, urlparse
|
||||
from urllib.parse import parse_qs, quote, unquote, urljoin, urlparse
|
||||
|
||||
import cron_descriptor
|
||||
import pytz
|
||||
|
@ -30,6 +31,7 @@ from apscheduler.job import Job
|
|||
from apscheduler.triggers.cron import CronTrigger
|
||||
from asgiref.sync import sync_to_async
|
||||
from fastapi import Depends, Header, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from starlette.authentication import has_required_scope
|
||||
from starlette.requests import URL
|
||||
|
||||
|
@ -1019,6 +1021,22 @@ def generate_chat_response(
|
|||
return chat_response, metadata
|
||||
|
||||
|
||||
class ChatRequestBody(BaseModel):
|
||||
q: str
|
||||
n: Optional[int] = 7
|
||||
d: Optional[float] = None
|
||||
stream: Optional[bool] = False
|
||||
title: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
region: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
country_code: Optional[str] = None
|
||||
timezone: Optional[str] = None
|
||||
images: Optional[list[str]] = None
|
||||
create_new: Optional[bool] = False
|
||||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
||||
self.requests = requests
|
||||
|
@ -1064,13 +1082,58 @@ class ApiUserRateLimiter:
|
|||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="We're glad you're enjoying Khoj! You've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
|
||||
detail="I'm glad you're enjoying interacting with me! But you've exceeded your usage limit for today. Come back tomorrow or subscribe to increase your usage limit via [your settings](https://app.khoj.dev/settings).",
|
||||
)
|
||||
|
||||
# Add the current request to the cache
|
||||
UserRequests.objects.create(user=user, slug=self.slug)
|
||||
|
||||
|
||||
class ApiImageRateLimiter:
|
||||
def __init__(self, max_images: int = 10, max_combined_size_mb: float = 10):
|
||||
self.max_images = max_images
|
||||
self.max_combined_size_mb = max_combined_size_mb
|
||||
|
||||
def __call__(self, request: Request, body: ChatRequestBody):
|
||||
if state.billing_enabled is False:
|
||||
return
|
||||
|
||||
# Rate limiting is disabled if user unauthenticated.
|
||||
# Other systems handle authentication
|
||||
if not request.user.is_authenticated:
|
||||
return
|
||||
|
||||
if not body.images:
|
||||
return
|
||||
|
||||
# Check number of images
|
||||
if len(body.images) > self.max_images:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Those are way too many images for me! I can handle up to {self.max_images} images per message.",
|
||||
)
|
||||
|
||||
# Check total size of images
|
||||
total_size_mb = 0.0
|
||||
for image in body.images:
|
||||
# Unquote the image in case it's URL encoded
|
||||
image = unquote(image)
|
||||
# Assuming the image is a base64 encoded string
|
||||
# Remove the data:image/jpeg;base64, part if present
|
||||
if "," in image:
|
||||
image = image.split(",", 1)[1]
|
||||
|
||||
# Decode base64 to get the actual size
|
||||
image_bytes = base64.b64decode(image)
|
||||
total_size_mb += len(image_bytes) / (1024 * 1024) # Convert bytes to MB
|
||||
|
||||
if total_size_mb > self.max_combined_size_mb:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Those images are way too large for me! I can handle up to {self.max_combined_size_mb}MB of images per message.",
|
||||
)
|
||||
|
||||
|
||||
class ConversationCommandRateLimiter:
|
||||
def __init__(self, trial_rate_limit: int, subscribed_rate_limit: int, slug: str):
|
||||
self.slug = slug
|
||||
|
|
Loading…
Add table
Reference in a new issue