mirror of
https://github.com/khoj-ai/khoj.git
synced 2024-12-18 18:47:11 +00:00
Add rate limiting to OTP login attempts and update email template for conciseness
This commit is contained in:
parent
b7783357fa
commit
ae9750e58e
6 changed files with 90 additions and 10 deletions
|
@ -36,6 +36,8 @@ export interface LoginPromptProps {
|
||||||
|
|
||||||
const fetcher = (url: string) => fetch(url).then((res) => res.json());
|
const fetcher = (url: string) => fetch(url).then((res) => res.json());
|
||||||
|
|
||||||
|
const ALLOWED_OTP_ATTEMPTS = 5;
|
||||||
|
|
||||||
interface Provider {
|
interface Provider {
|
||||||
client_id: string;
|
client_id: string;
|
||||||
redirect_uri: string;
|
redirect_uri: string;
|
||||||
|
@ -230,10 +232,16 @@ function EmailSignInContext({
|
||||||
}) {
|
}) {
|
||||||
const [otp, setOTP] = useState("");
|
const [otp, setOTP] = useState("");
|
||||||
const [otpError, setOTPError] = useState("");
|
const [otpError, setOTPError] = useState("");
|
||||||
|
const [numFailures, setNumFailures] = useState(0);
|
||||||
|
|
||||||
function checkOTPAndRedirect() {
|
function checkOTPAndRedirect() {
|
||||||
const verifyUrl = `/auth/magic?code=${otp}&email=${email}`;
|
const verifyUrl = `/auth/magic?code=${otp}&email=${email}`;
|
||||||
|
|
||||||
|
if (numFailures >= ALLOWED_OTP_ATTEMPTS) {
|
||||||
|
setOTPError("Too many failed attempts. Please try again tomorrow.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
fetch(verifyUrl, {
|
fetch(verifyUrl, {
|
||||||
method: "GET",
|
method: "GET",
|
||||||
headers: {
|
headers: {
|
||||||
|
@ -246,8 +254,16 @@ function EmailSignInContext({
|
||||||
if (res.redirected) {
|
if (res.redirected) {
|
||||||
window.location.href = res.url;
|
window.location.href = res.url;
|
||||||
}
|
}
|
||||||
|
} else if (res.status === 401) {
|
||||||
|
setOTPError("Invalid OTP.");
|
||||||
|
setNumFailures(numFailures + 1);
|
||||||
|
if (numFailures + 1 >= ALLOWED_OTP_ATTEMPTS) {
|
||||||
|
setOTPError("Too many failed attempts. Please try again tomorrow.");
|
||||||
|
}
|
||||||
|
} else if (res.status === 429) {
|
||||||
|
setOTPError("Too many failed attempts. Please try again tomorrow.");
|
||||||
|
setNumFailures(ALLOWED_OTP_ATTEMPTS);
|
||||||
} else {
|
} else {
|
||||||
setOTPError("Invalid OTP");
|
|
||||||
throw new Error("Failed to verify OTP");
|
throw new Error("Failed to verify OTP");
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -309,6 +325,7 @@ function EmailSignInContext({
|
||||||
maxLength={6}
|
maxLength={6}
|
||||||
value={otp || ""}
|
value={otp || ""}
|
||||||
onChange={setOTP}
|
onChange={setOTP}
|
||||||
|
disabled={numFailures >= ALLOWED_OTP_ATTEMPTS}
|
||||||
onComplete={() =>
|
onComplete={() =>
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
checkOTPAndRedirect();
|
checkOTPAndRedirect();
|
||||||
|
@ -324,7 +341,11 @@ function EmailSignInContext({
|
||||||
<InputOTPSlot index={5} />
|
<InputOTPSlot index={5} />
|
||||||
</InputOTPGroup>
|
</InputOTPGroup>
|
||||||
</InputOTP>
|
</InputOTP>
|
||||||
<div className="text-red-500 text-sm">{otpError}</div>
|
{otpError && (
|
||||||
|
<div className="text-red-500 text-sm">
|
||||||
|
{otpError} {ALLOWED_OTP_ATTEMPTS - numFailures} remaining attempts.
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
|
|
|
@ -437,10 +437,14 @@ def is_user_subscribed(user: KhojUser) -> bool:
|
||||||
return subscribed
|
return subscribed
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_email(email: str) -> KhojUser:
|
async def aget_user_by_email(email: str) -> KhojUser:
|
||||||
return await KhojUser.objects.filter(email=email).afirst()
|
return await KhojUser.objects.filter(email=email).afirst()
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_by_email(email: str) -> KhojUser:
|
||||||
|
return KhojUser.objects.filter(email=email).first()
|
||||||
|
|
||||||
|
|
||||||
async def aget_user_by_uuid(uuid: str) -> KhojUser:
|
async def aget_user_by_uuid(uuid: str) -> KhojUser:
|
||||||
return await KhojUser.objects.filter(uuid=uuid).afirst()
|
return await KhojUser.objects.filter(uuid=uuid).afirst()
|
||||||
|
|
||||||
|
|
|
@ -16,12 +16,12 @@
|
||||||
<img src="https://assets.khoj.dev/khoj_logo.png" alt="Khoj Logo" style="width: 120px;">
|
<img src="https://assets.khoj.dev/khoj_logo.png" alt="Khoj Logo" style="width: 120px;">
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Hi!</p>
|
|
||||||
|
|
||||||
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Use this code (valid for 5 minutes) to login to Khoj:</p>
|
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Use this code (valid for 5 minutes) to login to Khoj:</p>
|
||||||
|
|
||||||
<h1 style="font-size: 24px; color: #2c3e50; margin-bottom: 20px; text-align: center;">{{ code }}</h1>
|
<h1 style="font-size: 24px; color: #2c3e50; margin-bottom: 20px; text-align: center;">{{ code }}</h1>
|
||||||
|
|
||||||
|
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">It will be valid for 5 minutes.</p>
|
||||||
|
|
||||||
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Alternatively, <a href="{{ link }}" target="_blank"
|
<p style="font-size: 16px; color: #333; margin-bottom: 20px;">Alternatively, <a href="{{ link }}" target="_blank"
|
||||||
style="color: #FFA07A; text-decoration: none; font-weight: bold;">Click here to sign in on this
|
style="color: #FFA07A; text-decoration: none; font-weight: bold;">Click here to sign in on this
|
||||||
browser.</a></p>
|
browser.</a></p>
|
||||||
|
|
|
@ -5,7 +5,7 @@ import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter, Depends
|
||||||
from pydantic import BaseModel, EmailStr
|
from pydantic import BaseModel, EmailStr
|
||||||
from starlette.authentication import requires
|
from starlette.authentication import requires
|
||||||
from starlette.config import Config
|
from starlette.config import Config
|
||||||
|
@ -22,7 +22,11 @@ from khoj.database.adapters import (
|
||||||
get_or_create_user,
|
get_or_create_user,
|
||||||
)
|
)
|
||||||
from khoj.routers.email import send_magic_link_email, send_welcome_email
|
from khoj.routers.email import send_magic_link_email, send_welcome_email
|
||||||
from khoj.routers.helpers import get_next_url, update_telemetry_state
|
from khoj.routers.helpers import (
|
||||||
|
EmailVerificationApiRateLimiter,
|
||||||
|
get_next_url,
|
||||||
|
update_telemetry_state,
|
||||||
|
)
|
||||||
from khoj.utils import state
|
from khoj.utils import state
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -99,7 +103,14 @@ async def login_magic_link(request: Request, form: MagicLinkForm):
|
||||||
|
|
||||||
|
|
||||||
@auth_router.get("/magic")
|
@auth_router.get("/magic")
|
||||||
async def sign_in_with_magic_link(request: Request, code: str, email: str):
|
async def sign_in_with_magic_link(
|
||||||
|
request: Request,
|
||||||
|
code: str,
|
||||||
|
email: str,
|
||||||
|
rate_limiter=Depends(
|
||||||
|
EmailVerificationApiRateLimiter(requests=10, window=60 * 60 * 24, slug="magic_link_verification")
|
||||||
|
),
|
||||||
|
):
|
||||||
user = await aget_user_validated_by_email_verification_code(code, email)
|
user = await aget_user_validated_by_email_verification_code(code, email)
|
||||||
if user:
|
if user:
|
||||||
id_info = {
|
id_info = {
|
||||||
|
@ -108,7 +119,7 @@ async def sign_in_with_magic_link(request: Request, code: str, email: str):
|
||||||
|
|
||||||
request.session["user"] = dict(id_info)
|
request.session["user"] = dict(id_info)
|
||||||
return RedirectResponse(url="/")
|
return RedirectResponse(url="/")
|
||||||
return RedirectResponse(request.app.url_path_for("login_page"))
|
return Response(status_code=401)
|
||||||
|
|
||||||
|
|
||||||
@auth_router.post("/token")
|
@auth_router.post("/token")
|
||||||
|
|
|
@ -47,7 +47,7 @@ async def send_magic_link_email(email, unique_id, host):
|
||||||
{
|
{
|
||||||
"sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"),
|
"sender": os.environ.get("RESEND_EMAIL", "noreply@khoj.dev"),
|
||||||
"to": email,
|
"to": email,
|
||||||
"subject": f"{unique_id} - Sign in to Khoj 🚀",
|
"subject": f"Your unique login to Khoj",
|
||||||
"html": html_content,
|
"html": html_content,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
@ -49,6 +49,7 @@ from khoj.database.adapters import (
|
||||||
ais_user_subscribed,
|
ais_user_subscribed,
|
||||||
create_khoj_token,
|
create_khoj_token,
|
||||||
get_khoj_tokens,
|
get_khoj_tokens,
|
||||||
|
get_user_by_email,
|
||||||
get_user_name,
|
get_user_name,
|
||||||
get_user_notion_config,
|
get_user_notion_config,
|
||||||
get_user_subscription_state,
|
get_user_subscription_state,
|
||||||
|
@ -1363,6 +1364,49 @@ class FeedbackData(BaseModel):
|
||||||
sentiment: str
|
sentiment: str
|
||||||
|
|
||||||
|
|
||||||
|
class EmailVerificationApiRateLimiter:
|
||||||
|
def __init__(self, requests: int, window: int, slug: str):
|
||||||
|
self.requests = requests
|
||||||
|
self.window = window
|
||||||
|
self.slug = slug
|
||||||
|
|
||||||
|
def __call__(self, request: Request):
|
||||||
|
# Rate limiting disabled if billing is disabled
|
||||||
|
if state.billing_enabled is False:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract the email query parameter
|
||||||
|
email = request.query_params.get("email")
|
||||||
|
|
||||||
|
if email:
|
||||||
|
logger.info(f"Email query parameter: {email}")
|
||||||
|
|
||||||
|
user: KhojUser = get_user_by_email(email)
|
||||||
|
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail="User not found.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove requests outside of the time window
|
||||||
|
cutoff = datetime.now(tz=timezone.utc) - timedelta(seconds=self.window)
|
||||||
|
count_requests = UserRequests.objects.filter(user=user, created_at__gte=cutoff, slug=self.slug).count()
|
||||||
|
|
||||||
|
# Check if the user has exceeded the rate limit
|
||||||
|
if count_requests >= self.requests:
|
||||||
|
logger.info(
|
||||||
|
f"Rate limit: {count_requests}/{self.requests} requests not allowed in {self.window} seconds for email: {email}."
|
||||||
|
)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="Ran out of login attempts",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the current request to the db
|
||||||
|
UserRequests.objects.create(user=user, slug=self.slug)
|
||||||
|
|
||||||
|
|
||||||
class ApiUserRateLimiter:
|
class ApiUserRateLimiter:
|
||||||
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str):
|
||||||
self.requests = requests
|
self.requests = requests
|
||||||
|
|
Loading…
Reference in a new issue