149 lines
No EOL
5.1 KiB
Python
Executable file
149 lines
No EOL
5.1 KiB
Python
Executable file
#!/Users/sij/miniforge3/envs/api/bin/python
|
|
from fastapi import FastAPI, Request, HTTPException, Response
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import ClientDisconnect
|
|
from hypercorn.asyncio import serve
|
|
from hypercorn.config import Config
|
|
import sys
|
|
import asyncio
|
|
import httpx
|
|
import argparse
|
|
import json
|
|
import ipaddress
|
|
import importlib
|
|
from dotenv import load_dotenv
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
import argparse
|
|
from . import LOGGER, LOGS_DIR, OBSIDIAN_VAULT_DIR
|
|
from .logs import Logger
|
|
from .utilities import list_and_correct_impermissible_files
|
|
|
|
parser = argparse.ArgumentParser(description='Personal API.')
|
|
parser.add_argument('--debug', action='store_true', help='Set log level to INFO')
|
|
parser.add_argument('--test', type=str, help='Load only the specified module.')
|
|
args = parser.parse_args()
|
|
|
|
# Using the package logger
|
|
main_logger = Logger("main", LOGS_DIR)
|
|
main_logger.setup_from_args(args)
|
|
logger = LOGGER
|
|
|
|
# Use the logger
|
|
logger.debug("Debug Log")
|
|
logger.info("Info Log")
|
|
|
|
|
|
from sijapi import DEBUG, INFO, WARN, ERR, CRITICAL
|
|
|
|
from sijapi import HOST, ENV_PATH, GLOBAL_API_KEY, REQUESTS_DIR, ROUTER_DIR, REQUESTS_LOG_PATH, PUBLIC_SERVICES, TRUSTED_SUBNETS, ROUTERS
|
|
|
|
|
|
# Initialize a FastAPI application
|
|
api = FastAPI()
|
|
|
|
|
|
# CORSMiddleware
|
|
api.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=['*'],
|
|
allow_credentials=True,
|
|
allow_methods=['*'],
|
|
allow_headers=['*'],
|
|
)
|
|
|
|
class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
client_ip = ipaddress.ip_address(request.client.host)
|
|
if request.method == "OPTIONS":
|
|
# Allow CORS preflight requests
|
|
return JSONResponse(status_code=200)
|
|
if request.url.path not in PUBLIC_SERVICES:
|
|
if not any(client_ip in subnet for subnet in TRUSTED_SUBNETS):
|
|
api_key_header = request.headers.get("Authorization")
|
|
api_key_query = request.query_params.get("api_key")
|
|
if api_key_header:
|
|
api_key_header = api_key_header.lower().split("bearer ")[-1]
|
|
if api_key_header != GLOBAL_API_KEY and api_key_query != GLOBAL_API_KEY:
|
|
ERR(f"Invalid API key provided by a requester.")
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"detail": "Invalid or missing API key"}
|
|
)
|
|
response = await call_next(request)
|
|
# DEBUG(f"Request from {client_ip} is complete")
|
|
return response
|
|
|
|
api.add_middleware(SimpleAPIKeyMiddleware)
|
|
|
|
canceled_middleware = """
|
|
@api.middleware("http")
|
|
async def log_requests(request: Request, call_next):
|
|
DEBUG(f"Incoming request: {request.method} {request.url}")
|
|
DEBUG(f"Request headers: {request.headers}")
|
|
DEBUG(f"Request body: {await request.body()}")
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
async def log_outgoing_request(request):
|
|
INFO(f"Outgoing request: {request.method} {request.url}")
|
|
DEBUG(f"Request headers: {request.headers}")
|
|
DEBUG(f"Request body: {request.content}")
|
|
"""
|
|
|
|
@api.exception_handler(HTTPException)
|
|
async def http_exception_handler(request: Request, exc: HTTPException):
|
|
ERR(f"HTTP Exception: {exc.status_code} - {exc.detail}")
|
|
ERR(f"Request: {request.method} {request.url}")
|
|
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
|
|
|
@api.middleware("http")
|
|
async def handle_exception_middleware(request: Request, call_next):
|
|
try:
|
|
response = await call_next(request)
|
|
except RuntimeError as exc:
|
|
if str(exc) == "Response content longer than Content-Length":
|
|
# Update the Content-Length header to match the actual response content length
|
|
response.headers["Content-Length"] = str(len(response.body))
|
|
else:
|
|
raise
|
|
return response
|
|
|
|
|
|
|
|
def load_router(router_name):
|
|
router_file = ROUTER_DIR / f'{router_name}.py'
|
|
DEBUG(f"Attempting to load {router_name.capitalize()}...")
|
|
if router_file.exists():
|
|
module_path = f'sijapi.routers.{router_name}'
|
|
try:
|
|
module = importlib.import_module(module_path)
|
|
router = getattr(module, router_name)
|
|
api.include_router(router)
|
|
INFO(f"{router_name.capitalize()} router loaded.")
|
|
except (ImportError, AttributeError) as e:
|
|
CRITICAL(f"Failed to load router {router_name}: {e}")
|
|
else:
|
|
ERR(f"Router file for {router_name} does not exist.")
|
|
|
|
def main(argv):
|
|
if args.test:
|
|
load_router(args.test)
|
|
else:
|
|
CRITICAL(f"sijapi launched")
|
|
CRITICAL(f"{args._get_args}")
|
|
for router_name in ROUTERS:
|
|
load_router(router_name)
|
|
|
|
journal = OBSIDIAN_VAULT_DIR / "journal"
|
|
list_and_correct_impermissible_files(journal, rename=True)
|
|
config = Config()
|
|
config.keep_alive_timeout = 1200
|
|
config.bind = [HOST]
|
|
asyncio.run(serve(api, config))
|
|
|
|
if __name__ == "__main__":
|
|
main(sys.argv[1:]) |