From 0be1915aebffde44ead2838b68487c9b276a3f03 Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Thu, 25 Jul 2024 01:27:56 -0700 Subject: [PATCH] Auto-update: Thu Jul 25 01:27:56 PDT 2024 --- sijapi/__main__.py | 61 +++++----- sijapi/classes.py | 294 ++++++++++++++++++++++++--------------------- 2 files changed, 186 insertions(+), 169 deletions(-) diff --git a/sijapi/__main__.py b/sijapi/__main__.py index e9619c7..43bf81b 100755 --- a/sijapi/__main__.py +++ b/sijapi/__main__.py @@ -5,7 +5,6 @@ from fastapi import FastAPI, Request, HTTPException, Response from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.base import BaseHTTPMiddleware -from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import ClientDisconnect from hypercorn.asyncio import serve from hypercorn.config import Config as HypercornConfig @@ -44,7 +43,6 @@ err(f"Error message.") def crit(text: str): logger.critical(text) crit(f"Critical message.") - @asynccontextmanager async def lifespan(app: FastAPI): # Startup @@ -52,36 +50,30 @@ async def lifespan(app: FastAPI): crit(f"Arguments: {args}") # Load routers - for module_name in API.MODULES.__fields__: - if getattr(API.MODULES, module_name): - load_router(module_name) + if args.test: + load_router(args.test) + else: + for module_name in API.MODULES.__fields__: + if getattr(API.MODULES, module_name): + load_router(module_name) crit("Starting database synchronization...") try: - # Log the current TS_ID - crit(f"Current TS_ID: {os.environ.get('TS_ID', 'Not set')}") - - # Log the local_db configuration - local_db = API.local_db - crit(f"Local DB configuration: {local_db}") - - # Test local connection - async with API.get_connection() as conn: - version = await conn.fetchval("SELECT version()") - crit(f"Successfully connected to local database. PostgreSQL version: {version}") - + # Initialize sync structures + await API.initialize_sync() + # Sync schema across all databases await API.sync_schema() crit("Schema synchronization complete.") - # Attempt to pull changes from another database - source = await API.get_default_source() + # Check if other instances have more recent data + source = await API.get_most_recent_source() if source: crit(f"Pulling changes from {source['ts_id']}...") await API.pull_changes(source) crit("Data pull complete.") else: - crit("No available source for pulling changes. This might be the only active database.") + crit("No instances with more recent data found.") except Exception as e: crit(f"Error during startup: {str(e)}") @@ -93,7 +85,6 @@ async def lifespan(app: FastAPI): crit("Shutting down...") # Perform any cleanup operations here if needed - app = FastAPI(lifespan=lifespan) app.add_middleware( @@ -124,7 +115,6 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware): content={"detail": "Invalid or missing API key"} ) response = await call_next(request) - # debug(f"Request from {client_ip} is complete") return response # Add the middleware to your FastAPI app @@ -136,7 +126,6 @@ async def http_exception_handler(request: Request, exc: HTTPException): err(f"Request: {request.method} {request.url}") return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail}) - @app.middleware("http") async def handle_exception_middleware(request: Request, call_next): try: @@ -149,6 +138,19 @@ async def handle_exception_middleware(request: Request, call_next): raise return response +@app.middleware("http") +async def sync_middleware(request: Request, call_next): + response = await call_next(request) + + # Check if the request was a database write operation + if request.method in ["POST", "PUT", "PATCH", "DELETE"]: + try: + # Push changes to other databases + await API.push_changes_to_all() + except Exception as e: + err(f"Error pushing changes to other databases: {str(e)}") + + return response def load_router(router_name): router_file = ROUTER_DIR / f'{router_name}.py' @@ -160,25 +162,16 @@ def load_router(router_name): module = importlib.import_module(module_path) router = getattr(module, router_name) app.include_router(router) - # module_logger.info(f"{router_name.capitalize()} router loaded.") except (ImportError, AttributeError) as e: module_logger.critical(f"Failed to load router {router_name}: {e}") else: module_logger.error(f"Router file for {router_name} does not exist.") def main(argv): - if args.test: - load_router(args.test) - else: - crit(f"sijapi launched") - crit(f"Arguments: {args}") - for module_name in API.MODULES.__fields__: - if getattr(API.MODULES, module_name): - load_router(module_name) - config = HypercornConfig() config.bind = [API.BIND] + config.startup_timeout = 3600 # 1 hour asyncio.run(serve(app, config)) if __name__ == "__main__": - main(sys.argv[1:]) \ No newline at end of file + main(sys.argv[1:]) diff --git a/sijapi/classes.py b/sijapi/classes.py index d8e7fcf..1d4291d 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -31,6 +31,8 @@ def warn(text: str): logger.warning(text) def err(text: str): logger.error(text) def crit(text: str): logger.critical(text) +TS_ID=os.getenv("TS_ID", "NULL") + T = TypeVar('T', bound='Configuration') class Configuration(BaseModel): HOME: Path = Path.home() @@ -156,6 +158,17 @@ class Configuration(BaseModel): arbitrary_types_allowed = True +from pydantic import BaseModel, Field +from typing import Any, Dict, List, Union +from pathlib import Path +import yaml +import re +import os +import asyncpg +from contextlib import asynccontextmanager +from sijapi import TS_ID +import traceback + class APIConfig(BaseModel): HOST: str PORT: int @@ -183,7 +196,6 @@ class APIConfig(BaseModel): try: with open(secrets_path, 'r') as file: secrets_data = yaml.safe_load(file) - # info(f"Loaded secrets: {secrets_data}") except FileNotFoundError: err(f"Secrets file not found: {secrets_path}") secrets_data = {} @@ -279,8 +291,7 @@ class APIConfig(BaseModel): @property def local_db(self): - ts_id = os.environ.get('TS_ID') - return next((db for db in self.POOL if db['ts_id'] == ts_id), None) + return next((db for db in self.POOL if db['ts_id'] == TS_ID), None) @asynccontextmanager async def get_connection(self, pool_entry: Dict[str, Any] = None): @@ -305,104 +316,158 @@ class APIConfig(BaseModel): crit(f"Error: {str(e)}") raise - async def push_changes(self, query: str, *args): - connections = [] - try: - for pool_entry in self.POOL[1:]: # Skip the first (local) database - conn = await self.get_connection(pool_entry).__aenter__() - connections.append(conn) + async def initialize_sync(self): + async with self.get_connection() as conn: + # Create sync_status table + await conn.execute(""" + CREATE TABLE IF NOT EXISTS sync_status ( + table_name TEXT, + server_id TEXT, + last_synced_version INTEGER, + PRIMARY KEY (table_name, server_id) + ) + """) + + # Get all tables + tables = await conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + # Add version and server_id columns to all tables, create triggers + for table in tables: + table_name = table['tablename'] + await conn.execute(f""" + ALTER TABLE "{table_name}" + ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1, + ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{TS_ID}'; - results = await asyncio.gather( - *[conn.execute(query, *args) for conn in connections], - return_exceptions=True - ) + CREATE OR REPLACE FUNCTION update_version_and_server_id() + RETURNS TRIGGER AS $$ + BEGIN + NEW.version = COALESCE(OLD.version, 0) + 1; + NEW.server_id = '{TS_ID}'; + RETURN NEW; + END; + $$ LANGUAGE plpgsql; - for pool_entry, result in zip(self.POOL[1:], results): - if isinstance(result, Exception): - err(f"Failed to push to {pool_entry['ts_ip']}: {str(result)}") - else: - info(f"Successfully pushed to {pool_entry['ts_ip']}") + DROP TRIGGER IF EXISTS update_version_and_server_id_trigger ON "{table_name}"; + CREATE TRIGGER update_version_and_server_id_trigger + BEFORE INSERT OR UPDATE ON "{table_name}" + FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id(); - finally: - for conn in connections: - await conn.__aexit__(None, None, None) + INSERT INTO sync_status (table_name, server_id, last_synced_version) + VALUES ('{table_name}', '{TS_ID}', 0) + ON CONFLICT (table_name, server_id) DO NOTHING; + """) - async def get_default_source(self): - local = self.local_db - for db in self.POOL: - if db != local: - try: - async with self.get_connection(db): - return db - except: + async def get_most_recent_source(self): + most_recent_source = None + max_version = -1 + + for pool_entry in self.POOL: + if pool_entry['ts_id'] == TS_ID: + continue + + try: + async with self.get_connection(pool_entry) as conn: + version = await conn.fetchval(""" + SELECT MAX(last_synced_version) FROM sync_status + """) + if version > max_version: + max_version = version + most_recent_source = pool_entry + except Exception as e: + err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}") + + return most_recent_source + + async def pull_changes(self, source_pool_entry): + async with self.get_connection(source_pool_entry) as source_conn: + async with self.get_connection() as dest_conn: + tables = await source_conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] + last_synced_version = await self.get_last_synced_version(table_name, source_pool_entry['ts_id']) + + changes = await source_conn.fetch(f""" + SELECT * FROM "{table_name}" + WHERE version > $1 AND server_id = $2 + ORDER BY version ASC + """, last_synced_version, source_pool_entry['ts_id']) + + for change in changes: + columns = change.keys() + values = [change[col] for col in columns] + await dest_conn.execute(f""" + INSERT INTO "{table_name}" ({', '.join(columns)}) + VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))}) + ON CONFLICT (id) DO UPDATE SET + {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')} + """, *values) + + if changes: + await self.update_sync_status(table_name, source_pool_entry['ts_id'], changes[-1]['version']) + + async def push_changes_to_all(self): + async with self.get_connection() as local_conn: + tables = await local_conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for pool_entry in self.POOL: + if pool_entry['ts_id'] == TS_ID: continue - return None - - - async def pull_changes(self, source_pool_entry: Dict[str, Any] = None): - try: - if source_pool_entry is None: - source_pool_entry = await self.get_default_source() - - if source_pool_entry is None: - err("No available source for pulling changes") - return - - async with self.get_connection(source_pool_entry) as source_conn: - async with self.get_connection() as dest_conn: - tables = await source_conn.fetch( - "SELECT tablename FROM pg_tables WHERE schemaname = 'public'" - ) - for table in tables: - table_name = table['tablename'] - info(f"Processing table: {table_name}") - - # Get primary key column(s) - pk_columns = await source_conn.fetch(""" - SELECT a.attname - FROM pg_index i - JOIN pg_attribute a ON a.attrelid = i.indrelid - AND a.attnum = ANY(i.indkey) - WHERE i.indrelid = $1::regclass - AND i.indisprimary; - """, table_name) - - pk_cols = [col['attname'] for col in pk_columns] - info(f"Primary key columns for {table_name}: {pk_cols}") - if not pk_cols: - warn(f"No primary key found for table {table_name}. Skipping.") - continue - - # Fetch all rows from the source table - rows = await source_conn.fetch(f"SELECT * FROM {table_name}") - info(f"Fetched {len(rows)} rows from {table_name}") - if rows: - columns = list(rows[0].keys()) - info(f"Columns for {table_name}: {columns}") - # Upsert records to the destination table - for row in rows: - try: - query = f""" - INSERT INTO {table_name} ({', '.join(columns)}) - VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))}) - ON CONFLICT ({', '.join(pk_cols)}) DO UPDATE SET - {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in pk_cols)} - """ - info(f"Executing query: {query}") - info(f"With values: {[row[col] for col in columns]}") - await dest_conn.execute(query, *[row[col] for col in columns]) - except Exception as e: - err(f"Error processing row in {table_name}: {str(e)}") - err(f"Problematic row: {row}") - - info(f"Completed processing table: {table_name}") - - info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}") - except Exception as e: - err(f"Unexpected error in pull_changes: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") + + try: + async with self.get_connection(pool_entry) as remote_conn: + for table in tables: + table_name = table['tablename'] + last_synced_version = await self.get_last_synced_version(table_name, pool_entry['ts_id']) + + changes = await local_conn.fetch(f""" + SELECT * FROM "{table_name}" + WHERE version > $1 AND server_id = $2 + ORDER BY version ASC + """, last_synced_version, TS_ID) + + for change in changes: + columns = change.keys() + values = [change[col] for col in columns] + await remote_conn.execute(f""" + INSERT INTO "{table_name}" ({', '.join(columns)}) + VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))}) + ON CONFLICT (id) DO UPDATE SET + {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')} + """, *values) + + if changes: + await self.update_sync_status(table_name, pool_entry['ts_id'], changes[-1]['version']) + + info(f"Successfully pushed changes to {pool_entry['ts_id']}") + except Exception as e: + err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") + async def get_last_synced_version(self, table_name, server_id): + async with self.get_connection() as conn: + return await conn.fetchval(""" + SELECT last_synced_version FROM sync_status + WHERE table_name = $1 AND server_id = $2 + """, table_name, server_id) or 0 + async def update_sync_status(self, table_name, server_id, version): + async with self.get_connection() as conn: + await conn.execute(""" + INSERT INTO sync_status (table_name, server_id, last_synced_version) + VALUES ($1, $2, $3) + ON CONFLICT (table_name, server_id) DO UPDATE + SET last_synced_version = EXCLUDED.last_synced_version + """, table_name, server_id, version) async def sync_schema(self): for pool_entry in self.POOL: @@ -459,45 +524,6 @@ class APIConfig(BaseModel): END $$; """) - async def get_schema(self, pool_entry: Dict[str, Any]): - async with self.get_connection(pool_entry) as conn: - tables = await conn.fetch(""" - SELECT table_name, column_name, data_type, character_maximum_length, - is_nullable, column_default, ordinal_position - FROM information_schema.columns - WHERE table_schema = 'public' - ORDER BY table_name, ordinal_position - """) - - indexes = await conn.fetch(""" - SELECT indexname, indexdef - FROM pg_indexes - WHERE schemaname = 'public' - """) - - constraints = await conn.fetch(""" - SELECT conname, contype, conrelid::regclass::text as table_name, - pg_get_constraintdef(oid) as definition - FROM pg_constraint - WHERE connamespace = 'public'::regnamespace - """) - - return { - 'tables': tables, - 'indexes': indexes, - 'constraints': constraints - } - - async def create_sequence_if_not_exists(self, conn, sequence_name): - await conn.execute(f""" - DO $$ - BEGIN - IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN - CREATE SEQUENCE {sequence_name}; - END IF; - END $$; - """) - async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema): async with self.get_connection(pool_entry) as conn: source_tables = {t['table_name']: t for t in source_schema['tables']} @@ -567,8 +593,6 @@ class APIConfig(BaseModel): await conn.execute(sql) except Exception as e: err(f"Error processing table {table_name}: {str(e)}") - # Optionally, you might want to raise this exception if you want to stop the entire process - # raise try: source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']}