From bca1f89a00129f8b02be8affdc9c275e489af056 Mon Sep 17 00:00:00 2001
From: sanj <67624670+iodrift@users.noreply.github.com>
Date: Tue, 30 Jul 2024 16:32:22 -0700
Subject: [PATCH] Auto-update: Tue Jul 30 16:32:22 PDT 2024

---
 sijapi/classes.py | 190 ++++++++++++++++++++++++++--------------------
 1 file changed, 107 insertions(+), 83 deletions(-)

diff --git a/sijapi/classes.py b/sijapi/classes.py
index 0c8dfd8..546e78e 100644
--- a/sijapi/classes.py
+++ b/sijapi/classes.py
@@ -427,46 +427,59 @@ class APIConfig(BaseModel):
 
     async def ensure_sync_columns(self, conn, table_name):
         try:
+            # Check if the table has a primary key
+            has_primary_key = await conn.fetchval(f"""
+                SELECT EXISTS (
+                    SELECT 1
+                    FROM information_schema.table_constraints
+                    WHERE table_name = '{table_name}'
+                    AND constraint_type = 'PRIMARY KEY'
+                )
+            """)
+
             await conn.execute(f"""
                 DO $$ 
                 BEGIN 
+                    -- Ensure version column exists
+                    IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'version') THEN
+                        ALTER TABLE "{table_name}" ADD COLUMN version INTEGER DEFAULT 1;
+                    END IF;
+
+                    -- Ensure server_id column exists
+                    IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'server_id') THEN
+                        ALTER TABLE "{table_name}" ADD COLUMN server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
+                    END IF;
+
+                    -- Create or replace the trigger function
+                    CREATE OR REPLACE FUNCTION update_version_and_server_id()
+                    RETURNS TRIGGER AS $$
                     BEGIN
-                        ALTER TABLE "{table_name}" 
-                        ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1;
-                    EXCEPTION
-                        WHEN duplicate_column THEN 
-                            NULL;
-                    END;
-                    
-                    BEGIN
-                        ALTER TABLE "{table_name}" 
-                        ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
-                    EXCEPTION
-                        WHEN duplicate_column THEN 
-                            NULL;
+                        NEW.version = COALESCE(OLD.version, 0) + 1;
+                        NEW.server_id = '{os.environ.get('TS_ID')}';
+                        RETURN NEW;
                     END;
+                    $$ LANGUAGE plpgsql;
+
+                    -- Create the trigger if it doesn't exist
+                    IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_version_and_server_id_trigger' AND tgrelid = '{table_name}'::regclass) THEN
+                        CREATE TRIGGER update_version_and_server_id_trigger
+                        BEFORE INSERT OR UPDATE ON "{table_name}"
+                        FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
+                    END IF;
                 END $$;
             """)
             
-            # Verify that the columns were added
-            result = await conn.fetchrow(f"""
-                SELECT 
-                    EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'version') as has_version,
-                    EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'server_id') as has_server_id
-            """)
-            
-            if result['has_version'] and result['has_server_id']:
-                info(f"Successfully added/verified version and server_id columns for table {table_name}")
-                return True
-            else:
-                err(f"Failed to add version and/or server_id columns to table {table_name}")
-                return False
+            info(f"Successfully ensured sync columns and trigger for table {table_name}. Has primary key: {has_primary_key}")
+            return has_primary_key
+
         except Exception as e:
             err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
             err(f"Traceback: {traceback.format_exc()}")
             return False
 
 
+
+
     async def ensure_sync_trigger(self, conn, table_name):
         await conn.execute(f"""
             CREATE OR REPLACE FUNCTION update_version_and_server_id()
@@ -601,34 +614,30 @@ class APIConfig(BaseModel):
                         WHERE schemaname = 'public'
                     """)
                     
-                    async for table in tqdm(tables, desc="Syncing tables", unit="table"):
+                    for table in tables:
                         table_name = table['tablename']
                         try:
-                            if table_name == 'spatial_ref_sys':
-                                changes_count = await self.sync_spatial_ref_sys(source_conn, dest_conn)
+                            has_primary_key = await self.ensure_sync_columns(dest_conn, table_name)
+                            last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
+                            
+                            changes = await source_conn.fetch(f"""
+                                SELECT * FROM "{table_name}"
+                                WHERE version > $1 AND server_id = $2
+                                ORDER BY version ASC
+                                LIMIT $3
+                            """, last_synced_version, source_id, batch_size)
+                            
+                            if changes:
+                                changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, has_primary_key)
                                 total_changes += changes_count
-                                info(f"Synced spatial_ref_sys: {changes_count} changes. Total so far: {total_changes}")
+                                
+                                if changes_count > 0:
+                                    last_synced_version = changes[-1]['version']
+                                    await self.update_sync_status(dest_conn, table_name, source_id, last_synced_version)
+                                
+                                info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}")
                             else:
