From 9a16e9f46b8f95cfdc9dacdc1abc4a206793bfc3 Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Tue, 30 Jul 2024 14:13:53 -0700 Subject: [PATCH] Auto-update: Tue Jul 30 14:13:53 PDT 2024 --- sijapi/classes.py | 79 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 61 insertions(+), 18 deletions(-) diff --git a/sijapi/classes.py b/sijapi/classes.py index 90eecf1..0df567d 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -178,6 +178,7 @@ class APIConfig(BaseModel): TZ: str KEYS: List[str] GARBAGE: Dict[str, Any] + _db_pools: Dict[str, asyncpg.Pool] = {} @classmethod def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]): @@ -298,27 +299,48 @@ class APIConfig(BaseModel): if pool_entry is None: pool_entry = self.local_db - info(f"Attempting to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") - try: - conn = await asyncpg.connect( - host=pool_entry['ts_ip'], - port=pool_entry['db_port'], - user=pool_entry['db_user'], - password=pool_entry['db_pass'], - database=pool_entry['db_name'], - timeout=5 # Add a timeout to prevent hanging - ) + pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}" + + if pool_key not in self._db_pools: try: + self._db_pools[pool_key] = await asyncpg.create_pool( + host=pool_entry['ts_ip'], + port=pool_entry['db_port'], + user=pool_entry['db_user'], + password=pool_entry['db_pass'], + database=pool_entry['db_name'], + min_size=1, + max_size=10, # adjust as needed + timeout=5 # connection timeout in seconds + ) + except Exception as e: + err(f"Failed to create connection pool for {pool_key}: {str(e)}") + raise + + try: + async with self._db_pools[pool_key].acquire() as conn: yield conn - finally: - await conn.close() + except asyncpg.exceptions.ConnectionDoesNotExistError: + err(f"Failed to acquire connection from pool for {pool_key}: Connection does not exist") + raise except asyncpg.exceptions.ConnectionFailureError: - err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") + err(f"Failed to acquire connection from pool for {pool_key}: Connection failure") raise except Exception as e: - err(f"Unexpected error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}") + err(f"Unexpected error when acquiring connection from pool for {pool_key}: {str(e)}") raise + async def close_db_pools(self): + info("Closing database connection pools...") + for pool_key, pool in self._db_pools.items(): + try: + await pool.close() + info(f"Closed pool for {pool_key}") + except Exception as e: + err(f"Error closing pool for {pool_key}: {str(e)}") + self._db_pools.clear() + info("All database connection pools closed.") + async def initialize_sync(self): local_ts_id = os.environ.get('TS_ID') for pool_entry in self.POOL: @@ -348,7 +370,14 @@ class APIConfig(BaseModel): BEGIN BEGIN ALTER TABLE "{table_name}" - ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1, + ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1; + EXCEPTION + WHEN duplicate_column THEN + -- Do nothing, column already exists + END; + + BEGIN + ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}'; EXCEPTION WHEN duplicate_column THEN @@ -356,6 +385,16 @@ class APIConfig(BaseModel): END; END $$; """) + + # Verify that the columns were added + result = await conn.fetchrow(f""" + SELECT + EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'version') as has_version, + EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'server_id') as has_server_id + """) + + if not (result['has_version'] and result['has_server_id']): + raise Exception(f"Failed to add version and/or server_id columns to table {table_name}") async def ensure_sync_trigger(self, conn, table_name): await conn.execute(f""" @@ -391,9 +430,11 @@ class APIConfig(BaseModel): continue version = await conn.fetchval(""" - SELECT COALESCE(MAX(version), -1) FROM ( - SELECT MAX(version) as version FROM information_schema.columns - WHERE table_schema = 'public' AND column_name = 'version' + SELECT COALESCE(MAX(version), -1) + FROM ( + SELECT MAX(version) as version + FROM pg_tables + WHERE schemaname = 'public' ) as subquery """) if version > max_version: @@ -413,10 +454,12 @@ class APIConfig(BaseModel): FROM information_schema.columns WHERE table_schema = 'public' AND column_name = 'version' + AND table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') ) """) return result + async def pull_changes(self, source_pool_entry, batch_size=10000): if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): info("Skipping self-sync")