From 2e27cc69394644799c5d598db4c9615b3febc7e4 Mon Sep 17 00:00:00 2001
From: sanj <67624670+iodrift@users.noreply.github.com>
Date: Tue, 30 Jul 2024 10:14:37 -0700
Subject: [PATCH] Auto-update: Tue Jul 30 10:14:37 PDT 2024

---
 sijapi/__main__.py |   1 +
 sijapi/classes.py  | 332 +++++++++++++++++++++++++++++----------------
 2 files changed, 217 insertions(+), 116 deletions(-)

diff --git a/sijapi/__main__.py b/sijapi/__main__.py
index 6bfe6b4..c7fe3c8 100755
--- a/sijapi/__main__.py
+++ b/sijapi/__main__.py
@@ -83,6 +83,7 @@ async def lifespan(app: FastAPI):
 
 
 
+
 app = FastAPI(lifespan=lifespan)
 
 app.add_middleware(
diff --git a/sijapi/classes.py b/sijapi/classes.py
index 045dadc..be3875f 100644
--- a/sijapi/classes.py
+++ b/sijapi/classes.py
@@ -5,6 +5,7 @@ import yaml
 import math
 import os
 import re
+import uuid
 import aiofiles
 import aiohttp
 import asyncpg
@@ -315,60 +316,59 @@ class APIConfig(BaseModel):
             err(f"Error: {str(e)}")
             raise
 
-
     async def initialize_sync(self):
         for pool_entry in self.POOL:
-            async with self.get_connection(pool_entry) as conn:
-                tables = await conn.fetch("""
-                    SELECT tablename FROM pg_tables 
-                    WHERE schemaname = 'public'
-                """)
-                
-                for table in tables:
-                    table_name = table['tablename']
-                    # Add version and server_id columns if they don't exist
-                    await conn.execute(f"""
-                        DO $$ 
-                        BEGIN 
-                            BEGIN
-                                ALTER TABLE "{table_name}" 
-                                ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
-                                ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
-                            EXCEPTION
-                                WHEN duplicate_column THEN 
-                                    -- Do nothing, column already exists
-                            END;
-                        END $$;
+            try:
+                async with self.get_connection(pool_entry) as conn:
+                    tables = await conn.fetch("""
+                        SELECT tablename FROM pg_tables 
+                        WHERE schemaname = 'public'
                     """)
+                    
+                    for table in tables:
+                        table_name = table['tablename']
+                        await self.ensure_sync_columns(conn, table_name)
+                        await self.create_sync_trigger(conn, table_name)
 
-                    # Create or replace the trigger function
-                    await conn.execute(f"""
-                        CREATE OR REPLACE FUNCTION update_version_and_server_id()
-                        RETURNS TRIGGER AS $$
-                        BEGIN
-                            NEW.version = COALESCE(OLD.version, 0) + 1;
-                            NEW.server_id = '{os.environ.get('TS_ID')}';
-                            RETURN NEW;
-                        END;
-                        $$ LANGUAGE plpgsql;
-                    """)
-
-                    # Create the trigger if it doesn't exist
-                    await conn.execute(f"""
-                        DO $$ 
-                        BEGIN
-                            IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_version_and_server_id_trigger' AND tgrelid = '{table_name}'::regclass) THEN
-                                CREATE TRIGGER update_version_and_server_id_trigger
-                                BEFORE INSERT OR UPDATE ON "{table_name}"
-                                FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
-                            END IF;
-                        END $$;
-                    """)
-
-            info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.")
+                info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.")
+            except Exception as e:
+                err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
 
+    async def ensure_sync_columns(self, conn, table_name):
+        await conn.execute(f"""
+            DO $$ 
+            BEGIN 
+                BEGIN
+                    ALTER TABLE "{table_name}" 
+                    ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
+                    ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
+                EXCEPTION
+                    WHEN duplicate_column THEN 
+                        -- Do nothing, column already exists
+                END;
+            END $$;
+        """)
 
+    async def create_sync_trigger(self, conn, table_name):
+        await conn.execute(f"""
+            CREATE OR REPLACE FUNCTION update_version_and_server_id()
+            RETURNS TRIGGER AS $$
+            BEGIN
+                NEW.version = COALESCE(OLD.version, 0) + 1;
+                NEW.server_id = '{os.environ.get('TS_ID')}';
+                RETURN NEW;
+            END;
+            $$ LANGUAGE plpgsql;
 
+            DO $$ 
+            BEGIN
+                IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_version_and_server_id_trigger' AND tgrelid = '{table_name}'::regclass) THEN
+                    CREATE TRIGGER update_version_and_server_id_trigger
+                    BEFORE INSERT OR UPDATE ON "{table_name}"
+                    FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
+                END IF;
+            END $$;
+        """)
 
     async def get_most_recent_source(self):
         most_recent_source = None
@@ -394,27 +394,26 @@ class APIConfig(BaseModel):
         
         return most_recent_source
 
-
-
-    async def pull_changes(self, source_pool_entry):
+    async def pull_changes(self, source_pool_entry, batch_size=10000):
         if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
             info("Skipping self-sync")
             return 0
 
-        total_inserts = 0
-        total_updates = 0
-        table_changes = {}
-
+        total_changes = 0
         source_id = source_pool_entry['ts_id']
         source_ip = source_pool_entry['ts_ip']
         dest_id = os.environ.get('TS_ID')
         dest_ip = self.local_db['ts_ip']
 
-        info(f"Starting comprehensive sync from source {source_id} ({source_ip}) to destination {dest_id} ({dest_ip})")
+        info(f"Starting sync from source {source_id} ({source_ip}) to destination {dest_id} ({dest_ip})")
 
         try:
             async with self.get_connection(source_pool_entry) as source_conn:
                 async with self.get_connection(self.local_db) as dest_conn:
+                    # Sync schema first
+                    schema_changes = await self.detect_schema_changes(source_conn, dest_conn)
+                    await self.apply_schema_changes(dest_conn, schema_changes)
+
                     tables = await source_conn.fetch("""
                         SELECT tablename FROM pg_tables 
                         WHERE schemaname = 'public'
@@ -424,62 +423,92 @@ class APIConfig(BaseModel):
                         table_name = table['tablename']
                         last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
                         
-                        changes = await source_conn.fetch(f"""
-                            SELECT * FROM "{table_name}"
-                            WHERE version > $1 AND server_id = $2
-                            ORDER BY version ASC
-                        """, last_synced_version, source_id)
-                        
-                        inserts = 0
-                        updates = 0
-                        for change in changes:
-                            columns = list(change.keys())
-                            values = [change[col] for col in columns]
-                            placeholders = [f'${i+1}' for i in range(len(columns))]
+                        while True:
+                            changes = await source_conn.fetch(f"""
+                                SELECT * FROM "{table_name}"
+                                WHERE version > $1 AND server_id = $2
+                                ORDER BY version ASC
+                                LIMIT $3
+                            """, last_synced_version, source_id, batch_size)
                             
-                            insert_query = f"""
-                                INSERT INTO "{table_name}" ({', '.join(columns)})
-                                VALUES ({', '.join(placeholders)})
-                                ON CONFLICT (id) DO UPDATE SET
-                                {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
-                            """
-                            
-                            result = await dest_conn.execute(insert_query, *values)
-                            if 'UPDATE' in result:
-                                updates += 1
-                            else:
-                                inserts += 1
-                        
-                        if changes:
-                            await self.update_last_synced_version(dest_conn, table_name, source_id, changes[-1]['version'])
-                        
-                        total_inserts += inserts
-                        total_updates += updates
-                        table_changes[table_name] = {'inserts': inserts, 'updates': updates}
-                        
-                        info(f"Synced {table_name} from {source_id} to {dest_id}: {inserts} inserts, {updates} updates")
+                            if not changes:
+                                break
 
-            info(f"Comprehensive sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip})")
-            info(f"Total changes: {total_inserts} inserts, {total_updates} updates")
-            info("Changes by table:")
-            for table, changes in table_changes.items():
-                info(f"  {table}: {changes['inserts']} inserts, {changes['updates']} updates")
+                            changes_count = await self.apply_batch_changes(dest_conn, table_name, changes)
+                            total_changes += changes_count
+                            
+                            last_synced_version = changes[-1]['version']
+                            await self.update_last_synced_version(dest_conn, table_name, source_id, last_synced_version)
+                            
+                            info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}")
+
+            info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}")
 
         except Exception as e:
             err(f"Error during sync process: {str(e)}")
             err(f"Traceback: {traceback.format_exc()}")
 
