Auto-update: Thu Jul 25 01:27:56 PDT 2024

This commit is contained in:
sanj 2024-07-25 01:27:56 -07:00
parent 8775c05927
commit 0be1915aeb
2 changed files with 186 additions and 169 deletions

View file

@ -5,7 +5,6 @@ from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import ClientDisconnect from starlette.requests import ClientDisconnect
from hypercorn.asyncio import serve from hypercorn.asyncio import serve
from hypercorn.config import Config as HypercornConfig from hypercorn.config import Config as HypercornConfig
@ -44,7 +43,6 @@ err(f"Error message.")
def crit(text: str): logger.critical(text) def crit(text: str): logger.critical(text)
crit(f"Critical message.") crit(f"Critical message.")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup # Startup
@ -52,36 +50,30 @@ async def lifespan(app: FastAPI):
crit(f"Arguments: {args}") crit(f"Arguments: {args}")
# Load routers # Load routers
for module_name in API.MODULES.__fields__: if args.test:
if getattr(API.MODULES, module_name): load_router(args.test)
load_router(module_name) else:
for module_name in API.MODULES.__fields__:
if getattr(API.MODULES, module_name):
load_router(module_name)
crit("Starting database synchronization...") crit("Starting database synchronization...")
try: try:
# Log the current TS_ID # Initialize sync structures
crit(f"Current TS_ID: {os.environ.get('TS_ID', 'Not set')}") await API.initialize_sync()
# Log the local_db configuration
local_db = API.local_db
crit(f"Local DB configuration: {local_db}")
# Test local connection
async with API.get_connection() as conn:
version = await conn.fetchval("SELECT version()")
crit(f"Successfully connected to local database. PostgreSQL version: {version}")
# Sync schema across all databases # Sync schema across all databases
await API.sync_schema() await API.sync_schema()
crit("Schema synchronization complete.") crit("Schema synchronization complete.")
# Attempt to pull changes from another database # Check if other instances have more recent data
source = await API.get_default_source() source = await API.get_most_recent_source()
if source: if source:
crit(f"Pulling changes from {source['ts_id']}...") crit(f"Pulling changes from {source['ts_id']}...")
await API.pull_changes(source) await API.pull_changes(source)
crit("Data pull complete.") crit("Data pull complete.")
else: else:
crit("No available source for pulling changes. This might be the only active database.") crit("No instances with more recent data found.")
except Exception as e: except Exception as e:
crit(f"Error during startup: {str(e)}") crit(f"Error during startup: {str(e)}")
@ -93,7 +85,6 @@ async def lifespan(app: FastAPI):
crit("Shutting down...") crit("Shutting down...")
# Perform any cleanup operations here if needed # Perform any cleanup operations here if needed
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware( app.add_middleware(
@ -124,7 +115,6 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
content={"detail": "Invalid or missing API key"} content={"detail": "Invalid or missing API key"}
) )
response = await call_next(request) response = await call_next(request)
# debug(f"Request from {client_ip} is complete")
return response return response
# Add the middleware to your FastAPI app # Add the middleware to your FastAPI app
@ -136,7 +126,6 @@ async def http_exception_handler(request: Request, exc: HTTPException):
err(f"Request: {request.method} {request.url}") err(f"Request: {request.method} {request.url}")
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
@app.middleware("http") @app.middleware("http")
async def handle_exception_middleware(request: Request, call_next): async def handle_exception_middleware(request: Request, call_next):
try: try:
@ -149,6 +138,19 @@ async def handle_exception_middleware(request: Request, call_next):
raise raise
return response return response
@app.middleware("http")
async def sync_middleware(request: Request, call_next):
response = await call_next(request)
# Check if the request was a database write operation
if request.method in ["POST", "PUT", "PATCH", "DELETE"]:
try:
# Push changes to other databases
await API.push_changes_to_all()
except Exception as e:
err(f"Error pushing changes to other databases: {str(e)}")
return response
def load_router(router_name): def load_router(router_name):
router_file = ROUTER_DIR / f'{router_name}.py' router_file = ROUTER_DIR / f'{router_name}.py'
@ -160,25 +162,16 @@ def load_router(router_name):
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
router = getattr(module, router_name) router = getattr(module, router_name)
app.include_router(router) app.include_router(router)
# module_logger.info(f"{router_name.capitalize()} router loaded.")
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
module_logger.critical(f"Failed to load router {router_name}: {e}") module_logger.critical(f"Failed to load router {router_name}: {e}")
else: else:
module_logger.error(f"Router file for {router_name} does not exist.") module_logger.error(f"Router file for {router_name} does not exist.")
def main(argv): def main(argv):
if args.test:
load_router(args.test)
else:
crit(f"sijapi launched")
crit(f"Arguments: {args}")
for module_name in API.MODULES.__fields__:
if getattr(API.MODULES, module_name):
load_router(module_name)
config = HypercornConfig() config = HypercornConfig()
config.bind = [API.BIND] config.bind = [API.BIND]
config.startup_timeout = 3600 # 1 hour
asyncio.run(serve(app, config)) asyncio.run(serve(app, config))
if __name__ == "__main__": if __name__ == "__main__":
main(sys.argv[1:]) main(sys.argv[1:])

