Auto-update: Tue Jul 30 11:51:53 PDT 2024

This commit is contained in:
sanj 2024-07-30 11:51:53 -07:00
parent 72047b561c
commit 66775e0a82
2 changed files with 51 additions and 20 deletions

View file

@ -31,17 +31,13 @@ args = parser.parse_args()
L.setup_from_args(args) L.setup_from_args(args)
print(f"Debug modules after setup: {L.debug_modules}") print(f"Debug modules after setup: {L.debug_modules}")
logger = L.get_module_logger("main") logger = L.get_module_logger("main")
def debug(text: str): logger.debug(text) def debug(text: str): logger.debug(text)
debug(f"Debug message.")
def info(text: str): logger.info(text) def info(text: str): logger.info(text)
info(f"Info message.")
def warn(text: str): logger.warning(text) def warn(text: str): logger.warning(text)
warn(f"Warning message.")
def err(text: str): logger.error(text) def err(text: str): logger.error(text)
err(f"Error message.")
def crit(text: str): logger.critical(text) def crit(text: str): logger.critical(text)
crit(f"Critical message.")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
@ -79,8 +75,8 @@ async def lifespan(app: FastAPI):
# Shutdown # Shutdown
crit("Shutting down...") crit("Shutting down...")
# Perform any cleanup operations here if needed await API.close_db_pools()
crit("Database pools closed.")

View file

@ -165,6 +165,32 @@ class Configuration(BaseModel):
class DatabasePool:
def __init__(self):
self.pools = {}
async def get_connection(self, pool_entry):
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
if pool_key not in self.pools:
self.pools[pool_key] = await asyncpg.create_pool(
host=pool_entry['ts_ip'],
port=pool_entry['db_port'],
user=pool_entry['db_user'],
password=pool_entry['db_pass'],
database=pool_entry['db_name'],
min_size=1,
max_size=10
)
return await self.pools[pool_key].acquire()
async def release_connection(self, pool_entry, connection):
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
await self.pools[pool_key].release(connection)
async def close_all(self):
for pool in self.pools.values():
await pool.close()
class APIConfig(BaseModel): class APIConfig(BaseModel):
HOST: str HOST: str
PORT: int PORT: int
@ -178,6 +204,11 @@ class APIConfig(BaseModel):
TZ: str TZ: str
KEYS: List[str] KEYS: List[str]
GARBAGE: Dict[str, Any] GARBAGE: Dict[str, Any]
db_pool: DatabasePool = None
def __init__(self, **data):
super().__init__(**data)
self.db_pool = DatabasePool()
@classmethod @classmethod
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]): def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
@ -300,19 +331,20 @@ class APIConfig(BaseModel):
info(f"Attempting to connect to database: {pool_entry}") info(f"Attempting to connect to database: {pool_entry}")
try: try:
conn = await asyncpg.connect( conn = await self.db_pool.get_connection(pool_entry)
host=pool_entry['ts_ip'],
port=pool_entry['db_port'],
user=pool_entry['db_user'],
password=pool_entry['db_pass'],
database=pool_entry['db_name']
)
try: try:
yield conn yield conn
finally: finally:
await conn.close() await self.db_pool.release_connection(pool_entry, conn)
except asyncpg.exceptions.ConnectionDoesNotExistError:
err(f"Connection to database {pool_entry['ts_ip']}:{pool_entry['db_port']} does not exist or has been closed")
raise
except asyncpg.exceptions.ConnectionFailureError as e:
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
err(f"Connection error: {str(e)}")
raise
except Exception as e: except Exception as e:
warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") err(f"Unexpected error connecting to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
err(f"Error: {str(e)}") err(f"Error: {str(e)}")
raise raise
@ -331,11 +363,12 @@ class APIConfig(BaseModel):
await self.create_sync_trigger(conn, table_name) await self.create_sync_trigger(conn, table_name)
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.") info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.")
except asyncpg.exceptions.ConnectionFailureError:
err(f"Failed to connect to database during initialization: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
except Exception as e: except Exception as e:
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}") err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
async def ensure_sync_columns(self, conn, table_name): async def ensure_sync_columns(self, conn, table_name):
try: try:
await conn.execute(f""" await conn.execute(f"""
@ -356,7 +389,6 @@ class APIConfig(BaseModel):
err(f"Error ensuring sync columns for table {table_name}: {str(e)}") err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
async def create_sync_trigger(self, conn, table_name): async def create_sync_trigger(self, conn, table_name):
await conn.execute(f""" await conn.execute(f"""
CREATE OR REPLACE FUNCTION update_version_and_server_id() CREATE OR REPLACE FUNCTION update_version_and_server_id()
@ -416,6 +448,8 @@ class APIConfig(BaseModel):
if version > max_version: if version > max_version:
max_version = version max_version = version
most_recent_source = pool_entry most_recent_source = pool_entry
except asyncpg.exceptions.ConnectionFailureError:
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
except Exception as e: except Exception as e:
err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}") err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}")
@ -426,7 +460,6 @@ class APIConfig(BaseModel):
return most_recent_source return most_recent_source
async def pull_changes(self, source_pool_entry, batch_size=10000): async def pull_changes(self, source_pool_entry, batch_size=10000):
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
info("Skipping self-sync") info("Skipping self-sync")
@ -684,7 +717,9 @@ class APIConfig(BaseModel):
""", table_name, column_name) """, table_name, column_name)
return exists return exists
async def close_db_pools(self):
if self.db_pool:
await self.db_pool.close_all()