-                                last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
-                                
-                                changes = await source_conn.fetch(f"""
-                                    SELECT * FROM "{table_name}"
-                                    WHERE version > $1 AND server_id = $2
-                                    ORDER BY version ASC
-                                    LIMIT $3
-                                """, last_synced_version, source_id, batch_size)
-                                
-                                if changes:
-                                    changes_count = await self.apply_batch_changes(dest_conn, table_name, changes)
-                                    total_changes += changes_count
-                                    
-                                    if changes_count > 0:
-                                        last_synced_version = changes[-1]['version']
-                                        await self.update_sync_status(dest_conn, table_name, source_id, last_synced_version)
-                                    
-                                    info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}")
-                                else:
-                                    info(f"No changes to sync for {table_name}")
+                                info(f"No changes to sync for {table_name}")
 
                         except Exception as e:
                             err(f"Error syncing table {table_name}: {str(e)}")
@@ -651,36 +660,39 @@ class APIConfig(BaseModel):
 
 
 
-    async def apply_batch_changes(self, conn, table_name, changes):
+    async def apply_batch_changes(self, conn, table_name, changes, has_primary_key):
         if not changes:
             return 0
 
         try:
-            # Convert the keys to a list
             columns = list(changes[0].keys())
             placeholders = [f'${i+1}' for i in range(len(columns))]
-
-            # Check if 'id' column exists
-            id_exists = 'id' in columns
-
-            if id_exists:
+            
+            if has_primary_key:
                 insert_query = f"""
                     INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
                     VALUES ({', '.join(placeholders)})
-                    ON CONFLICT (id) DO UPDATE SET
-                    {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != 'id')}
+                    ON CONFLICT ON CONSTRAINT {table_name}_pkey DO UPDATE SET
+                    {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in ['version', 'server_id'])},
+                    version = EXCLUDED.version,
+                    server_id = EXCLUDED.server_id
+                    WHERE "{table_name}".version < EXCLUDED.version
+                    OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
                 """
             else:
-                # For tables without 'id', use all columns as conflict target
+                # For tables without a primary key, we'll use all columns for conflict detection
                 insert_query = f"""
                     INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
                     VALUES ({', '.join(placeholders)})
-                    ON CONFLICT DO NOTHING
+                    ON CONFLICT ({', '.join(f'"{col}"' for col in columns if col not in ['version', 'server_id'])}) DO UPDATE SET
+                    version = EXCLUDED.version,
+                    server_id = EXCLUDED.server_id
+                    WHERE "{table_name}".version < EXCLUDED.version
+                    OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
                 """
 
             debug(f"Generated insert query for {table_name}: {insert_query}")
 
-            # Execute the insert for each change
             affected_rows = 0
             async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"):
                 values = [change[col] for col in columns]
@@ -696,6 +708,7 @@ class APIConfig(BaseModel):
             return 0
 
 
+
     async def sync_spatial_ref_sys(self, source_conn, dest_conn):
         try:
             # Get all entries from the source
@@ -777,30 +790,40 @@ class APIConfig(BaseModel):
                     
                     for table in tables:
                         table_name = table['tablename']
-                        last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
-                        
-                        changes = await local_conn.fetch(f"""
-                            SELECT * FROM "{table_name}"
-                            WHERE version > $1 AND server_id = $2
-                            ORDER BY version ASC
-                        """, last_synced_version, os.environ.get('TS_ID'))
-                        
-                        for change in changes:
-                            columns = list(change.keys())
-                            values = [change[col] for col in columns]
-                            placeholders = [f'${i+1}' for i in range(len(columns))]
+                        try:
+                            last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
                             
-                            insert_query = f"""
-                                INSERT INTO "{table_name}" ({', '.join(columns)})
-                                VALUES ({', '.join(placeholders)})
-                                ON CONFLICT (id) DO UPDATE SET
-                                {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
-                            """
+                            changes = await local_conn.fetch(f"""
+                                SELECT * FROM "{table_name}"
+                                WHERE version > $1 AND server_id = $2
+                                ORDER BY version ASC
+                            """, last_synced_version, os.environ.get('TS_ID'))
                             
-                            await remote_conn.execute(insert_query, *values)
+                            if changes:
+                                debug(f"Pushing changes for table {table_name}")
+                                debug(f"Columns: {', '.join(changes[0].keys())}")
+                                
+                                columns = list(changes[0].keys())
+                                placeholders = [f'${i+1}' for i in range(len(columns))]
+                                
+                                insert_query = f"""
+                                    INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
+                                    VALUES ({', '.join(placeholders)})
+                                    ON CONFLICT (id) DO UPDATE SET
+                                    {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != 'id')}
+                                """
+                                
+                                debug(f"Insert query: {insert_query}")
+                                
+                                for change in changes:
+                                    values = [change[col] for col in columns]
+                                    await remote_conn.execute(insert_query, *values)
+                                
+                                await self.update_sync_status(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version'])
                         
-                        if changes:
-                            await self.update_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version'])
+                        except Exception as e:
+                            err(f"Error pushing changes for table {table_name}: {str(e)}")
+                            err(f"Traceback: {traceback.format_exc()}")
             
             info(f"Successfully pushed changes to {pool_entry['ts_id']}")
         except Exception as e:
@@ -808,6 +831,7 @@ class APIConfig(BaseModel):
             err(f"Traceback: {traceback.format_exc()}")
 
 
+
     async def update_sync_status(self, conn, table_name, server_id, version):
         await conn.execute("""
             INSERT INTO sync_status (table_name, server_id, last_synced_version, last_sync_time)