View file

@ -31,6 +31,8 @@ def warn(text: str): logger.warning(text)
def err(text: str): logger.error(text) def err(text: str): logger.error(text)
def crit(text: str): logger.critical(text) def crit(text: str): logger.critical(text)
TS_ID=os.getenv("TS_ID", "NULL")
T = TypeVar('T', bound='Configuration') T = TypeVar('T', bound='Configuration')
class Configuration(BaseModel): class Configuration(BaseModel):
HOME: Path = Path.home() HOME: Path = Path.home()
@ -156,6 +158,17 @@ class Configuration(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
from pydantic import BaseModel, Field
from typing import Any, Dict, List, Union
from pathlib import Path
import yaml
import re
import os
import asyncpg
from contextlib import asynccontextmanager
from sijapi import TS_ID
import traceback
class APIConfig(BaseModel): class APIConfig(BaseModel):
HOST: str HOST: str
PORT: int PORT: int
@ -183,7 +196,6 @@ class APIConfig(BaseModel):
try: try:
with open(secrets_path, 'r') as file: with open(secrets_path, 'r') as file:
secrets_data = yaml.safe_load(file) secrets_data = yaml.safe_load(file)
# info(f"Loaded secrets: {secrets_data}")
except FileNotFoundError: except FileNotFoundError:
err(f"Secrets file not found: {secrets_path}") err(f"Secrets file not found: {secrets_path}")
secrets_data = {} secrets_data = {}
@ -279,8 +291,7 @@ class APIConfig(BaseModel):
@property @property
def local_db(self): def local_db(self):
ts_id = os.environ.get('TS_ID') return next((db for db in self.POOL if db['ts_id'] == TS_ID), None)
return next((db for db in self.POOL if db['ts_id'] == ts_id), None)
@asynccontextmanager @asynccontextmanager
async def get_connection(self, pool_entry: Dict[str, Any] = None): async def get_connection(self, pool_entry: Dict[str, Any] = None):
@ -305,104 +316,158 @@ class APIConfig(BaseModel):
crit(f"Error: {str(e)}") crit(f"Error: {str(e)}")
raise raise
async def push_changes(self, query: str, *args): async def initialize_sync(self):
connections = [] async with self.get_connection() as conn:
try: # Create sync_status table
for pool_entry in self.POOL[1:]: # Skip the first (local) database await conn.execute("""
conn = await self.get_connection(pool_entry).__aenter__() CREATE TABLE IF NOT EXISTS sync_status (
connections.append(conn) table_name TEXT,
server_id TEXT,
last_synced_version INTEGER,
PRIMARY KEY (table_name, server_id)
)
""")
# Get all tables
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
# Add version and server_id columns to all tables, create triggers
for table in tables:
table_name = table['tablename']
await conn.execute(f"""
ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{TS_ID}';
results = await asyncio.gather( CREATE OR REPLACE FUNCTION update_version_and_server_id()
*[conn.execute(query, *args) for conn in connections], RETURNS TRIGGER AS $$
return_exceptions=True BEGIN
) NEW.version = COALESCE(OLD.version, 0) + 1;
NEW.server_id = '{TS_ID}';
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
for pool_entry, result in zip(self.POOL[1:], results): DROP TRIGGER IF EXISTS update_version_and_server_id_trigger ON "{table_name}";
if isinstance(result, Exception): CREATE TRIGGER update_version_and_server_id_trigger
err(f"Failed to push to {pool_entry['ts_ip']}: {str(result)}") BEFORE INSERT OR UPDATE ON "{table_name}"
else: FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
info(f"Successfully pushed to {pool_entry['ts_ip']}")
finally: INSERT INTO sync_status (table_name, server_id, last_synced_version)
for conn in connections: VALUES ('{table_name}', '{TS_ID}', 0)
await conn.__aexit__(None, None, None) ON CONFLICT (table_name, server_id) DO NOTHING;
""")
async def get_default_source(self): async def get_most_recent_source(self):
local = self.local_db most_recent_source = None
for db in self.POOL: max_version = -1
if db != local:
try: for pool_entry in self.POOL:
async with self.get_connection(db): if pool_entry['ts_id'] == TS_ID:
return db continue
except:
try:
async with self.get_connection(pool_entry) as conn:
version = await conn.fetchval("""
SELECT MAX(last_synced_version) FROM sync_status
""")
if version > max_version:
max_version = version
most_recent_source = pool_entry
except Exception as e:
err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}")
return most_recent_source
async def pull_changes(self, source_pool_entry):
async with self.get_connection(source_pool_entry) as source_conn:
async with self.get_connection() as dest_conn:
tables = await source_conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
last_synced_version = await self.get_last_synced_version(table_name, source_pool_entry['ts_id'])
changes = await source_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2
ORDER BY version ASC
""", last_synced_version, source_pool_entry['ts_id'])
for change in changes:
columns = change.keys()
values = [change[col] for col in columns]
await dest_conn.execute(f"""
INSERT INTO "{table_name}" ({', '.join(columns)})
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
ON CONFLICT (id) DO UPDATE SET
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
""", *values)
if changes:
await self.update_sync_status(table_name, source_pool_entry['ts_id'], changes[-1]['version'])
async def push_changes_to_all(self):
async with self.get_connection() as local_conn:
tables = await local_conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for pool_entry in self.POOL:
if pool_entry['ts_id'] == TS_ID:
continue continue
return None
try:
async with self.get_connection(pool_entry) as remote_conn:
async def pull_changes(self, source_pool_entry: Dict[str, Any] = None): for table in tables:
try: table_name = table['tablename']
if source_pool_entry is None: last_synced_version = await self.get_last_synced_version(table_name, pool_entry['ts_id'])
source_pool_entry = await self.get_default_source()
changes = await local_conn.fetch(f"""
if source_pool_entry is None: SELECT * FROM "{table_name}"
err("No available source for pulling changes") WHERE version > $1 AND server_id = $2
return ORDER BY version ASC
""", last_synced_version, TS_ID)
async with self.get_connection(source_pool_entry) as source_conn:
async with self.get_connection() as dest_conn: for change in changes:
tables = await source_conn.fetch( columns = change.keys()
"SELECT tablename FROM pg_tables WHERE schemaname = 'public'" values = [change[col] for col in columns]
) await remote_conn.execute(f"""
for table in tables: INSERT INTO "{table_name}" ({', '.join(columns)})
table_name = table['tablename'] VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
info(f"Processing table: {table_name}") ON CONFLICT (id) DO UPDATE SET
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
# Get primary key column(s) """, *values)
pk_columns = await source_conn.fetch("""
SELECT a.attname if changes:
FROM pg_index i await self.update_sync_status(table_name, pool_entry['ts_id'], changes[-1]['version'])
JOIN pg_attribute a ON a.attrelid = i.indrelid
AND a.attnum = ANY(i.indkey) info(f"Successfully pushed changes to {pool_entry['ts_id']}")
WHERE i.indrelid = $1::regclass except Exception as e:
AND i.indisprimary; err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
""", table_name)
pk_cols = [col['attname'] for col in pk_columns]
info(f"Primary key columns for {table_name}: {pk_cols}")
if not pk_cols:
warn(f"No primary key found for table {table_name}. Skipping.")
continue
# Fetch all rows from the source table
rows = await source_conn.fetch(f"SELECT * FROM {table_name}")
info(f"Fetched {len(rows)} rows from {table_name}")
if rows:
columns = list(rows[0].keys())
info(f"Columns for {table_name}: {columns}")
# Upsert records to the destination table
for row in rows:
try:
query = f"""
INSERT INTO {table_name} ({', '.join(columns)})
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
ON CONFLICT ({', '.join(pk_cols)}) DO UPDATE SET
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in pk_cols)}
"""
info(f"Executing query: {query}")
info(f"With values: {[row[col] for col in columns]}")
await dest_conn.execute(query, *[row[col] for col in columns])
except Exception as e:
err(f"Error processing row in {table_name}: {str(e)}")
err(f"Problematic row: {row}")
info(f"Completed processing table: {table_name}")
info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
except Exception as e:
err(f"Unexpected error in pull_changes: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
async def get_last_synced_version(self, table_name, server_id):
async with self.get_connection() as conn:
return await conn.fetchval("""
SELECT last_synced_version FROM sync_status
WHERE table_name = $1 AND server_id = $2
""", table_name, server_id) or 0
async def update_sync_status(self, table_name, server_id, version):
async with self.get_connection() as conn:
await conn.execute("""
INSERT INTO sync_status (table_name, server_id, last_synced_version)
VALUES ($1, $2, $3)
ON CONFLICT (table_name, server_id) DO UPDATE
SET last_synced_version = EXCLUDED.last_synced_version
""", table_name, server_id, version)
async def sync_schema(self): async def sync_schema(self):
for pool_entry in self.POOL: for pool_entry in self.POOL:
@ -459,45 +524,6 @@ class APIConfig(BaseModel):
END $$; END $$;
""") """)
async def get_schema(self, pool_entry: Dict[str, Any]):
async with self.get_connection(pool_entry) as conn:
tables = await conn.fetch("""
SELECT table_name, column_name, data_type, character_maximum_length,
is_nullable, column_default, ordinal_position
FROM information_schema.columns
WHERE table_schema = 'public'
ORDER BY table_name, ordinal_position
""")
indexes = await conn.fetch("""
SELECT indexname, indexdef
FROM pg_indexes
WHERE schemaname = 'public'
""")
constraints = await conn.fetch("""
SELECT conname, contype, conrelid::regclass::text as table_name,
pg_get_constraintdef(oid) as definition
FROM pg_constraint
WHERE connamespace = 'public'::regnamespace
""")
return {
'tables': tables,
'indexes': indexes,
'constraints': constraints
}
async def create_sequence_if_not_exists(self, conn, sequence_name):
await conn.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
CREATE SEQUENCE {sequence_name};
END IF;
END $$;
""")
async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema): async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
source_tables = {t['table_name']: t for t in source_schema['tables']} source_tables = {t['table_name']: t for t in source_schema['tables']}
@ -567,8 +593,6 @@ class APIConfig(BaseModel):
await conn.execute(sql) await conn.execute(sql)
except Exception as e: except Exception as e:
err(f"Error processing table {table_name}: {str(e)}") err(f"Error processing table {table_name}: {str(e)}")
# Optionally, you might want to raise this exception if you want to stop the entire process
# raise
try: try:
source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']} source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']}