diff --git a/sijapi/__main__.py b/sijapi/__main__.py index b692955..f8e5d1c 100755 --- a/sijapi/__main__.py +++ b/sijapi/__main__.py @@ -80,6 +80,7 @@ async def lifespan(app: FastAPI): crit("Database pools closed.") + app = FastAPI(lifespan=lifespan) app.add_middleware( diff --git a/sijapi/classes.py b/sijapi/classes.py index ead927b..90eecf1 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -298,7 +298,7 @@ class APIConfig(BaseModel): if pool_entry is None: pool_entry = self.local_db - info(f"Attempting to connect to database: {pool_entry}") + info(f"Attempting to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") try: conn = await asyncpg.connect( host=pool_entry['ts_ip'], @@ -312,39 +312,36 @@ class APIConfig(BaseModel): yield conn finally: await conn.close() - except asyncpg.exceptions.ConnectionDoesNotExistError: - err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']} - Connection does not exist") - raise except asyncpg.exceptions.ConnectionFailureError: - err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']} - Connection failure") - raise - except asyncpg.exceptions.PostgresError as e: - err(f"PostgreSQL error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}") + err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") raise except Exception as e: err(f"Unexpected error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}") raise - - async def initialize_sync(self): + local_ts_id = os.environ.get('TS_ID') for pool_entry in self.POOL: + if pool_entry['ts_id'] == local_ts_id: + continue # Skip local database try: async with self.get_connection(pool_entry) as conn: - tables = await conn.fetch(""" - SELECT tablename FROM pg_tables - WHERE schemaname = 'public' - """) - - for table in tables: - table_name = table['tablename'] - await self.ensure_sync_columns(conn, table_name) - await self.ensure_sync_trigger(conn, table_name) - + await self.ensure_sync_structure(conn) info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.") except Exception as e: err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}") + async def ensure_sync_structure(self, conn): + tables = await conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] + await self.ensure_sync_columns(conn, table_name) + await self.ensure_sync_trigger(conn, table_name) + async def ensure_sync_columns(self, conn, table_name): await conn.execute(f""" DO $$ @@ -378,17 +375,21 @@ class APIConfig(BaseModel): FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id(); """) - async def get_most_recent_source(self): most_recent_source = None max_version = -1 + local_ts_id = os.environ.get('TS_ID') for pool_entry in self.POOL: - if pool_entry['ts_id'] == os.environ.get('TS_ID'): - continue + if pool_entry['ts_id'] == local_ts_id: + continue # Skip local database try: async with self.get_connection(pool_entry) as conn: + if not await self.check_version_column_exists(conn): + warn(f"Version column does not exist in {pool_entry['ts_id']}. Skipping.") + continue + version = await conn.fetchval(""" SELECT COALESCE(MAX(version), -1) FROM ( SELECT MAX(version) as version FROM information_schema.columns @@ -398,14 +399,23 @@ class APIConfig(BaseModel): if version > max_version: max_version = version most_recent_source = pool_entry + except asyncpg.exceptions.ConnectionFailureError: + warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") except Exception as e: - warn(f"Failed to connect to or query database for {pool_entry['ts_id']}: {str(e)}") + warn(f"Error checking version for {pool_entry['ts_id']}: {str(e)}") return most_recent_source - - - + async def check_version_column_exists(self, conn): + result = await conn.fetchval(""" + SELECT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND column_name = 'version' + ) + """) + return result async def pull_changes(self, source_pool_entry, batch_size=10000): if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): @@ -492,7 +502,6 @@ class APIConfig(BaseModel): # Ensure temporary table is dropped await conn.execute(f"DROP TABLE IF EXISTS {temp_table_name}") - async def push_changes_to_all(self): for pool_entry in self.POOL: if pool_entry['ts_id'] != os.environ.get('TS_ID'): @@ -501,7 +510,6 @@ class APIConfig(BaseModel): except Exception as e: err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") - async def push_changes_to_one(self, pool_entry): try: async with self.get_connection() as local_conn: @@ -543,21 +551,6 @@ class APIConfig(BaseModel): err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") - async def ensure_sync_columns(self, conn, table_name): - await conn.execute(f""" - DO $$ - BEGIN - BEGIN - ALTER TABLE "{table_name}" - ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1, - ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}'; - EXCEPTION - WHEN duplicate_column THEN - -- Do nothing, column already exists - END; - END $$; - """) - async def get_last_synced_version(self, conn, table_name, server_id): return await conn.fetchval(f""" SELECT COALESCE(MAX(version), 0)