diff --git a/sijapi/classes.py b/sijapi/classes.py index 40bdaf3..b043445 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -8,6 +8,7 @@ import re import aiofiles import aiohttp import asyncpg +import traceback import reverse_geocoder as rg from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar @@ -391,7 +392,7 @@ class APIConfig(BaseModel): try: async with self.get_connection(source_pool_entry) as source_conn: - async with self.get_connection() as dest_conn: + async with self.get_connection(self.local_db) as dest_conn: # Connect to local DB explicitly # Compare tables source_tables = await self.get_tables(source_conn) dest_tables = await self.get_tables(dest_conn) @@ -412,6 +413,10 @@ class APIConfig(BaseModel): total_updates += updates table_changes[table] = {'inserts': inserts, 'updates': updates} + # Optionally, handle tables only in source + for table in tables_only_in_source: + warn(f"Table '{table}' exists in source but not in destination. Consider manual migration.") + info(f"Comprehensive sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip})") info(f"Total changes: {total_inserts} inserts, {total_updates} updates") info("Changes by table:") @@ -420,10 +425,12 @@ class APIConfig(BaseModel): except Exception as e: err(f"Error during sync process: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") return total_inserts + total_updates + async def get_tables(self, conn): tables = await conn.fetch(""" SELECT tablename FROM pg_tables @@ -464,9 +471,10 @@ class APIConfig(BaseModel): try: primary_keys = await self.get_primary_keys(dest_conn, table_name) if not primary_keys: - warn(f"Table {table_name} has no primary keys. Skipping data sync.") - return inserts, updates - + warn(f"Table {table_name} has no primary keys. Using all columns for comparison.") + columns = await self.get_table_columns(dest_conn, table_name) + primary_keys = columns # Use all columns if no primary key + last_synced_version = await self.get_last_synced_version(table_name, source_id) changes = await source_conn.fetch(f""" @@ -479,11 +487,14 @@ class APIConfig(BaseModel): columns = list(change.keys()) values = [change[col] for col in columns] + conflict_clause = f"({', '.join(primary_keys)})" + update_clause = ', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in primary_keys) + insert_query = f""" INSERT INTO "{table_name}" ({', '.join(columns)}) VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))}) - ON CONFLICT ({', '.join(primary_keys)}) DO UPDATE SET - {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in primary_keys)} + ON CONFLICT {conflict_clause} DO UPDATE SET + {update_clause} """ try: @@ -512,6 +523,14 @@ class APIConfig(BaseModel): return inserts, updates + async def get_table_columns(self, conn, table_name): + columns = await conn.fetch(""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = $1 + ORDER BY ordinal_position + """, table_name) + return [col['column_name'] for col in columns] async def get_primary_keys(self, conn, table_name): primary_keys = await conn.fetch(""" @@ -584,16 +603,18 @@ class APIConfig(BaseModel): async def sync_schema(self): + local_id = os.environ.get('TS_ID') source_entry = self.local_db source_schema = await self.get_schema(source_entry) - for pool_entry in self.POOL[1:]: - try: - target_schema = await self.get_schema(pool_entry) - await self.apply_schema_changes(pool_entry, source_schema, target_schema) - info(f"Synced schema to {pool_entry['ts_ip']}") - except Exception as e: - err(f"Failed to sync schema to {pool_entry['ts_ip']}: {str(e)}") + for pool_entry in self.POOL: + if pool_entry['ts_id'] != local_id: # Skip the local instance + try: + target_schema = await self.get_schema(pool_entry) + await self.apply_schema_changes(pool_entry, source_schema, target_schema) + info(f"Synced schema to {pool_entry['ts_ip']}") + except Exception as e: + err(f"Failed to sync schema to {pool_entry['ts_ip']}: {str(e)}") async def get_schema(self, pool_entry: Dict[str, Any]): async with self.get_connection(pool_entry) as conn: