From 66775e0a8278ee598eaba9a56c15ce0bff0b901b Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Tue, 30 Jul 2024 11:51:53 -0700 Subject: [PATCH] Auto-update: Tue Jul 30 11:51:53 PDT 2024 --- sijapi/__main__.py | 10 +++----- sijapi/classes.py | 61 ++++++++++++++++++++++++++++++++++++---------- 2 files changed, 51 insertions(+), 20 deletions(-) diff --git a/sijapi/__main__.py b/sijapi/__main__.py index c7fe3c8..ca102bb 100755 --- a/sijapi/__main__.py +++ b/sijapi/__main__.py @@ -31,17 +31,13 @@ args = parser.parse_args() L.setup_from_args(args) print(f"Debug modules after setup: {L.debug_modules}") + logger = L.get_module_logger("main") def debug(text: str): logger.debug(text) -debug(f"Debug message.") def info(text: str): logger.info(text) -info(f"Info message.") def warn(text: str): logger.warning(text) -warn(f"Warning message.") def err(text: str): logger.error(text) -err(f"Error message.") def crit(text: str): logger.critical(text) -crit(f"Critical message.") @asynccontextmanager async def lifespan(app: FastAPI): @@ -79,8 +75,8 @@ async def lifespan(app: FastAPI): # Shutdown crit("Shutting down...") - # Perform any cleanup operations here if needed - + await API.close_db_pools() + crit("Database pools closed.") diff --git a/sijapi/classes.py b/sijapi/classes.py index 83a02cb..a810cc0 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -165,6 +165,32 @@ class Configuration(BaseModel): +class DatabasePool: + def __init__(self): + self.pools = {} + + async def get_connection(self, pool_entry): + pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}" + if pool_key not in self.pools: + self.pools[pool_key] = await asyncpg.create_pool( + host=pool_entry['ts_ip'], + port=pool_entry['db_port'], + user=pool_entry['db_user'], + password=pool_entry['db_pass'], + database=pool_entry['db_name'], + min_size=1, + max_size=10 + ) + return await self.pools[pool_key].acquire() + + async def release_connection(self, pool_entry, connection): + pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}" + await self.pools[pool_key].release(connection) + + async def close_all(self): + for pool in self.pools.values(): + await pool.close() + class APIConfig(BaseModel): HOST: str PORT: int @@ -178,6 +204,11 @@ class APIConfig(BaseModel): TZ: str KEYS: List[str] GARBAGE: Dict[str, Any] + db_pool: DatabasePool = None + + def __init__(self, **data): + super().__init__(**data) + self.db_pool = DatabasePool() @classmethod def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]): @@ -300,19 +331,20 @@ class APIConfig(BaseModel): info(f"Attempting to connect to database: {pool_entry}") try: - conn = await asyncpg.connect( - host=pool_entry['ts_ip'], - port=pool_entry['db_port'], - user=pool_entry['db_user'], - password=pool_entry['db_pass'], - database=pool_entry['db_name'] - ) + conn = await self.db_pool.get_connection(pool_entry) try: yield conn finally: - await conn.close() + await self.db_pool.release_connection(pool_entry, conn) + except asyncpg.exceptions.ConnectionDoesNotExistError: + err(f"Connection to database {pool_entry['ts_ip']}:{pool_entry['db_port']} does not exist or has been closed") + raise + except asyncpg.exceptions.ConnectionFailureError as e: + err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") + err(f"Connection error: {str(e)}") + raise except Exception as e: - warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") + err(f"Unexpected error connecting to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") err(f"Error: {str(e)}") raise @@ -331,11 +363,12 @@ class APIConfig(BaseModel): await self.create_sync_trigger(conn, table_name) info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.") + except asyncpg.exceptions.ConnectionFailureError: + err(f"Failed to connect to database during initialization: {pool_entry['ts_ip']}:{pool_entry['db_port']}") except Exception as e: err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") - async def ensure_sync_columns(self, conn, table_name): try: await conn.execute(f""" @@ -356,7 +389,6 @@ class APIConfig(BaseModel): err(f"Error ensuring sync columns for table {table_name}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") - async def create_sync_trigger(self, conn, table_name): await conn.execute(f""" CREATE OR REPLACE FUNCTION update_version_and_server_id() @@ -416,6 +448,8 @@ class APIConfig(BaseModel): if version > max_version: max_version = version most_recent_source = pool_entry + except asyncpg.exceptions.ConnectionFailureError: + err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") except Exception as e: err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}") @@ -426,7 +460,6 @@ class APIConfig(BaseModel): return most_recent_source - async def pull_changes(self, source_pool_entry, batch_size=10000): if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): info("Skipping self-sync") @@ -684,7 +717,9 @@ class APIConfig(BaseModel): """, table_name, column_name) return exists - + async def close_db_pools(self): + if self.db_pool: + await self.db_pool.close_all()