sijapi/sijapi/__main__.py

227 lines
9.2 KiB
Python
Executable file

#!/Users/sij/miniforge3/envs/api/bin/python
#__main__.py
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from hypercorn.asyncio import serve
from hypercorn.config import Config as HypercornConfig
import sys
import os
import traceback
import asyncio
import ipaddress
import importlib
from pathlib import Path
import argparse
from . import Sys, Db, Dir
from .logs import L, get_logger
def parse_args():
parser = argparse.ArgumentParser(description='Personal API.')
parser.add_argument('--log', type=str, default='INFO',
choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
help='Set overall log level (e.g., DEBUG, INFO, WARNING)')
parser.add_argument('--debug', nargs='+', default=[],
help='Set DEBUG log level for specific modules')
parser.add_argument('--info', nargs='+', default=[],
help='Set INFO log level for specific modules')
parser.add_argument('--test', type=str, help='Load only the specified module.')
return parser.parse_args()
args = parse_args()
# Setup logging
L.setup_from_args(args)
l = get_logger("main")
l.info(f"Logging initialized. Debug modules: {L.debug_modules}")
l.info(f"Command line arguments: {args}")
l.debug(f"Current working directory: {os.getcwd()}")
l.debug(f"__file__ path: {__file__}")
l.debug(f"Absolute path of __file__: {os.path.abspath(__file__)}")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
l.critical("sijapi launched")
l.info(f"Arguments: {args}")
# Load routers
if args.test:
load_router(args.test)
else:
for module_name in Sys.MODULES.__fields__:
if getattr(Sys.MODULES, module_name):
load_router(module_name)
try:
await Db.initialize_engines()
await Db.ensure_query_tracking_table()
except Exception as e:
l.critical(f"Error during startup: {str(e)}")
l.critical(f"Traceback: {traceback.format_exc()}")
try:
yield # This is where the app runs
finally:
# Shutdown
l.critical("Shutting down...")
try:
await asyncio.wait_for(Db.close(), timeout=20)
l.critical("Database pools closed.")
except asyncio.TimeoutError:
l.critical("Timeout while closing database pools.")
except Exception as e:
l.critical(f"Error during shutdown: {str(e)}")
l.critical(f"Traceback: {traceback.format_exc()}")
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
client_ip = ipaddress.ip_address(request.client.host)
if request.method == "OPTIONS":
# Allow CORS preflight requests
return JSONResponse(status_code=200)
if request.url.path not in Sys.PUBLIC:
trusted_subnets = [ipaddress.ip_network(subnet) for subnet in Sys.TRUSTED_SUBNETS]
if not any(client_ip in subnet for subnet in trusted_subnets):
api_key_header = request.headers.get("Authorization")
api_key_query = request.query_params.get("api_key")
# Convert Sys.KEYS to lowercase for case-insensitive comparison
api_keys_lower = [key.lower() for key in Sys.KEYS]
l.debug(f"Sys.KEYS (lowercase): {api_keys_lower}")
if api_key_header:
api_key_header = api_key_header.lower().split("bearer ")[-1]
l.debug(f"API key provided in header: {api_key_header}")
if api_key_query:
api_key_query = api_key_query.lower()
l.debug(f"API key provided in query: {api_key_query}")
if (api_key_header is None or api_key_header.lower() not in api_keys_lower) and \
(api_key_query is None or api_key_query.lower() not in api_keys_lower):
l.error(f"Invalid API key provided by a requester.")
if api_key_header:
l.debug(f"Invalid API key in header: {api_key_header}")
if api_key_query:
l.debug(f"Invalid API key in query: {api_key_query}")
return JSONResponse(
status_code=401,
content={"detail": "Invalid or missing API key"}
)
else:
if api_key_header and api_key_header.lower() in api_keys_lower:
l.debug(f"Valid API key provided in header: {api_key_header}")
if api_key_query and api_key_query.lower() in api_keys_lower:
l.debug(f"Valid API key provided in query: {api_key_query}")
response = await call_next(request)
return response
# Add the middleware to your FastAPI app
app.add_middleware(SimpleAPIKeyMiddleware)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
l.error(f"HTTP Exception: {exc.status_code} - {exc.detail}")
l.error(f"Request: {request.method} {request.url}")
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
@app.middleware("http")
async def handle_exception_middleware(request: Request, call_next):
try:
response = await call_next(request)
return response
except Exception as exc:
l.error(f"Unhandled exception in request: {request.method} {request.url}")
l.error(f"Exception: {str(exc)}")
l.error(f"Traceback: {traceback.format_exc()}")
return JSONResponse(
status_code=500,
content={"detail": "Internal Server Error"}
)
@app.post("/sync/pull")
async def pull_changes():
l.info(f"Received request to /sync/pull")
try:
await Sys.add_primary_keys_to_local_tables()
await Sys.add_primary_keys_to_remote_tables()
try:
source = await Sys.get_most_recent_source()
if source:
# Pull changes from the source
total_changes = await Sys.pull_changes(source)
return JSONResponse(content={
"status": "success",
"message": f"Pull complete. Total changes: {total_changes}",
"source": f"{source['ts_id']} ({source['ts_ip']})",
"changes": total_changes
})
else:
return JSONResponse(content={
"status": "info",
"message": "No instances with more recent data found or all instances are offline."
})
except Exception as e:
l.error(f"Error in /sync/pull: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Error during pull: {str(e)}")
finally:
l.info(f"Finished processing /sync/pull request")
except Exception as e:
l.error(f"Error while ensuring primary keys to tables: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Error during primary key insurance: {str(e)}")
def load_router(router_name):
router_logger = get_logger(f"router.{router_name}")
router_logger.debug(f"Attempting to load {router_name.capitalize()}...")
# Log the full path being checked
router_file = Dir.ROUTER / f'{router_name}.py'
router_logger.debug(f"Checking for router file at: {router_file.absolute()}")
if router_file.exists():
router_logger.debug(f"Router file found: {router_file}")
module_path = f'sijapi.routers.{router_name}'
router_logger.debug(f"Attempting to import module: {module_path}")
try:
module = importlib.import_module(module_path)
router_logger.debug(f"Module imported successfully: {module}")
router = getattr(module, router_name)
router_logger.debug(f"Router object retrieved: {router}")
app.include_router(router)
router_logger.info(f"Router {router_name} loaded successfully")
except (ImportError, AttributeError) as e:
router_logger.critical(f"Failed to load router {router_name}: {e}")
router_logger.debug(f"Current working directory: {os.getcwd()}")
router_logger.debug(f"Python path: {sys.path}")
else:
router_logger.error(f"Router file for {router_name} does not exist at {router_file.absolute()}")
router_logger.debug(f"Contents of router directory: {list(Dir.ROUTER.iterdir())}")
def main(argv):
config = HypercornConfig()
config.bind = [Sys.BIND]
config.startup_timeout = 300 # 5 minutes
config.shutdown_timeout = 15 # 15 seconds
asyncio.run(serve(app, config))
if __name__ == "__main__":
main(sys.argv[1:])