-        return total_inserts + total_updates
+        return total_changes
+
+    async def apply_batch_changes(self, conn, table_name, changes):
+        if not changes:
+            return 0
+
+        temp_table_name = f"temp_{table_name}_{uuid.uuid4().hex[:8]}"
+        
+        try:
+            # Create temporary table
+            await conn.execute(f"""
+                CREATE TEMPORARY TABLE {temp_table_name} (LIKE "{table_name}" INCLUDING ALL)
+                ON COMMIT DROP
+            """)
+
+            # Bulk insert changes into temporary table
+            columns = changes[0].keys()
+            await conn.copy_records_to_table(temp_table_name, records=[tuple(change[col] for col in columns) for change in changes])
+
+            # Perform upsert with spatial awareness
+            result = await conn.execute(f"""
+                INSERT INTO "{table_name}" 
+                SELECT tc.* 
+                FROM {temp_table_name} tc
+                LEFT JOIN "{table_name}" t ON t.id = tc.id
+                WHERE t.id IS NULL
+                ON CONFLICT (id) DO UPDATE SET
+                {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
+                WHERE (
+                    CASE 
+                        WHEN "{table_name}".geometry IS NOT NULL AND EXCLUDED.geometry IS NOT NULL 
+                        THEN NOT ST_Equals("{table_name}".geometry, EXCLUDED.geometry)
+                        ELSE FALSE
+                    END
+                ) OR {' OR '.join(f"COALESCE({col} <> EXCLUDED.{col}, TRUE)" for col in columns if col not in ['id', 'geometry'])}
+            """)
+
+            # Parse the result to get the number of affected rows
+            affected_rows = int(result.split()[-1])
+            return affected_rows
+
+        finally:
+            # Ensure temporary table is dropped
+            await conn.execute(f"DROP TABLE IF EXISTS {temp_table_name}")
 
     async def push_changes_to_all(self):
         for pool_entry in self.POOL:
             if pool_entry['ts_id'] != os.environ.get('TS_ID'):
