Auto-update: Tue Jul 30 14:13:53 PDT 2024
This commit is contained in:
parent
3ab9f6bc81
commit
9a16e9f46b
1 changed files with 61 additions and 18 deletions
|
@ -178,6 +178,7 @@ class APIConfig(BaseModel):
|
||||||
TZ: str
|
TZ: str
|
||||||
KEYS: List[str]
|
KEYS: List[str]
|
||||||
GARBAGE: Dict[str, Any]
|
GARBAGE: Dict[str, Any]
|
||||||
|
_db_pools: Dict[str, asyncpg.Pool] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
|
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
|
||||||
|
@ -298,27 +299,48 @@ class APIConfig(BaseModel):
|
||||||
if pool_entry is None:
|
if pool_entry is None:
|
||||||
pool_entry = self.local_db
|
pool_entry = self.local_db
|
||||||
|
|
||||||
info(f"Attempting to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
|
||||||
try:
|
|
||||||
conn = await asyncpg.connect(
|
if pool_key not in self._db_pools:
|
||||||
host=pool_entry['ts_ip'],
|
|
||||||
port=pool_entry['db_port'],
|
|
||||||
user=pool_entry['db_user'],
|
|
||||||
password=pool_entry['db_pass'],
|
|
||||||
database=pool_entry['db_name'],
|
|
||||||
timeout=5 # Add a timeout to prevent hanging
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
|
self._db_pools[pool_key] = await asyncpg.create_pool(
|
||||||
|
host=pool_entry['ts_ip'],
|
||||||
|
port=pool_entry['db_port'],
|
||||||
|
user=pool_entry['db_user'],
|
||||||
|
password=pool_entry['db_pass'],
|
||||||
|
database=pool_entry['db_name'],
|
||||||
|
min_size=1,
|
||||||
|
max_size=10, # adjust as needed
|
||||||
|
timeout=5 # connection timeout in seconds
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Failed to create connection pool for {pool_key}: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self._db_pools[pool_key].acquire() as conn:
|
||||||
yield conn
|
yield conn
|
||||||
finally:
|
except asyncpg.exceptions.ConnectionDoesNotExistError:
|
||||||
await conn.close()
|
err(f"Failed to acquire connection from pool for {pool_key}: Connection does not exist")
|
||||||
|
raise
|
||||||
except asyncpg.exceptions.ConnectionFailureError:
|
except asyncpg.exceptions.ConnectionFailureError:
|
||||||
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
err(f"Failed to acquire connection from pool for {pool_key}: Connection failure")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Unexpected error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
|
err(f"Unexpected error when acquiring connection from pool for {pool_key}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def close_db_pools(self):
|
||||||
|
info("Closing database connection pools...")
|
||||||
|
for pool_key, pool in self._db_pools.items():
|
||||||
|
try:
|
||||||
|
await pool.close()
|
||||||
|
info(f"Closed pool for {pool_key}")
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Error closing pool for {pool_key}: {str(e)}")
|
||||||
|
self._db_pools.clear()
|
||||||
|
info("All database connection pools closed.")
|
||||||
|
|
||||||
async def initialize_sync(self):
|
async def initialize_sync(self):
|
||||||
local_ts_id = os.environ.get('TS_ID')
|
local_ts_id = os.environ.get('TS_ID')
|
||||||
for pool_entry in self.POOL:
|
for pool_entry in self.POOL:
|
||||||
|
@ -348,7 +370,14 @@ class APIConfig(BaseModel):
|
||||||
BEGIN
|
BEGIN
|
||||||
BEGIN
|
BEGIN
|
||||||
ALTER TABLE "{table_name}"
|
ALTER TABLE "{table_name}"
|
||||||
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
|
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1;
|
||||||
|
EXCEPTION
|
||||||
|
WHEN duplicate_column THEN
|
||||||
|
-- Do nothing, column already exists
|
||||||
|
END;
|
||||||
|
|
||||||
|
BEGIN
|
||||||
|
ALTER TABLE "{table_name}"
|
||||||
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
|
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
|
||||||
EXCEPTION
|
EXCEPTION
|
||||||
WHEN duplicate_column THEN
|
WHEN duplicate_column THEN
|
||||||
|
@ -356,6 +385,16 @@ class APIConfig(BaseModel):
|
||||||
END;
|
END;
|
||||||
END $$;
|
END $$;
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# Verify that the columns were added
|
||||||
|
result = await conn.fetchrow(f"""
|
||||||
|
SELECT
|
||||||
|
EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'version') as has_version,
|
||||||
|
EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'server_id') as has_server_id
|
||||||
|
""")
|
||||||
|
|
||||||
|
if not (result['has_version'] and result['has_server_id']):
|
||||||
|
raise Exception(f"Failed to add version and/or server_id columns to table {table_name}")
|
||||||
|
|
||||||
async def ensure_sync_trigger(self, conn, table_name):
|
async def ensure_sync_trigger(self, conn, table_name):
|
||||||
await conn.execute(f"""
|
await conn.execute(f"""
|
||||||
|
@ -391,9 +430,11 @@ class APIConfig(BaseModel):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
version = await conn.fetchval("""
|
version = await conn.fetchval("""
|
||||||
SELECT COALESCE(MAX(version), -1) FROM (
|
SELECT COALESCE(MAX(version), -1)
|
||||||
SELECT MAX(version) as version FROM information_schema.columns
|
FROM (
|
||||||
WHERE table_schema = 'public' AND column_name = 'version'
|
SELECT MAX(version) as version
|
||||||
|
FROM pg_tables
|
||||||
|
WHERE schemaname = 'public'
|
||||||
) as subquery
|
) as subquery
|
||||||
""")
|
""")
|
||||||
if version > max_version:
|
if version > max_version:
|
||||||
|
@ -413,10 +454,12 @@ class APIConfig(BaseModel):
|
||||||
FROM information_schema.columns
|
FROM information_schema.columns
|
||||||
WHERE table_schema = 'public'
|
WHERE table_schema = 'public'
|
||||||
AND column_name = 'version'
|
AND column_name = 'version'
|
||||||
|
AND table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def pull_changes(self, source_pool_entry, batch_size=10000):
|
async def pull_changes(self, source_pool_entry, batch_size=10000):
|
||||||
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
|
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
|
||||||
info("Skipping self-sync")
|
info("Skipping self-sync")
|
||||||
|
|
Loading…
Reference in a new issue