From 8665cfeebcc0f0b73c40e0f955254b51e71a03ff Mon Sep 17 00:00:00 2001
From: sanj <67624670+iodrift@users.noreply.github.com>
Date: Tue, 30 Jul 2024 14:13:53 -0700
Subject: [PATCH] Auto-update: Tue Jul 30 14:13:53 PDT 2024

---
 sijapi/classes.py | 79 ++++++++++++++++++++++++++++++++++++-----------
 1 file changed, 61 insertions(+), 18 deletions(-)

diff --git a/sijapi/classes.py b/sijapi/classes.py
index 90eecf1..0df567d 100644
--- a/sijapi/classes.py
+++ b/sijapi/classes.py
@@ -178,6 +178,7 @@ class APIConfig(BaseModel):
     TZ: str
     KEYS: List[str]
     GARBAGE: Dict[str, Any]
+    _db_pools: Dict[str, asyncpg.Pool] = {}
 
     @classmethod
     def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
@@ -298,27 +299,48 @@ class APIConfig(BaseModel):
         if pool_entry is None:
             pool_entry = self.local_db
         
-        info(f"Attempting to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
-        try:
-            conn = await asyncpg.connect(
-                host=pool_entry['ts_ip'],
-                port=pool_entry['db_port'],
-                user=pool_entry['db_user'],
-                password=pool_entry['db_pass'],
-                database=pool_entry['db_name'],
-                timeout=5  # Add a timeout to prevent hanging
-            )
+        pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
+        
+        if pool_key not in self._db_pools:
             try:
+                self._db_pools[pool_key] = await asyncpg.create_pool(
+                    host=pool_entry['ts_ip'],
+                    port=pool_entry['db_port'],
+                    user=pool_entry['db_user'],
+                    password=pool_entry['db_pass'],
+                    database=pool_entry['db_name'],
+                    min_size=1,
+                    max_size=10,  # adjust as needed
+                    timeout=5  # connection timeout in seconds
+                )
+            except Exception as e:
+                err(f"Failed to create connection pool for {pool_key}: {str(e)}")
+                raise
+
+        try:
+            async with self._db_pools[pool_key].acquire() as conn:
                 yield conn
-            finally:
-                await conn.close()
+        except asyncpg.exceptions.ConnectionDoesNotExistError:
+            err(f"Failed to acquire connection from pool for {pool_key}: Connection does not exist")
+            raise
         except asyncpg.exceptions.ConnectionFailureError:
-            err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
+            err(f"Failed to acquire connection from pool for {pool_key}: Connection failure")
             raise
         except Exception as e:
-            err(f"Unexpected error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
+            err(f"Unexpected error when acquiring connection from pool for {pool_key}: {str(e)}")
             raise
 
+    async def close_db_pools(self):
+        info("Closing database connection pools...")
+        for pool_key, pool in self._db_pools.items():
+            try:
+                await pool.close()
+                info(f"Closed pool for {pool_key}")
+            except Exception as e:
+                err(f"Error closing pool for {pool_key}: {str(e)}")
+        self._db_pools.clear()
+        info("All database connection pools closed.")
+
     async def initialize_sync(self):
         local_ts_id = os.environ.get('TS_ID')
         for pool_entry in self.POOL:
@@ -348,7 +370,14 @@ class APIConfig(BaseModel):
             BEGIN 
                 BEGIN
                     ALTER TABLE "{table_name}" 
-                    ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
+                    ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1;
+                EXCEPTION
+                    WHEN duplicate_column THEN 
+                        -- Do nothing, column already exists
+                END;
+                
+                BEGIN
+                    ALTER TABLE "{table_name}" 
                     ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
                 EXCEPTION
                     WHEN duplicate_column THEN 
@@ -356,6 +385,16 @@ class APIConfig(BaseModel):
                 END;
             END $$;
         """)
+        
+        # Verify that the columns were added
+        result = await conn.fetchrow(f"""
+            SELECT 
+                EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'version') as has_version,
+                EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'server_id') as has_server_id
+        """)
+        
+        if not (result['has_version'] and result['has_server_id']):
+            raise Exception(f"Failed to add version and/or server_id columns to table {table_name}")
 
     async def ensure_sync_trigger(self, conn, table_name):
         await conn.execute(f"""
@@ -391,9 +430,11 @@ class APIConfig(BaseModel):
                         continue
 
                     version = await conn.fetchval("""
-                        SELECT COALESCE(MAX(version), -1) FROM (
-                            SELECT MAX(version) as version FROM information_schema.columns
-                            WHERE table_schema = 'public' AND column_name = 'version'
+                        SELECT COALESCE(MAX(version), -1)
+                        FROM (
+                            SELECT MAX(version) as version
+                            FROM pg_tables
+                            WHERE schemaname = 'public'
                         ) as subquery
                     """)
                     if version > max_version:
@@ -413,10 +454,12 @@ class APIConfig(BaseModel):
                 FROM information_schema.columns 
                 WHERE table_schema = 'public' 
                 AND column_name = 'version'
+                AND table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
             )
         """)
         return result
 
+
     async def pull_changes(self, source_pool_entry, batch_size=10000):
         if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
             info("Skipping self-sync")