-                await self.push_changes_to_one(pool_entry)
+                try:
+                    await self.push_changes_to_one(pool_entry)
+                except Exception as e:
+                    err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
 
-    async def push_changes_to_one(self, pool_entry):
+    async def push_changes_to_one(self, pool_entry, batch_size=10000):
         try:
             async with self.get_connection() as local_conn:
                 async with self.get_connection(pool_entry) as remote_conn:
+                    # Sync schema first
+                    schema_changes = await self.detect_schema_changes(local_conn, remote_conn)
+                    await self.apply_schema_changes(remote_conn, schema_changes)
+
                     tables = await local_conn.fetch("""
                         SELECT tablename FROM pg_tables 
                         WHERE schemaname = 'public'
@@ -489,28 +518,23 @@ class APIConfig(BaseModel):
                         table_name = table['tablename']
                         last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
                         
-                        changes = await local_conn.fetch(f"""
-                            SELECT * FROM "{table_name}"
-                            WHERE version > $1 AND server_id = $2
-                            ORDER BY version ASC
-                        """, last_synced_version, os.environ.get('TS_ID'))
-                        
-                        for change in changes:
-                            columns = list(change.keys())
-                            values = [change[col] for col in columns]
-                            placeholders = [f'${i+1}' for i in range(len(columns))]
+                        while True:
+                            changes = await local_conn.fetch(f"""
+                                SELECT * FROM "{table_name}"
+                                WHERE version > $1 AND server_id = $2
+                                ORDER BY version ASC
+                                LIMIT $3
+                            """, last_synced_version, os.environ.get('TS_ID'), batch_size)
                             
-                            insert_query = f"""
-                                INSERT INTO "{table_name}" ({', '.join(columns)})
-                                VALUES ({', '.join(placeholders)})
-                                ON CONFLICT (id) DO UPDATE SET
-                                {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
-                            """
+                            if not changes:
+                                break
+
+                            changes_count = await self.apply_batch_changes(remote_conn, table_name, changes)
                             
-                            await remote_conn.execute(insert_query, *values)
-                        
-                        if changes:
-                            await self.update_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version'])
+                            last_synced_version = changes[-1]['version']
+                            await self.update_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'), last_synced_version)
+                            
+                            info(f"Pushed batch for {table_name}: {changes_count} changes to {pool_entry['ts_id']}")
             
             info(f"Successfully pushed changes to {pool_entry['ts_id']}")
         except Exception as e:
@@ -552,6 +576,82 @@ class APIConfig(BaseModel):
         END $$;
         """)
 
+    async def detect_schema_changes(self, source_conn, dest_conn):
+        schema_changes = {
+            'new_tables': [],
+            'new_columns': {}
+        }
+        
+        # Detect new tables
+        source_tables = await source_conn.fetch("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
+        dest_tables = await dest_conn.fetch("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
+        
+        source_table_names = set(table['tablename'] for table in source_tables)
+        dest_table_names = set(table['tablename'] for table in dest_tables)
+        
+        new_tables = source_table_names - dest_table_names
+        schema_changes['new_tables'] = list(new_tables)
+        
+        # Detect new columns
+        for table_name in source_table_names:
+            if table_name in dest_table_names:
+                source_columns = await source_conn.fetch(f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}'")
+                dest_columns = await dest_conn.fetch(f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}'")
+                
+                source_column_names = set(column['column_name'] for column in source_columns)
+                dest_column_names = set(column['column_name'] for column in dest_columns)
+                
+                new_columns = source_column_names - dest_column_names
+                if new_columns:
+                    schema_changes['new_columns'][table_name] = [
+                        {'name': column['column_name'], 'type': column['data_type']}
+                        for column in source_columns if column['column_name'] in new_columns
+                    ]
+        
+        return schema_changes
+
+    async def apply_schema_changes(self, conn, schema_changes):
+        for table_name in schema_changes['new_tables']:
+            create_table_sql = await self.get_table_creation_sql(conn, table_name)
+            await conn.execute(create_table_sql)
+            info(f"Created new table: {table_name}")
+        
+        for table_name, columns in schema_changes['new_columns'].items():
+            for column in columns:
+                await conn.execute(f"""
+                    ALTER TABLE "{table_name}" 
+                    ADD COLUMN IF NOT EXISTS {column['name']} {column['type']}
+                """)
+                info(f"Added new column {column['name']} to table {table_name}")
+
+    async def get_table_creation_sql(self, conn, table_name):
+        create_table_sql = await conn.fetchval(f"""
+            SELECT pg_get_tabledef('{table_name}'::regclass::oid)
+        """)
+        return create_table_sql
+
+    async def table_exists(self, conn, table_name):
+        exists = await conn.fetchval(f"""
+            SELECT EXISTS (
+                SELECT FROM information_schema.tables 
+                WHERE table_schema = 'public' 
+                AND table_name = $1
+            )
+        """, table_name)
+        return exists
+
+    async def column_exists(self, conn, table_name, column_name):
+        exists = await conn.fetchval(f"""
+            SELECT EXISTS (
+                SELECT FROM information_schema.columns 
+                WHERE table_schema = 'public' 
+                AND table_name = $1 
+                AND column_name = $2
+            )
+        """, table_name, column_name)
+        return exists
+
+