Add a more gracefull error message when the rate limit is exceeded

This commit is contained in:
sabaimran 2024-04-08 15:20:54 +05:30
parent 11ce3e2268
commit 089e0d028b
2 changed files with 28 additions and 6 deletions

View file

@ -932,6 +932,8 @@ To get started, just start typing below. You can also type / to see a list of co
websocketState.references = references; websocketState.references = references;
} else if (chunk.type == "status") { } else if (chunk.type == "status") {
handleStreamResponse(websocketState.newResponseText, chunk.message, null, false); handleStreamResponse(websocketState.newResponseText, chunk.message, null, false);
} else if (chunk.type == "rate_limit") {
handleStreamResponse(websocketState.newResponseText, chunk.message, websocketState.loadingEllipsis, true);
} else { } else {
rawResponse = chunk.response; rawResponse = chunk.response;
} }
@ -939,7 +941,7 @@ To get started, just start typing below. You can also type / to see a list of co
// If the chunk is not a JSON object, just display it as is // If the chunk is not a JSON object, just display it as is
websocketState.rawResponse += chunk; websocketState.rawResponse += chunk;
} finally { } finally {
if (chunk.type != "status") { if (chunk.type != "status" && chunk.type != "rate_limit") {
addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseText, websocketState.references); addMessageToChatBody(websocketState.rawResponse, websocketState.newResponseText, websocketState.references);
} }
} }

View file

@ -5,7 +5,7 @@ from typing import Dict, Optional
from urllib.parse import unquote from urllib.parse import unquote
from asgiref.sync import sync_to_async from asgiref.sync import sync_to_async
from fastapi import APIRouter, Depends, Request, WebSocket from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
from fastapi.requests import Request from fastapi.requests import Request
from fastapi.responses import Response, StreamingResponse from fastapi.responses import Response, StreamingResponse
from starlette.authentication import requires from starlette.authentication import requires
@ -292,14 +292,30 @@ async def websocket_endpoint(
connection_alive = False connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread") logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
async def send_rate_limit_message(message: str):
nonlocal connection_alive
if not connection_alive:
return
status_packet = {
"type": "rate_limit",
"message": message,
"content-type": "application/json",
}
try:
await websocket.send_text(json.dumps(status_packet))
except ConnectionClosedOK:
connection_alive = False
logger.info(f"User {user} disconnected web socket. Emitting rest of responses to clear thread")
user: KhojUser = websocket.user.object user: KhojUser = websocket.user.object
conversation = await ConversationAdapters.aget_conversation_by_user( conversation = await ConversationAdapters.aget_conversation_by_user(
user, client_application=websocket.user.client_app, conversation_id=conversation_id user, client_application=websocket.user.client_app, conversation_id=conversation_id
) )
hourly_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=60, window=60, slug="chat_minute") hourly_limiter = ApiUserRateLimiter(requests=1, subscribed_requests=60, window=60, slug="chat_minute")
daily_limiter = ApiUserRateLimiter(requests=5, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day") daily_limiter = ApiUserRateLimiter(requests=1, subscribed_requests=600, window=60 * 60 * 24, slug="chat_day")
await is_ready_to_chat(user) await is_ready_to_chat(user)
@ -318,8 +334,12 @@ async def websocket_endpoint(
logger.debug(f"User {user} disconnected web socket") logger.debug(f"User {user} disconnected web socket")
break break
await sync_to_async(hourly_limiter)(websocket) try:
await sync_to_async(daily_limiter)(websocket) await sync_to_async(hourly_limiter)(websocket)
await sync_to_async(daily_limiter)(websocket)
except HTTPException as e:
await send_rate_limit_message(e.detail)
break
conversation_commands = [get_conversation_command(query=q, any_references=True)] conversation_commands = [get_conversation_command(query=q, any_references=True)]