sw1tch/registration.py
2025-02-02 15:01:16 -08:00

397 lines
14 KiB
Python

import os
import re
import yaml
import json
import smtplib
import httpx
import logging
from datetime import datetime, timedelta
from email.message import EmailMessage
from typing import List, Dict, Optional, Tuple, Set, Pattern
from fastapi import FastAPI, Request, Form, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from starlette.middleware.base import BaseHTTPMiddleware
# ---------------------------------------------------------
# 1. Load configuration and setup paths
# ---------------------------------------------------------
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CONFIG_PATH = os.path.join(BASE_DIR, "config.yaml")
with open(CONFIG_PATH, "r") as f:
config = yaml.safe_load(f)
# Initialize or load registrations.json
REGISTRATIONS_PATH = os.path.join(BASE_DIR, "registrations.json")
def load_registrations() -> List[Dict]:
try:
with open(REGISTRATIONS_PATH, "r") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
return []
def save_registration(data: Dict):
registrations = load_registrations()
registrations.append(data)
with open(REGISTRATIONS_PATH, "w") as f:
json.dump(registrations, f, indent=2)
# Load banned IPs and emails
def load_banned_list(filename: str) -> Set[str]:
try:
with open(os.path.join(BASE_DIR, filename), "r") as f:
return {line.strip() for line in f if line.strip()}
except FileNotFoundError:
return set()
def load_banned_usernames() -> List[Pattern]:
"""Load banned usernames file and compile regex patterns."""
patterns = []
try:
with open(os.path.join(BASE_DIR, "banned_usernames.txt"), "r") as f:
for line in f:
line = line.strip()
if line:
try:
patterns.append(re.compile(line, re.IGNORECASE))
except re.error:
logging.error(f"Invalid regex pattern in banned_usernames.txt: {line}")
except FileNotFoundError:
pass
return patterns
banned_ips = load_banned_list("banned_ips.txt")
banned_emails = load_banned_list("banned_emails.txt")
banned_username_patterns = load_banned_usernames()
# Read the registration token
def read_registration_token():
token_path = os.path.join(BASE_DIR, ".registration_token")
try:
with open(token_path, "r") as f:
return f.read().strip()
except FileNotFoundError:
return None
# ---------------------------------------------------------
# 2. Logging Configuration
# ---------------------------------------------------------
# Set up logging format
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
# Configure loggers
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
logging.getLogger("uvicorn.error").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
class CustomLoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
# Don't process /api/time or favicon requests at all
if request.url.path == "/api/time" or request.url.path.endswith('favicon.ico'):
return await call_next(request)
# For all other requests, log them
response = await call_next(request)
logger.info(f"Request: {request.method} {request.url.path} - Status: {response.status_code}")
return response
# ---------------------------------------------------------
# 3. Time Calculation Functions
# ---------------------------------------------------------
def get_current_utc() -> datetime:
return datetime.utcnow()
def get_next_reset_time(now: datetime) -> datetime:
"""Return the next reset time (possibly today or tomorrow) from config."""
reset_h = config["token_reset_time_utc"] // 100
reset_m = config["token_reset_time_utc"] % 100
candidate = now.replace(hour=reset_h, minute=reset_m, second=0, microsecond=0)
if candidate <= now:
# If we've passed today's reset time, it must be tomorrow.
candidate += timedelta(days=1)
return candidate
def get_downtime_start(next_reset: datetime) -> datetime:
"""Return the downtime start time (minutes before next_reset)."""
return next_reset - timedelta(minutes=config["downtime_before_token_reset"])
def format_timedelta(td: timedelta) -> str:
"""Format a timedelta as 'X hours and Y minutes' (or similar)."""
total_minutes = int(td.total_seconds() // 60)
hours = total_minutes // 60
minutes = total_minutes % 60
parts = []
if hours == 1:
parts.append("1 hour")
elif hours > 1:
parts.append(f"{hours} hours")
if minutes == 1:
parts.append("1 minute")
elif minutes > 1:
parts.append(f"{minutes} minutes")
if not parts: # If total is less than a minute
return "0 minutes"
return " and ".join(parts)
def get_time_until_reset_str(now: datetime) -> str:
"""Return a string like '3 hours and 41 minutes' until next reset."""
nr = get_next_reset_time(now)
delta = nr - now
return format_timedelta(delta)
def is_registration_closed(now: datetime) -> Tuple[bool, str]:
"""
Determine if registration is closed based on config.
Return (closed_bool, message).
"""
nr = get_next_reset_time(now)
ds = get_downtime_start(nr)
if ds <= now < nr:
# We are within downtime
time_until_open = nr - now
msg = (
f"Registration is closed. "
f"It reopens in {format_timedelta(time_until_open)} at {nr.strftime('%H:%M UTC')}."
)
return True, msg
else:
# Registration is open
if now > ds:
# We've passed ds, so next downtime is tomorrow
nr += timedelta(days=1)
ds = get_downtime_start(nr)
time_until_close = ds - now
msg = (
f"Registration is open. "
f"It will close in {format_timedelta(time_until_close)} at {ds.strftime('%H:%M UTC')}."
)
return False, msg
# ---------------------------------------------------------
# 4. Registration Validation
# ---------------------------------------------------------
def is_username_banned(username: str) -> bool:
"""Check if username matches any banned patterns."""
return any(pattern.search(username) for pattern in banned_username_patterns)
def check_email_cooldown(email: str) -> Optional[str]:
"""Check if email is allowed to register based on cooldown and multiple account rules."""
registrations = load_registrations()
email_entries = [r for r in registrations if r["email"] == email]
if not email_entries:
return None
if not config.get("multiple_users_per_email", True):
return "This email address has already been used to register an account."
if email_cooldown := config.get("email_cooldown"):
latest_registration = max(
datetime.fromisoformat(e["datetime"])
for e in email_entries
)
time_since = datetime.utcnow() - latest_registration
if time_since.total_seconds() < email_cooldown:
wait_time = email_cooldown - time_since.total_seconds()
return f"Please wait {int(wait_time)} seconds before requesting another account."
return None
async def check_username_availability(username: str) -> bool:
"""Check if username is available on Matrix and in our registration records."""
# Check banned usernames first
if is_username_banned(username):
logger.info(f"[USERNAME CHECK] {username}: Banned by pattern")
return False
# Check local registrations
registrations = load_registrations()
if any(r["requested_name"] == username for r in registrations):
logger.info(f"[USERNAME CHECK] {username}: Already requested")
return False
# Check Matrix homeserver
url = f"https://{config['homeserver']}/_matrix/client/v3/register/available?username={username}"
async with httpx.AsyncClient() as client:
try:
response = await client.get(url, timeout=5)
if response.status_code == 200:
data = response.json()
is_available = data.get("available", False)
logger.info(f"[USERNAME CHECK] {username}: {'Available' if is_available else 'Taken'}")
return is_available
elif response.status_code == 400:
logger.info(f"[USERNAME CHECK] {username}: Taken (400)")
return False
except httpx.RequestError as ex:
logger.warning(f"[USERNAME CHECK] Could not reach homeserver: {ex}")
return False
return False
# ---------------------------------------------------------
# 5. FastAPI Setup and Routes
# ---------------------------------------------------------
app = FastAPI()
app.add_middleware(CustomLoggingMiddleware)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
now = get_current_utc()
closed, message = is_registration_closed(now)
return templates.TemplateResponse(
"index.html",
{
"request": request,
"registration_closed": closed,
"homeserver": config["homeserver"],
"message": message
}
)
@app.get("/api/time")
async def get_server_time():
now = get_current_utc()
return JSONResponse({"utc_time": now.strftime("%H:%M:%S")})
@app.post("/register", response_class=HTMLResponse)
async def register(
request: Request,
requested_username: str = Form(...),
email: str = Form(...)
):
now = get_current_utc()
client_ip = request.client.host
logger.info(f"Registration attempt - Username: {requested_username}, Email: {email}, IP: {client_ip}")
# Check if registration is closed
closed, message = is_registration_closed(now)
if closed:
logger.info("Registration rejected: Registration is closed")
return templates.TemplateResponse(
"error.html",
{
"request": request,
"message": message
}
)
# Check bans
if client_ip in banned_ips:
logger.info(f"Registration rejected: Banned IP {client_ip}")
return templates.TemplateResponse(
"error.html",
{
"request": request,
"message": "Registration not allowed from your IP address."
}
)
if email in banned_emails:
logger.info(f"Registration rejected: Banned email {email}")
return templates.TemplateResponse(
"error.html",
{
"request": request,
"message": "Registration not allowed for this email address."
}
)
# Check email cooldown
if error_message := check_email_cooldown(email):
logger.info(f"Registration rejected: Email cooldown - {email}")
return templates.TemplateResponse(
"error.html",
{
"request": request,
"message": error_message
}
)
# Check username availability
available = await check_username_availability(requested_username)
if not available:
logger.info(f"Registration rejected: Username unavailable - {requested_username}")
return templates.TemplateResponse(
"error.html",
{
"request": request,
"message": f"The username '{requested_username}' is not available."
}
)
# Read token and prepare email
token = read_registration_token()
if token is None:
logger.error("Registration token file not found")
raise HTTPException(status_code=500, detail="Registration token file not found.")
time_until_reset = get_time_until_reset_str(now)
email_body = config["email_body"].format(
homeserver=config["homeserver"],
registration_token=token,
requested_username=requested_username,
utc_time=now.strftime("%H:%M:%S"),
time_until_reset=time_until_reset
)
msg = EmailMessage()
msg.set_content(email_body)
msg["Subject"] = config["email_subject"].format(homeserver=config["homeserver"])
msg["From"] = config["smtp"]["username"]
msg["To"] = email
# Send email
try:
smtp_conf = config["smtp"]
with smtplib.SMTP(smtp_conf["host"], smtp_conf["port"]) as server:
if smtp_conf.get("use_tls", True):
server.starttls()
server.login(smtp_conf["username"], smtp_conf["password"])
server.send_message(msg)
logger.info(f"Registration email sent successfully to {email}")
except Exception as ex:
logger.error(f"Failed to send email: {ex}")
raise HTTPException(status_code=500, detail=f"Error sending email: {ex}")
# Log registration
registration_data = {
"requested_name": requested_username,
"email": email,
"datetime": datetime.utcnow().isoformat(),
"ip_address": client_ip
}
save_registration(registration_data)
logger.info(f"Registration successful - Username: {requested_username}, Email: {email}")
return templates.TemplateResponse(
"success.html",
{
"request": request,
"homeserver": config["homeserver"]
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"registration:app",
host="0.0.0.0",
port=config["port"],
reload=True,
access_log=False
)