diff --git a/sijapi/classes.py b/sijapi/classes.py index 0c8dfd8..546e78e 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -427,46 +427,59 @@ class APIConfig(BaseModel): async def ensure_sync_columns(self, conn, table_name): try: + # Check if the table has a primary key + has_primary_key = await conn.fetchval(f""" + SELECT EXISTS ( + SELECT 1 + FROM information_schema.table_constraints + WHERE table_name = '{table_name}' + AND constraint_type = 'PRIMARY KEY' + ) + """) + await conn.execute(f""" DO $$ BEGIN + -- Ensure version column exists + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'version') THEN + ALTER TABLE "{table_name}" ADD COLUMN version INTEGER DEFAULT 1; + END IF; + + -- Ensure server_id column exists + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'server_id') THEN + ALTER TABLE "{table_name}" ADD COLUMN server_id TEXT DEFAULT '{os.environ.get('TS_ID')}'; + END IF; + + -- Create or replace the trigger function + CREATE OR REPLACE FUNCTION update_version_and_server_id() + RETURNS TRIGGER AS $$ BEGIN - ALTER TABLE "{table_name}" - ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1; - EXCEPTION - WHEN duplicate_column THEN - NULL; - END; - - BEGIN - ALTER TABLE "{table_name}" - ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}'; - EXCEPTION - WHEN duplicate_column THEN - NULL; + NEW.version = COALESCE(OLD.version, 0) + 1; + NEW.server_id = '{os.environ.get('TS_ID')}'; + RETURN NEW; END; + $$ LANGUAGE plpgsql; + + -- Create the trigger if it doesn't exist + IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_version_and_server_id_trigger' AND tgrelid = '{table_name}'::regclass) THEN + CREATE TRIGGER update_version_and_server_id_trigger + BEFORE INSERT OR UPDATE ON "{table_name}" + FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id(); + END IF; END $$; """) - # Verify that the columns were added - result = await conn.fetchrow(f""" - SELECT - EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'version') as has_version, - EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'server_id') as has_server_id - """) - - if result['has_version'] and result['has_server_id']: - info(f"Successfully added/verified version and server_id columns for table {table_name}") - return True - else: - err(f"Failed to add version and/or server_id columns to table {table_name}") - return False + info(f"Successfully ensured sync columns and trigger for table {table_name}. Has primary key: {has_primary_key}") + return has_primary_key + except Exception as e: err(f"Error ensuring sync columns for table {table_name}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") return False + + async def ensure_sync_trigger(self, conn, table_name): await conn.execute(f""" CREATE OR REPLACE FUNCTION update_version_and_server_id() @@ -601,34 +614,30 @@ class APIConfig(BaseModel): WHERE schemaname = 'public' """) - async for table in tqdm(tables, desc="Syncing tables", unit="table"): + for table in tables: table_name = table['tablename'] try: - if table_name == 'spatial_ref_sys': - changes_count = await self.sync_spatial_ref_sys(source_conn, dest_conn) + has_primary_key = await self.ensure_sync_columns(dest_conn, table_name) + last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id) + + changes = await source_conn.fetch(f""" + SELECT * FROM "{table_name}" + WHERE version > $1 AND server_id = $2 + ORDER BY version ASC + LIMIT $3 + """, last_synced_version, source_id, batch_size) + + if changes: + changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, has_primary_key) total_changes += changes_count - info(f"Synced spatial_ref_sys: {changes_count} changes. Total so far: {total_changes}") + + if changes_count > 0: + last_synced_version = changes[-1]['version'] + await self.update_sync_status(dest_conn, table_name, source_id, last_synced_version) + + info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}") else: - last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id) - - changes = await source_conn.fetch(f""" - SELECT * FROM "{table_name}" - WHERE version > $1 AND server_id = $2 - ORDER BY version ASC - LIMIT $3 - """, last_synced_version, source_id, batch_size) - - if changes: - changes_count = await self.apply_batch_changes(dest_conn, table_name, changes) - total_changes += changes_count - - if changes_count > 0: - last_synced_version = changes[-1]['version'] - await self.update_sync_status(dest_conn, table_name, source_id, last_synced_version) - - info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}") - else: - info(f"No changes to sync for {table_name}") + info(f"No changes to sync for {table_name}") except Exception as e: err(f"Error syncing table {table_name}: {str(e)}") @@ -651,36 +660,39 @@ class APIConfig(BaseModel): - async def apply_batch_changes(self, conn, table_name, changes): + async def apply_batch_changes(self, conn, table_name, changes, has_primary_key): if not changes: return 0 try: - # Convert the keys to a list columns = list(changes[0].keys()) placeholders = [f'${i+1}' for i in range(len(columns))] - - # Check if 'id' column exists - id_exists = 'id' in columns - - if id_exists: + + if has_primary_key: insert_query = f""" INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) VALUES ({', '.join(placeholders)}) - ON CONFLICT (id) DO UPDATE SET - {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != 'id')} + ON CONFLICT ON CONSTRAINT {table_name}_pkey DO UPDATE SET + {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in ['version', 'server_id'])}, + version = EXCLUDED.version, + server_id = EXCLUDED.server_id + WHERE "{table_name}".version < EXCLUDED.version + OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) """ else: - # For tables without 'id', use all columns as conflict target + # For tables without a primary key, we'll use all columns for conflict detection insert_query = f""" INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) VALUES ({', '.join(placeholders)}) - ON CONFLICT DO NOTHING + ON CONFLICT ({', '.join(f'"{col}"' for col in columns if col not in ['version', 'server_id'])}) DO UPDATE SET + version = EXCLUDED.version, + server_id = EXCLUDED.server_id + WHERE "{table_name}".version < EXCLUDED.version + OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) """ debug(f"Generated insert query for {table_name}: {insert_query}") - # Execute the insert for each change affected_rows = 0 async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"): values = [change[col] for col in columns] @@ -696,6 +708,7 @@ class APIConfig(BaseModel): return 0 + async def sync_spatial_ref_sys(self, source_conn, dest_conn): try: # Get all entries from the source @@ -777,30 +790,40 @@ class APIConfig(BaseModel): for table in tables: table_name = table['tablename'] - last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID')) - - changes = await local_conn.fetch(f""" - SELECT * FROM "{table_name}" - WHERE version > $1 AND server_id = $2 - ORDER BY version ASC - """, last_synced_version, os.environ.get('TS_ID')) - - for change in changes: - columns = list(change.keys()) - values = [change[col] for col in columns] - placeholders = [f'${i+1}' for i in range(len(columns))] + try: + last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID')) - insert_query = f""" - INSERT INTO "{table_name}" ({', '.join(columns)}) - VALUES ({', '.join(placeholders)}) - ON CONFLICT (id) DO UPDATE SET - {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')} - """ + changes = await local_conn.fetch(f""" + SELECT * FROM "{table_name}" + WHERE version > $1 AND server_id = $2 + ORDER BY version ASC + """, last_synced_version, os.environ.get('TS_ID')) - await remote_conn.execute(insert_query, *values) + if changes: + debug(f"Pushing changes for table {table_name}") + debug(f"Columns: {', '.join(changes[0].keys())}") + + columns = list(changes[0].keys()) + placeholders = [f'${i+1}' for i in range(len(columns))] + + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT (id) DO UPDATE SET + {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != 'id')} + """ + + debug(f"Insert query: {insert_query}") + + for change in changes: + values = [change[col] for col in columns] + await remote_conn.execute(insert_query, *values) + + await self.update_sync_status(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version']) - if changes: - await self.update_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version']) + except Exception as e: + err(f"Error pushing changes for table {table_name}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") info(f"Successfully pushed changes to {pool_entry['ts_id']}") except Exception as e: @@ -808,6 +831,7 @@ class APIConfig(BaseModel): err(f"Traceback: {traceback.format_exc()}") + async def update_sync_status(self, conn, table_name, server_id, version): await conn.execute(""" INSERT INTO sync_status (table_name, server_id, last_synced_version, last_sync_time)