Add rate limiting to OTP login attempts and update email template for conciseness

This commit is contained in:
sabaimran 2024-12-16 11:59:45 -08:00
parent b7783357fa
commit ae9750e58e
6 changed files with 90 additions and 10 deletions

View file

@ -36,6 +36,8 @@ export interface LoginPromptProps {
const fetcher = (url: string) => fetch(url).then((res) => res.json()); const fetcher = (url: string) => fetch(url).then((res) => res.json());
const ALLOWED_OTP_ATTEMPTS = 5;
interface Provider { interface Provider {
client_id: string; client_id: string;
redirect_uri: string; redirect_uri: string;
@ -230,10 +232,16 @@ function EmailSignInContext({
}) { }) {
const [otp, setOTP] = useState(""); const [otp, setOTP] = useState("");
const [otpError, setOTPError] = useState(""); const [otpError, setOTPError] = useState("");
const [numFailures, setNumFailures] = useState(0);
function checkOTPAndRedirect() { function checkOTPAndRedirect() {
const verifyUrl = `/auth/magic?code=${otp}&email=${email}`; const verifyUrl = `/auth/magic?code=${otp}&email=${email}`;
if (numFailures >= ALLOWED_OTP_ATTEMPTS) {
setOTPError("Too many failed attempts. Please try again tomorrow.");
return;
}
fetch(verifyUrl, { fetch(verifyUrl, {
method: "GET", method: "GET",
headers: { headers: {
@ -246,8 +254,16 @@ function EmailSignInContext({
if (res.redirected) { if (res.redirected) {
window.location.href = res.url; window.location.href = res.url;
} }
} else if (res.status === 401) {
setOTPError("Invalid OTP.");
setNumFailures(numFailures + 1);
if (numFailures + 1 >= ALLOWED_OTP_ATTEMPTS) {
setOTPError("Too many failed attempts. Please try again tomorrow.");
}
} else if (res.status === 429) {
setOTPError("Too many failed attempts. Please try again tomorrow.");
setNumFailures(ALLOWED_OTP_ATTEMPTS);
} else { } else {
setOTPError("Invalid OTP");
throw new Error("Failed to verify OTP"); throw new Error("Failed to verify OTP");
} }
}) })
@ -309,6 +325,7 @@ function EmailSignInContext({
maxLength={6} maxLength={6}
value={otp || ""} value={otp || ""}
onChange={setOTP} onChange={setOTP}
disabled={numFailures >= ALLOWED_OTP_ATTEMPTS}
onComplete={() => onComplete={() =>
setTimeout(() => { setTimeout(() => {
checkOTPAndRedirect(); checkOTPAndRedirect();
@ -324,7 +341,11 @@ function EmailSignInContext({
<InputOTPSlot index={5} /> <InputOTPSlot index={5} />
</InputOTPGroup> </InputOTPGroup>
</InputOTP> </InputOTP>
<div className="text-red-500 text-sm">{otpError}</div> {otpError && (
<div className="text-red-500 text-sm">
{otpError} {ALLOWED_OTP_ATTEMPTS - numFailures} remaining attempts.
</div>
)}
</div> </div>
)} )}

View file

@ -437,10 +437,14 @@ def is_user_subscribed(user: KhojUser) -> bool:
return subscribed return subscribed
async def get_user_by_email(email: str) -> KhojUser: async def aget_user_by_email(email: str) -> KhojUser:
return await KhojUser.objects.filter(email=email).afirst() return await KhojUser.objects.filter(email=email).afirst()
def get_user_by_email(email: str) -> KhojUser:
return KhojUser.objects.filter(email=email).first()
async def aget_user_by_uuid(uuid: str) -> KhojUser: async def aget_user_by_uuid(uuid: str) -> KhojUser:
return await KhojUser.objects.filter(uuid=uuid).afirst() return await KhojUser.objects.filter(uuid=uuid).afirst()

View file

@ -16,12 +16,12 @@
<img src="https://assets.khoj.dev/khoj_logo.png" alt="Khoj Logo" style="width: 120px;"> <img src="https://assets.khoj.dev/khoj_logo.png" alt="Khoj Logo" style="width: 120px;">
</a> </a>
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Hi!</p>
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Use this code (valid for 5 minutes) to login to Khoj:</p> <p style="font-size: 16px; color: #333; margin-bottom: 20px;">Use this code (valid for 5 minutes) to login to Khoj:</p>
<h1 style="font-size: 24px; color: #2c3e50; margin-bottom: 20px; text-align: center;">{{ code }}</h1> <h1 style="font-size: 24px; color: #2c3e50; margin-bottom: 20px; text-align: center;">{{ code }}</h1>
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">It will be valid for 5 minutes.</p>
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Alternatively, <a href="{{ link }}" target="_blank" <p style="font-size: 16px; color: #333; margin-bottom: 20px;">Alternatively, <a href="{{ link }}" target="_blank"
style="color: #FFA07A; text-decoration: none; font-weight: bold;">Click here to sign in on this style="color: #FFA07A; text-decoration: none; font-weight: bold;">Click here to sign in on this
browser.</a></p> browser.</a></p>

View file

@ -5,7 +5,7 @@ import os
from typing import Optional from typing import Optional
import requests import requests
from fastapi import APIRouter from fastapi import APIRouter, Depends
from pydantic import BaseModel, EmailStr from pydantic import BaseModel, EmailStr
from starlette.authentication import requires from starlette.authentication import requires
from starlette.config import Config from starlette.config import Config
@ -22,7 +22,11 @@ from khoj.database.adapters import (
get_or_create_user, get_or_create_user,
) )
from khoj.routers.email import send_magic_link_email, send_welcome_email from khoj.routers.email import send_magic_link_email, send_welcome_email
from khoj.routers.helpers import get_next_url, update_telemetry_state from khoj.routers.helpers import (
EmailVerificationApiRateLimiter,
get_next_url,
update_telemetry_state,
)
from khoj.utils import state from khoj.utils import state
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -99,7 +103,14 @@ async def login_magic_link(request: Request, form: MagicLinkForm):
@auth_router.get("/magic") @auth_router.get("/magic")
async def sign_in_with_magic_link(request: Request, code: str, email: str): async def sign_in_with_magic_link(
request: Request,
code: str,
email: str,
rate_limiter=Depends(
EmailVerificationApiRateLimiter(requests=10, window=60 * 60 * 24, slug="magic_link_verification")
),
):
user = await aget_user_validated_by_email_verification_code(code, email) user = await aget_user_validated_by_email_verification_code(code, email)
if user: if user:
id_info = { id_info = {
@ -108,7 +119,7 @@ async def sign_in_with_magic_link(request: Request, code: str, email: str):
request.session["user"] = dict(id_info) request.session["user"] = dict(id_info)
return RedirectResponse(url="/") return RedirectResponse(url="/")
return RedirectResponse(request.app.url_path_for("login_page")) return Response(status_code=401)
@auth_router.post("/token") @auth_router.post("/token")

View file

@ -47,7 +47,7 @@ async def send_magic_link_email(email, unique_id, host):
{ {
"sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"), "sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"),
"to": email, "to": email,
"subject": f"{unique_id} - Sign in to Khoj 🚀", "subject": f"Your unique login to Khoj",
"html": html_content, "html": html_content,
} }
) )

View file

@ -49,6 +49,7 @@ from khoj.database.adapters import (
ais_user_subscribed, ais_user_subscribed,
create_khoj_token, create_khoj_token,
get_khoj_tokens, get_khoj_tokens,
get_user_by_email,
get_user_name, get_user_name,
get_user_notion_config, get_user_notion_config,
get_user_subscription_state, get_user_subscription_state,
@ -1363,6 +1364,49 @@ class FeedbackData(BaseModel):
sentiment: str sentiment: str
class EmailVerificationApiRateLimiter:
def __init__(self, requests: int, window: int, slug: str):
self.requests = requests
self.window = window
self.slug = slug
def __call__(self, request: Request):
# Rate limiting disabled if billing is disabled
if state.billing_enabled is False:
return
# Extract the email query parameter
email = request.query_params.get("email")
if email:
logger.info(f"Email query parameter: {email}")
user: KhojUser = get_user_by_email(email)
if not user:
raise HTTPException(
status_code=404,
detail="User not found.",
)
# Remove requests outside of the time window
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
# Check if the user has exceeded the rate limit
if count_requests >= self.requests:
logger.info(
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for email: {email}."
)
raise HTTPException(
status_code=429,
detail="Ran out of login attempts",
)
# Add the current request to the db
UserRequests.objects.create(user=user, slug=self.slug)
class ApiUserRateLimiter: class ApiUserRateLimiter:
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str): def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
self.requests = requests self.requests = requests