Auto-update: Tue Jul 30 14:13:53 PDT 2024

This commit is contained in:
sanj 2024-07-30 14:13:53 -07:00
parent 3ab9f6bc81
commit 9a16e9f46b

View file

@ -178,6 +178,7 @@ class APIConfig(BaseModel):
TZ: str
KEYS: List[str]
GARBAGE: Dict[str, Any]
_db_pools: Dict[str, asyncpg.Pool] = {}
@classmethod
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
@ -298,27 +299,48 @@ class APIConfig(BaseModel):
if pool_entry is None:
pool_entry = self.local_db
info(f"Attempting to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
if pool_key not in self._db_pools:
try:
conn = await asyncpg.connect(
self._db_pools[pool_key] = await asyncpg.create_pool(
host=pool_entry['ts_ip'],
port=pool_entry['db_port'],
user=pool_entry['db_user'],
password=pool_entry['db_pass'],
database=pool_entry['db_name'],
timeout=5 # Add a timeout to prevent hanging
min_size=1,
max_size=10, # adjust as needed
timeout=5 # connection timeout in seconds
)
except Exception as e:
err(f"Failed to create connection pool for {pool_key}: {str(e)}")
raise
try:
async with self._db_pools[pool_key].acquire() as conn:
yield conn
finally:
await conn.close()
except asyncpg.exceptions.ConnectionDoesNotExistError:
err(f"Failed to acquire connection from pool for {pool_key}: Connection does not exist")
raise
except asyncpg.exceptions.ConnectionFailureError:
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
err(f"Failed to acquire connection from pool for {pool_key}: Connection failure")
raise
except Exception as e:
err(f"Unexpected error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
err(f"Unexpected error when acquiring connection from pool for {pool_key}: {str(e)}")
raise
async def close_db_pools(self):
info("Closing database connection pools...")
for pool_key, pool in self._db_pools.items():
try:
await pool.close()
info(f"Closed pool for {pool_key}")
except Exception as e:
err(f"Error closing pool for {pool_key}: {str(e)}")
self._db_pools.clear()
info("All database connection pools closed.")
async def initialize_sync(self):
local_ts_id = os.environ.get('TS_ID')
for pool_entry in self.POOL:
@ -348,7 +370,14 @@ class APIConfig(BaseModel):
BEGIN
BEGIN
ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1;
EXCEPTION
WHEN duplicate_column THEN
-- Do nothing, column already exists
END;
BEGIN
ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
EXCEPTION
WHEN duplicate_column THEN
@ -357,6 +386,16 @@ class APIConfig(BaseModel):
END $$;
""")
# Verify that the columns were added
result = await conn.fetchrow(f"""
SELECT
EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'version') as has_version,
EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'server_id') as has_server_id
""")
if not (result['has_version'] and result['has_server_id']):
raise Exception(f"Failed to add version and/or server_id columns to table {table_name}")
async def ensure_sync_trigger(self, conn, table_name):
await conn.execute(f"""
CREATE OR REPLACE FUNCTION update_version_and_server_id()
@ -391,9 +430,11 @@ class APIConfig(BaseModel):
continue
version = await conn.fetchval("""
SELECT COALESCE(MAX(version), -1) FROM (
SELECT MAX(version) as version FROM information_schema.columns
WHERE table_schema = 'public' AND column_name = 'version'
SELECT COALESCE(MAX(version), -1)
FROM (
SELECT MAX(version) as version
FROM pg_tables
WHERE schemaname = 'public'
) as subquery
""")
if version > max_version:
@ -413,10 +454,12 @@ class APIConfig(BaseModel):
FROM information_schema.columns
WHERE table_schema = 'public'
AND column_name = 'version'
AND table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
)
""")
return result
async def pull_changes(self, source_pool_entry, batch_size=10000):
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
info("Skipping self-sync")