Auto-update: Tue Jul 30 14:13:53 PDT 2024
This commit is contained in:
parent
3ab9f6bc81
commit
9a16e9f46b
1 changed files with 61 additions and 18 deletions
|
@ -178,6 +178,7 @@ class APIConfig(BaseModel):
|
|||
TZ: str
|
||||
KEYS: List[str]
|
||||
GARBAGE: Dict[str, Any]
|
||||
_db_pools: Dict[str, asyncpg.Pool] = {}
|
||||
|
||||
@classmethod
|
||||
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
|
||||
|
@ -298,27 +299,48 @@ class APIConfig(BaseModel):
|
|||
if pool_entry is None:
|
||||
pool_entry = self.local_db
|
||||
|
||||
info(f"Attempting to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
||||
try:
|
||||
conn = await asyncpg.connect(
|
||||
host=pool_entry['ts_ip'],
|
||||
port=pool_entry['db_port'],
|
||||
user=pool_entry['db_user'],
|
||||
password=pool_entry['db_pass'],
|
||||
database=pool_entry['db_name'],
|
||||
timeout=5 # Add a timeout to prevent hanging
|
||||
)
|
||||
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
|
||||
|
||||
if pool_key not in self._db_pools:
|
||||
try:
|
||||
self._db_pools[pool_key] = await asyncpg.create_pool(
|
||||
host=pool_entry['ts_ip'],
|
||||
port=pool_entry['db_port'],
|
||||
user=pool_entry['db_user'],
|
||||
password=pool_entry['db_pass'],
|
||||
database=pool_entry['db_name'],
|
||||
min_size=1,
|
||||
max_size=10, # adjust as needed
|
||||
timeout=5 # connection timeout in seconds
|
||||
)
|
||||
except Exception as e:
|
||||
err(f"Failed to create connection pool for {pool_key}: {str(e)}")
|
||||
raise
|
||||
|
||||
try:
|
||||
async with self._db_pools[pool_key].acquire() as conn:
|
||||
yield conn
|
||||
finally:
|
||||
await conn.close()
|
||||
except asyncpg.exceptions.ConnectionDoesNotExistError:
|
||||
err(f"Failed to acquire connection from pool for {pool_key}: Connection does not exist")
|
||||
raise
|
||||
except asyncpg.exceptions.ConnectionFailureError:
|
||||
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
||||
err(f"Failed to acquire connection from pool for {pool_key}: Connection failure")
|
||||
raise
|
||||
except Exception as e:
|
||||
err(f"Unexpected error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
|
||||
err(f"Unexpected error when acquiring connection from pool for {pool_key}: {str(e)}")
|
||||
raise
|
||||
|
||||
async def close_db_pools(self):
|
||||
info("Closing database connection pools...")
|
||||
for pool_key, pool in self._db_pools.items():
|
||||
try:
|
||||
await pool.close()
|
||||
info(f"Closed pool for {pool_key}")
|
||||
except Exception as e:
|
||||
err(f"Error closing pool for {pool_key}: {str(e)}")
|
||||
self._db_pools.clear()
|
||||
info("All database connection pools closed.")
|
||||
|
||||
async def initialize_sync(self):
|
||||
local_ts_id = os.environ.get('TS_ID')
|
||||
for pool_entry in self.POOL:
|
||||
|
@ -348,7 +370,14 @@ class APIConfig(BaseModel):
|
|||
BEGIN
|
||||
BEGIN
|
||||
ALTER TABLE "{table_name}"
|
||||
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
|
||||
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1;
|
||||
EXCEPTION
|
||||
WHEN duplicate_column THEN
|
||||
-- Do nothing, column already exists
|
||||
END;
|
||||
|
||||
BEGIN
|
||||
ALTER TABLE "{table_name}"
|
||||
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
|
||||
EXCEPTION
|
||||
WHEN duplicate_column THEN
|
||||
|
@ -357,6 +386,16 @@ class APIConfig(BaseModel):
|
|||
END $$;
|
||||
""")
|
||||
|
||||
# Verify that the columns were added
|
||||
result = await conn.fetchrow(f"""
|
||||
SELECT
|
||||
EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'version') as has_version,
|
||||
EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'server_id') as has_server_id
|
||||
""")
|
||||
|
||||
if not (result['has_version'] and result['has_server_id']):
|
||||
raise Exception(f"Failed to add version and/or server_id columns to table {table_name}")
|
||||
|
||||
async def ensure_sync_trigger(self, conn, table_name):
|
||||
await conn.execute(f"""
|
||||
CREATE OR REPLACE FUNCTION update_version_and_server_id()
|
||||
|
@ -391,9 +430,11 @@ class APIConfig(BaseModel):
|
|||
continue
|
||||
|
||||
version = await conn.fetchval("""
|
||||
SELECT COALESCE(MAX(version), -1) FROM (
|
||||
SELECT MAX(version) as version FROM information_schema.columns
|
||||
WHERE table_schema = 'public' AND column_name = 'version'
|
||||
SELECT COALESCE(MAX(version), -1)
|
||||
FROM (
|
||||
SELECT MAX(version) as version
|
||||
FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
) as subquery
|
||||
""")
|
||||
if version > max_version:
|
||||
|
@ -413,10 +454,12 @@ class APIConfig(BaseModel):
|
|||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND column_name = 'version'
|
||||
AND table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
|
||||
)
|
||||
""")
|
||||
return result
|
||||
|
||||
|
||||
async def pull_changes(self, source_pool_entry, batch_size=10000):
|
||||
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
|
||||
info("Skipping self-sync")
|
||||
|
|
Loading…
Reference in a new issue