mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-12-18 10:37:11 +00:00
Add rate limiting to OTP login attempts and update email template for conciseness
This commit is contained in:
parent
b7783357fa
commit
ae9750e58e
6 changed files with 90 additions and 10 deletions
|
@ -36,6 +36,8 @@ export interface LoginPromptProps {
|
|||
|
||||
const fetcher = (url: string) => fetch(url).then((res) => res.json());
|
||||
|
||||
const ALLOWED_OTP_ATTEMPTS = 5;
|
||||
|
||||
interface Provider {
|
||||
client_id: string;
|
||||
redirect_uri: string;
|
||||
|
@ -230,10 +232,16 @@ function EmailSignInContext({
|
|||
}) {
|
||||
const [otp, setOTP] = useState("");
|
||||
const [otpError, setOTPError] = useState("");
|
||||
const [numFailures, setNumFailures] = useState(0);
|
||||
|
||||
function checkOTPAndRedirect() {
|
||||
const verifyUrl = `/auth/magic?code=${otp}&email=${email}`;
|
||||
|
||||
if (numFailures >= ALLOWED_OTP_ATTEMPTS) {
|
||||
setOTPError("Too many failed attempts. Please try again tomorrow.");
|
||||
return;
|
||||
}
|
||||
|
||||
fetch(verifyUrl, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
|
@ -246,8 +254,16 @@ function EmailSignInContext({
|
|||
if (res.redirected) {
|
||||
window.location.href = res.url;
|
||||
}
|
||||
} else if (res.status === 401) {
|
||||
setOTPError("Invalid OTP.");
|
||||
setNumFailures(numFailures + 1);
|
||||
if (numFailures + 1 >= ALLOWED_OTP_ATTEMPTS) {
|
||||
setOTPError("Too many failed attempts. Please try again tomorrow.");
|
||||
}
|
||||
} else if (res.status === 429) {
|
||||
setOTPError("Too many failed attempts. Please try again tomorrow.");
|
||||
setNumFailures(ALLOWED_OTP_ATTEMPTS);
|
||||
} else {
|
||||
setOTPError("Invalid OTP");
|
||||
throw new Error("Failed to verify OTP");
|
||||
}
|
||||
})
|
||||
|
@ -309,6 +325,7 @@ function EmailSignInContext({
|
|||
maxLength={6}
|
||||
value={otp || ""}
|
||||
onChange={setOTP}
|
||||
disabled={numFailures >= ALLOWED_OTP_ATTEMPTS}
|
||||
onComplete={() =>
|
||||
setTimeout(() => {
|
||||
checkOTPAndRedirect();
|
||||
|
@ -324,7 +341,11 @@ function EmailSignInContext({
|
|||
<InputOTPSlot index={5} />
|
||||
</InputOTPGroup>
|
||||
</InputOTP>
|
||||
<div className="text-red-500 text-sm">{otpError}</div>
|
||||
{otpError && (
|
||||
<div className="text-red-500 text-sm">
|
||||
{otpError} {ALLOWED_OTP_ATTEMPTS - numFailures} remaining attempts.
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
|
|
@ -437,10 +437,14 @@ def is_user_subscribed(user: KhojUser) -> bool:
|
|||
return subscribed
|
||||
|
||||
|
||||
async def get_user_by_email(email: str) -> KhojUser:
|
||||
async def aget_user_by_email(email: str) -> KhojUser:
|
||||
return await KhojUser.objects.filter(email=email).afirst()
|
||||
|
||||
|
||||
def get_user_by_email(email: str) -> KhojUser:
|
||||
return KhojUser.objects.filter(email=email).first()
|
||||
|
||||
|
||||
async def aget_user_by_uuid(uuid: str) -> KhojUser:
|
||||
return await KhojUser.objects.filter(uuid=uuid).afirst()
|
||||
|
||||
|
|
|
@ -16,12 +16,12 @@
|
|||
<img src="https://assets.khoj.dev/khoj_logo.png" alt="Khoj Logo" style="width: 120px;">
|
||||
</a>
|
||||
|
||||
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Hi!</p>
|
||||
|
||||
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Use this code (valid for 5 minutes) to login to Khoj:</p>
|
||||
|
||||
<h1 style="font-size: 24px; color: #2c3e50; margin-bottom: 20px; text-align: center;">{{ code }}</h1>
|
||||
|
||||
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">It will be valid for 5 minutes.</p>
|
||||
|
||||
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Alternatively, <a href="{{ link }}" target="_blank"
|
||||
style="color: #FFA07A; text-decoration: none; font-weight: bold;">Click here to sign in on this
|
||||
browser.</a></p>
|
||||
|
|
|
@ -5,7 +5,7 @@ import os
|
|||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from starlette.authentication import requires
|
||||
from starlette.config import Config
|
||||
|
@ -22,7 +22,11 @@ from khoj.database.adapters import (
|
|||
get_or_create_user,
|
||||
)
|
||||
from khoj.routers.email import send_magic_link_email, send_welcome_email
|
||||
from khoj.routers.helpers import get_next_url, update_telemetry_state
|
||||
from khoj.routers.helpers import (
|
||||
EmailVerificationApiRateLimiter,
|
||||
get_next_url,
|
||||
update_telemetry_state,
|
||||
)
|
||||
from khoj.utils import state
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -99,7 +103,14 @@ async def login_magic_link(request: Request, form: MagicLinkForm):
|
|||
|
||||
|
||||
@auth_router.get("/magic")
|
||||
async def sign_in_with_magic_link(request: Request, code: str, email: str):
|
||||
async def sign_in_with_magic_link(
|
||||
request: Request,
|
||||
code: str,
|
||||
email: str,
|
||||
rate_limiter=Depends(
|
||||
EmailVerificationApiRateLimiter(requests=10, window=60 * 60 * 24, slug="magic_link_verification")
|
||||
),
|
||||
):
|
||||
user = await aget_user_validated_by_email_verification_code(code, email)
|
||||
if user:
|
||||
id_info = {
|
||||
|
@ -108,7 +119,7 @@ async def sign_in_with_magic_link(request: Request, code: str, email: str):
|
|||
|
||||
request.session["user"] = dict(id_info)
|
||||
return RedirectResponse(url="/")
|
||||
return RedirectResponse(request.app.url_path_for("login_page"))
|
||||
return Response(status_code=401)
|
||||
|
||||
|
||||
@auth_router.post("/token")
|
||||
|
|
|
@ -47,7 +47,7 @@ async def send_magic_link_email(email, unique_id, host):
|
|||
{
|
||||
"sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"),
|
||||
"to": email,
|
||||
"subject": f"{unique_id} - Sign in to Khoj 🚀",
|
||||
"subject": f"Your unique login to Khoj",
|
||||
"html": html_content,
|
||||
}
|
||||
)
|
||||
|
|
|
@ -49,6 +49,7 @@ from khoj.database.adapters import (
|
|||
ais_user_subscribed,
|
||||
create_khoj_token,
|
||||
get_khoj_tokens,
|
||||
get_user_by_email,
|
||||
get_user_name,
|
||||
get_user_notion_config,
|
||||
get_user_subscription_state,
|
||||
|
@ -1363,6 +1364,49 @@ class FeedbackData(BaseModel):
|
|||
sentiment: str
|
||||
|
||||
|
||||
class EmailVerificationApiRateLimiter:
|
||||
def __init__(self, requests: int, window: int, slug: str):
|
||||
self.requests = requests
|
||||
self.window = window
|
||||
self.slug = slug
|
||||
|
||||
def __call__(self, request: Request):
|
||||
# Rate limiting disabled if billing is disabled
|
||||
if state.billing_enabled is False:
|
||||
return
|
||||
|
||||
# Extract the email query parameter
|
||||
email = request.query_params.get("email")
|
||||
|
||||
if email:
|
||||
logger.info(f"Email query parameter: {email}")
|
||||
|
||||
user: KhojUser = get_user_by_email(email)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="User not found.",
|
||||
)
|
||||
|
||||
# Remove requests outside of the time window
|
||||
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
|
||||
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
||||
|
||||
# Check if the user has exceeded the rate limit
|
||||
if count_requests >= self.requests:
|
||||
logger.info(
|
||||
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for email: {email}."
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Ran out of login attempts",
|
||||
)
|
||||
|
||||
# Add the current request to the db
|
||||
UserRequests.objects.create(user=user, slug=self.slug)
|
||||
|
||||
|
||||
class ApiUserRateLimiter:
|
||||
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
||||
self.requests = requests
|
||||
|
|
Loading…
Reference in a new issue