sw1tch/registration.py

424 lines
15 KiB
Python
Raw Normal View History

2025-02-02 15:01:16 -08:00
import os
import re
import yaml
import json
import smtplib
import httpx
import logging
import ipaddress
2025-02-02 15:01:16 -08:00
from datetime import datetime, timedelta
from email.message import EmailMessage
from typing import List, Dict, Optional, Tuple, Set, Pattern, Union
2025-02-02 15:01:16 -08:00
from fastapi import FastAPI, Request, Form, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from starlette.middleware.base import BaseHTTPMiddleware
from ipaddress import IPv4Network, IPv4Address
2025-02-02 15:01:16 -08:00
# ---------------------------------------------------------
# 1. Load configuration and setup paths
# ---------------------------------------------------------
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CONFIG_PATH = os.path.join(BASE_DIR, "config.yaml")
with open(CONFIG_PATH, "r") as f:
config = yaml.safe_load(f)
# Initialize or load registrations.json
REGISTRATIONS_PATH = os.path.join(BASE_DIR, "registrations.json")
def load_registrations() -> List[Dict]:
try:
with open(REGISTRATIONS_PATH, "r") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
return []
def save_registration(data: Dict):
registrations = load_registrations()
registrations.append(data)
with open(REGISTRATIONS_PATH, "w") as f:
json.dump(registrations, f, indent=2)
# Functions to check banned entries
2025-02-02 15:01:16 -08:00
def load_banned_usernames() -> List[Pattern]:
"""Load banned usernames file and compile regex patterns."""
patterns = []
try:
with open(os.path.join(BASE_DIR, "banned_usernames.txt"), "r") as f:
for line in f:
line = line.strip()
if line:
try:
patterns.append(re.compile(line, re.IGNORECASE))
except re.error:
logging.error(f"Invalid regex pattern in banned_usernames.txt: {line}")
except FileNotFoundError:
pass
return patterns
def is_ip_banned(ip: str) -> bool:
"""Check if an IP is banned, supporting both individual IPs and CIDR ranges."""
try:
check_ip = IPv4Address(ip)
try:
with open(os.path.join(BASE_DIR, "banned_ips.txt"), "r") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
if '/' in line: # CIDR notation
if check_ip in IPv4Network(line):
return True
else: # Individual IP
if check_ip == IPv4Address(line):
return True
except ValueError:
logging.error(f"Invalid IP/CIDR in banned_ips.txt: {line}")
except FileNotFoundError:
return False
except ValueError:
logging.error(f"Invalid IP address to check: {ip}")
return False
def is_email_banned(email: str) -> bool:
"""Check if an email matches any banned patterns."""
try:
with open(os.path.join(BASE_DIR, "banned_emails.txt"), "r") as f:
for line in f:
pattern = line.strip()
if not pattern:
continue
# Convert email patterns to regex
# Replace * with .* and escape dots
regex_pattern = pattern.replace(".", "\\.").replace("*", ".*")
try:
if re.match(regex_pattern, email, re.IGNORECASE):
return True
except re.error:
logging.error(f"Invalid email pattern in banned_emails.txt: {pattern}")
except FileNotFoundError:
pass
return False
def is_username_banned(username: str) -> bool:
"""Check if username matches any banned patterns."""
patterns = load_banned_usernames()
return any(pattern.search(username) for pattern in patterns)
2025-02-02 15:01:16 -08:00
# Read the registration token
def read_registration_token():
token_path = os.path.join(BASE_DIR, ".registration_token")
try:
with open(token_path, "r") as f:
return f.read().strip()
except FileNotFoundError:
return None
# ---------------------------------------------------------
# 2. Logging Configuration
# ---------------------------------------------------------
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
logging.getLogger("uvicorn.error").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
class CustomLoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
if request.url.path == "/api/time" or request.url.path.endswith('favicon.ico'):
return await call_next(request)
response = await call_next(request)
logger.info(f"Request: {request.method} {request.url.path} - Status: {response.status_code}")
return response
# ---------------------------------------------------------
# 3. Time Calculation Functions
# ---------------------------------------------------------
def get_current_utc() -> datetime:
return datetime.utcnow()
def get_next_reset_time(now: datetime) -> datetime:
"""Return the next reset time (possibly today or tomorrow) from config."""
reset_h = config["token_reset_time_utc"] // 100
reset_m = config["token_reset_time_utc"] % 100
candidate = now.replace(hour=reset_h, minute=reset_m, second=0, microsecond=0)
if candidate <= now:
# If we've passed today's reset time, it must be tomorrow.
candidate += timedelta(days=1)
return candidate
def get_downtime_start(next_reset: datetime) -> datetime:
"""Return the downtime start time (minutes before next_reset)."""
return next_reset - timedelta(minutes=config["downtime_before_token_reset"])
def format_timedelta(td: timedelta) -> str:
"""Format a timedelta as 'X hours and Y minutes' (or similar)."""
total_minutes = int(td.total_seconds() // 60)
hours = total_minutes // 60
minutes = total_minutes % 60
parts = []
if hours == 1:
parts.append("1 hour")
elif hours > 1:
parts.append(f"{hours} hours")
if minutes == 1:
parts.append("1 minute")
elif minutes > 1:
parts.append(f"{minutes} minutes")
if not parts: # If total is less than a minute
return "0 minutes"
return " and ".join(parts)
def get_time_until_reset_str(now: datetime) -> str:
"""Return a string like '3 hours and 41 minutes' until next reset."""
nr = get_next_reset_time(now)
delta = nr - now
return format_timedelta(delta)
def is_registration_closed(now: datetime) -> Tuple[bool, str]:
"""
Determine if registration is closed based on config.
Return (closed_bool, message).
"""
nr = get_next_reset_time(now)
ds = get_downtime_start(nr)
if ds <= now < nr:
# We are within downtime
time_until_open = nr - now
msg = (
f"Registration is closed. "
f"It reopens in {format_timedelta(time_until_open)} at {nr.strftime('%H:%M UTC')}."
)
return True, msg
else:
# Registration is open
if now > ds:
# We've passed ds, so next downtime is tomorrow
nr += timedelta(days=1)
ds = get_downtime_start(nr)
time_until_close = ds - now
msg = (
f"Registration is open. "
f"It will close in {format_timedelta(time_until_close)} at {ds.strftime('%H:%M UTC')}."
)
return False, msg
# ---------------------------------------------------------
# 4. Registration Validation
# ---------------------------------------------------------
def check_email_cooldown(email: str) -> Optional[str]:
"""Check if email is allowed to register based on cooldown and multiple account rules."""
registrations = load_registrations()
email_entries = [r for r in registrations if r["email"] == email]
if not email_entries:
return None
if not config.get("multiple_users_per_email", True):
return "This email address has already been used to register an account."
if email_cooldown := config.get("email_cooldown"):
latest_registration = max(
datetime.fromisoformat(e["datetime"])
for e in email_entries
)
time_since = datetime.utcnow() - latest_registration
if time_since.total_seconds() < email_cooldown:
wait_time = email_cooldown - time_since.total_seconds()
return f"Please wait {int(wait_time)} seconds before requesting another account."
return None
async def check_username_availability(username: str) -> bool:
"""Check if username is available on Matrix and in our registration records."""
if is_username_banned(username):
logger.info(f"[USERNAME CHECK] {username}: Banned by pattern")
return False
registrations = load_registrations()
if any(r["requested_name"] == username for r in registrations):
logger.info(f"[USERNAME CHECK] {username}: Already requested")
return False
url = f"https://{config['homeserver']}/_matrix/client/v3/register/available?username={username}"
async with httpx.AsyncClient() as client:
try:
response = await client.get(url, timeout=5)
if response.status_code == 200:
data = response.json()
is_available = data.get("available", False)
logger.info(f"[USERNAME CHECK] {username}: {'Available' if is_available else 'Taken'}")
return is_available
elif response.status_code == 400:
logger.info(f"[USERNAME CHECK] {username}: Taken (400)")
return False
except httpx.RequestError as ex:
logger.warning(f"[USERNAME CHECK] Could not reach homeserver: {ex}")
return False
return False
# ---------------------------------------------------------
# 5. FastAPI Setup and Routes
# ---------------------------------------------------------
app = FastAPI()
app.add_middleware(CustomLoggingMiddleware)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
now = get_current_utc()
closed, message = is_registration_closed(now)
return templates.TemplateResponse(
"index.html",
{
"request": request,
"registration_closed": closed,
"homeserver": config["homeserver"],
2025-02-08 01:17:53 +00:00
"message": message,
"reset_hour": config["token_reset_time_utc"] // 100,
"reset_minute": config["token_reset_time_utc"] % 100,
"downtime_minutes": config["downtime_before_token_reset"]
2025-02-02 15:01:16 -08:00
}
)
@app.get("/api/time")
async def get_server_time():
now = get_current_utc()
return JSONResponse({"utc_time": now.strftime("%H:%M:%S")})
@app.post("/register", response_class=HTMLResponse)
async def register(
request: Request,
requested_username: str = Form(...),
email: str = Form(...)
):
now = get_current_utc()
client_ip = request.client.host
logger.info(f"Registration attempt - Username: {requested_username}, Email: {email}, IP: {client_ip}")
closed, message = is_registration_closed(now)
if closed:
logger.info("Registration rejected: Registration is closed")
return templates.TemplateResponse(
"error.html",
{
"request": request,
"message": message
}
)
if is_ip_banned(client_ip):
2025-02-02 15:01:16 -08:00
logger.info(f"Registration rejected: Banned IP {client_ip}")
return templates.TemplateResponse(
"error.html",
{
"request": request,
"message": "Registration not allowed from your IP address."
}
)
if is_email_banned(email):
2025-02-02 15:01:16 -08:00
logger.info(f"Registration rejected: Banned email {email}")
return templates.TemplateResponse(
"error.html",
{
"request": request,
"message": "Registration not allowed for this email address."
}
)
if error_message := check_email_cooldown(email):
logger.info(f"Registration rejected: Email cooldown - {email}")
return templates.TemplateResponse(
"error.html",
{
"request": request,
"message": error_message
}
)
available = await check_username_availability(requested_username)
if not available:
logger.info(f"Registration rejected: Username unavailable - {requested_username}")
return templates.TemplateResponse(
"error.html",
{
"request": request,
"message": f"The username '{requested_username}' is not available."
}
)
token = read_registration_token()
if token is None:
logger.error("Registration token file not found")
raise HTTPException(status_code=500, detail="Registration token file not found.")
time_until_reset = get_time_until_reset_str(now)
email_body = config["email_body"].format(
homeserver=config["homeserver"],
registration_token=token,
requested_username=requested_username,
utc_time=now.strftime("%H:%M:%S"),
time_until_reset=time_until_reset
)
msg = EmailMessage()
msg.set_content(email_body)
msg["Subject"] = config["email_subject"].format(homeserver=config["homeserver"])
msg["From"] = config["smtp"]["username"]
msg["To"] = email
try:
smtp_conf = config["smtp"]
with smtplib.SMTP(smtp_conf["host"], smtp_conf["port"]) as server:
if smtp_conf.get("use_tls", True):
server.starttls()
server.login(smtp_conf["username"], smtp_conf["password"])
server.send_message(msg)
logger.info(f"Registration email sent successfully to {email}")
except Exception as ex:
logger.error(f"Failed to send email: {ex}")
raise HTTPException(status_code=500, detail=f"Error sending email: {ex}")
registration_data = {
"requested_name": requested_username,
"email": email,
"datetime": datetime.utcnow().isoformat(),
"ip_address": client_ip
}
save_registration(registration_data)
logger.info(f"Registration successful - Username: {requested_username}, Email: {email}")
return templates.TemplateResponse(
"success.html",
{
"request": request,
"homeserver": config["homeserver"]
}
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"registration:app",
host="0.0.0.0",
port=config["port"],
reload=True,
access_log=False
)