diff --git a/sijapi/__init__.py b/sijapi/__init__.py index 94cc479..4ebcbcb 100644 --- a/sijapi/__init__.py +++ b/sijapi/__init__.py @@ -6,7 +6,7 @@ import multiprocessing from dotenv import load_dotenv from dateutil import tz from pathlib import Path -from .classes import Database, Geocoder, APIConfig, Configuration +from .classes import Geocoder, APIConfig, Configuration from .logs import Logger # INITIALization diff --git a/sijapi/classes.py b/sijapi/classes.py index 4d354ac..6533ff6 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -167,6 +167,7 @@ class Configuration(BaseModel): + class APIConfig(BaseModel): HOST: str PORT: int @@ -182,6 +183,8 @@ class APIConfig(BaseModel): GARBAGE: Dict[str, Any] _db_pools: Dict[str, asyncpg.Pool] = {} + SPECIAL_TABLES = ['spatial_ref_sys'] # Tables that can't have server_id and version columns + @classmethod def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]): config_path = cls._resolve_path(config_path, 'config') @@ -301,36 +304,37 @@ class APIConfig(BaseModel): if pool_entry is None: pool_entry = self.local_db - info(f"Attempting to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") - try: - conn = await asyncpg.connect( - host=pool_entry['ts_ip'], - port=pool_entry['db_port'], - user=pool_entry['db_user'], - password=pool_entry['db_pass'], - database=pool_entry['db_name'], - timeout=5 # Add a timeout to prevent hanging - ) - info(f"Successfully connected to {pool_entry['ts_ip']}:{pool_entry['db_port']}") + pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}" + + if pool_key not in self._db_pools: try: + self._db_pools[pool_key] = await asyncpg.create_pool( + host=pool_entry['ts_ip'], + port=pool_entry['db_port'], + user=pool_entry['db_user'], + password=pool_entry['db_pass'], + database=pool_entry['db_name'], + min_size=1, + max_size=10, # adjust as needed + timeout=5 # connection timeout in seconds + ) + except Exception as e: + err(f"Failed to create connection pool for {pool_key}: {str(e)}") + yield None + return + + try: + async with self._db_pools[pool_key].acquire() as conn: yield conn - finally: - await conn.close() - info(f"Closed connection to {pool_entry['ts_ip']}:{pool_entry['db_port']}") except asyncpg.exceptions.ConnectionDoesNotExistError: - err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']} - Connection does not exist") - raise + err(f"Failed to acquire connection from pool for {pool_key}: Connection does not exist") + yield None except asyncpg.exceptions.ConnectionFailureError: - err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']} - Connection failure") - raise - except asyncpg.exceptions.PostgresError as e: - err(f"PostgreSQL error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}") - raise + err(f"Failed to acquire connection from pool for {pool_key}: Connection failure") + yield None except Exception as e: - err(f"Unexpected error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}") - raise - - + err(f"Unexpected error when acquiring connection from pool for {pool_key}: {str(e)}") + yield None async def close_db_pools(self): info("Closing database connection pools...") @@ -343,14 +347,18 @@ class APIConfig(BaseModel): self._db_pools.clear() info("All database connection pools closed.") - async def initialize_sync(self): local_ts_id = os.environ.get('TS_ID') - for pool_entry in self.POOL: + online_hosts = await self.get_online_hosts() + + for pool_entry in online_hosts: if pool_entry['ts_id'] == local_ts_id: continue # Skip local database try: async with self.get_connection(pool_entry) as conn: + if conn is None: + continue # Skip this database if connection failed + info(f"Starting sync initialization for {pool_entry['ts_ip']}...") # Check PostGIS installation @@ -358,130 +366,67 @@ class APIConfig(BaseModel): if not postgis_installed: warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.") - # Initialize sync_status table - await self.initialize_sync_status_table(conn) - - # Continue with sync initialization tables = await conn.fetch(""" SELECT tablename FROM pg_tables WHERE schemaname = 'public' """) - all_tables_synced = True for table in tables: table_name = table['tablename'] - if not await self.ensure_sync_columns(conn, table_name): - all_tables_synced = False + await self.ensure_sync_columns(conn, table_name) - if all_tables_synced: - info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.") - else: - warn(f"Sync initialization partially complete for {pool_entry['ts_ip']}. Some tables may be missing version or server_id columns.") + info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have necessary sync columns and triggers.") except Exception as e: err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") - - - async def initialize_sync_status_table(self, conn): - await conn.execute(""" - CREATE TABLE IF NOT EXISTS sync_status ( - table_name TEXT, - server_id TEXT, - last_synced_version INTEGER, - last_sync_time TIMESTAMP WITH TIME ZONE, - PRIMARY KEY (table_name, server_id) - ) - """) - - # Check if the last_sync_time column exists, and add it if it doesn't - column_exists = await conn.fetchval(""" - SELECT EXISTS ( - SELECT 1 - FROM information_schema.columns - WHERE table_name = 'sync_status' AND column_name = 'last_sync_time' - ) - """) - - if not column_exists: - await conn.execute(""" - ALTER TABLE sync_status - ADD COLUMN last_sync_time TIMESTAMP WITH TIME ZONE - """) - - - - - async def ensure_sync_structure(self, conn): - tables = await conn.fetch(""" - SELECT tablename FROM pg_tables - WHERE schemaname = 'public' - """) - - for table in tables: - table_name = table['tablename'] - await self.ensure_sync_columns(conn, table_name) - await self.ensure_sync_trigger(conn, table_name) - - async def ensure_sync_columns(self, conn, table_name): + if table_name in self.SPECIAL_TABLES: + info(f"Skipping sync columns for special table: {table_name}") + return None + try: - # Check if the table has a primary key - has_primary_key = await conn.fetchval(f""" - SELECT EXISTS ( - SELECT 1 - FROM information_schema.table_constraints - WHERE table_name = '{table_name}' - AND constraint_type = 'PRIMARY KEY' - ) + # Get primary key information + primary_key = await conn.fetchval(f""" + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid + AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = '{table_name}'::regclass + AND i.indisprimary; """) - # Check if version column exists - version_exists = await conn.fetchval(f""" - SELECT EXISTS ( - SELECT 1 FROM information_schema.columns - WHERE table_name = '{table_name}' AND column_name = 'version' - ) + # Ensure version column exists + await conn.execute(f""" + ALTER TABLE "{table_name}" + ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1; """) - # Check if server_id column exists - server_id_exists = await conn.fetchval(f""" - SELECT EXISTS ( - SELECT 1 FROM information_schema.columns - WHERE table_name = '{table_name}' AND column_name = 'server_id' - ) + # Ensure server_id column exists + await conn.execute(f""" + ALTER TABLE "{table_name}" + ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}'; """) - # Add version column if it doesn't exist - if not version_exists: - await conn.execute(f""" - ALTER TABLE "{table_name}" ADD COLUMN version INTEGER DEFAULT 1 - """) - - # Add server_id column if it doesn't exist - if not server_id_exists: - await conn.execute(f""" - ALTER TABLE "{table_name}" ADD COLUMN server_id TEXT DEFAULT '{os.environ.get('TS_ID')}' - """) - # Create or replace the trigger function - await conn.execute(""" + await conn.execute(f""" CREATE OR REPLACE FUNCTION update_version_and_server_id() RETURNS TRIGGER AS $$ BEGIN NEW.version = COALESCE(OLD.version, 0) + 1; - NEW.server_id = $1; + NEW.server_id = '{os.environ.get('TS_ID')}'; RETURN NEW; END; $$ LANGUAGE plpgsql; """) - # Create the trigger if it doesn't exist + # Check if the trigger exists and create it if it doesn't trigger_exists = await conn.fetchval(f""" SELECT EXISTS ( - SELECT 1 FROM pg_trigger - WHERE tgname = 'update_version_and_server_id_trigger' + SELECT 1 + FROM pg_trigger + WHERE tgname = 'update_version_and_server_id_trigger' AND tgrelid = '{table_name}'::regclass ) """) @@ -490,128 +435,58 @@ class APIConfig(BaseModel): await conn.execute(f""" CREATE TRIGGER update_version_and_server_id_trigger BEFORE INSERT OR UPDATE ON "{table_name}" - FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id('{os.environ.get('TS_ID')}') + FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id(); """) - - info(f"Successfully ensured sync columns and trigger for table {table_name}. Has primary key: {has_primary_key}") - return has_primary_key + + info(f"Successfully ensured sync columns and trigger for table {table_name}") + return primary_key except Exception as e: err(f"Error ensuring sync columns for table {table_name}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") - return False - async def ensure_sync_trigger(self, conn, table_name): - await conn.execute(f""" - CREATE OR REPLACE FUNCTION update_version_and_server_id() - RETURNS TRIGGER AS $$ - BEGIN - NEW.version = COALESCE(OLD.version, 0) + 1; - NEW.server_id = '{os.environ.get('TS_ID')}'; - RETURN NEW; - END; - $$ LANGUAGE plpgsql; + async def apply_batch_changes(self, conn, table_name, changes, primary_key): + if not changes: + return 0 - DROP TRIGGER IF EXISTS update_version_and_server_id_trigger ON "{table_name}"; - - CREATE TRIGGER update_version_and_server_id_trigger - BEFORE INSERT OR UPDATE ON "{table_name}" - FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id(); - """) - - async def get_most_recent_source(self): - most_recent_source = None - max_version = -1 - local_ts_id = os.environ.get('TS_ID') - - for pool_entry in self.POOL: - if pool_entry['ts_id'] == local_ts_id: - continue # Skip local database + try: + columns = list(changes[0].keys()) + placeholders = [f'${i+1}' for i in range(len(columns))] - if not await self.is_server_accessible(pool_entry['ts_ip'], pool_entry['db_port']): - warn(f"Server {pool_entry['ts_id']} ({pool_entry['ts_ip']}:{pool_entry['db_port']}) is not accessible. Skipping.") - continue + if primary_key: + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT ("{primary_key}") DO UPDATE SET + {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in [primary_key, 'version', 'server_id'])}, + version = EXCLUDED.version, + server_id = EXCLUDED.server_id + WHERE "{table_name}".version < EXCLUDED.version + OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) + """ + else: + # For tables without a primary key, we'll use all columns for conflict resolution + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT DO NOTHING + """ - try: - async with self.get_connection(pool_entry) as conn: - tables = await conn.fetch(""" - SELECT tablename FROM pg_tables - WHERE schemaname = 'public' - """) - - for table in tables: - table_name = table['tablename'] - try: - result = await conn.fetchrow(f""" - SELECT MAX(version) as max_version, server_id - FROM "{table_name}" - WHERE version = (SELECT MAX(version) FROM "{table_name}") - GROUP BY server_id - ORDER BY MAX(version) DESC - LIMIT 1 - """) - if result: - version, server_id = result['max_version'], result['server_id'] - info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})") - if version > max_version: - max_version = version - most_recent_source = pool_entry - else: - info(f"No data in table {table_name} for {pool_entry['ts_id']}") - except asyncpg.exceptions.UndefinedColumnError: - warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Attempting to add...") - await self.ensure_sync_columns(conn, table_name) - except Exception as e: - err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}") + debug(f"Generated insert query for {table_name}: {insert_query}") - except asyncpg.exceptions.ConnectionFailureError as e: - err(f"Failed to establish database connection with {pool_entry['ts_id']} ({pool_entry['ts_ip']}:{pool_entry['db_port']}): {str(e)}") - except Exception as e: - err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") - - return most_recent_source + affected_rows = 0 + async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"): + values = [change[col] for col in columns] + debug(f"Executing query for {table_name} with values: {values}") + result = await conn.execute(insert_query, *values) + affected_rows += int(result.split()[-1]) + return affected_rows - - - async def is_server_accessible(self, host, port, timeout=2): - try: - future = asyncio.open_connection(host, port) - await asyncio.wait_for(future, timeout=timeout) - return True - except (asyncio.TimeoutError, ConnectionRefusedError, socket.gaierror): - return False - - async def check_version_column_exists(self, conn): - try: - result = await conn.fetchval(""" - SELECT EXISTS ( - SELECT 1 - FROM information_schema.columns - WHERE table_schema = 'public' - AND column_name = 'version' - AND table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') - ) - """) - if not result: - tables_without_version = await conn.fetch(""" - SELECT tablename - FROM pg_tables - WHERE schemaname = 'public' - AND tablename NOT IN ( - SELECT table_name - FROM information_schema.columns - WHERE table_schema = 'public' AND column_name = 'version' - ) - """) - table_names = ", ".join([t['tablename'] for t in tables_without_version]) - warn(f"Tables without 'version' column: {table_names}") - return result except Exception as e: - err(f"Error checking for 'version' column existence: {str(e)}") - return False - + err(f"Error applying batch changes to {table_name}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + return 0 async def pull_changes(self, source_pool_entry, batch_size=10000): if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): @@ -634,41 +509,34 @@ class APIConfig(BaseModel): WHERE schemaname = 'public' """) - for table in tables: + async for table in tqdm(tables, desc="Syncing tables", unit="table"): table_name = table['tablename'] try: - debug(f"Processing table: {table_name}") - has_primary_key = await self.ensure_sync_columns(dest_conn, table_name) - debug(f"Table {table_name} has primary key: {has_primary_key}") - last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id) - debug(f"Last synced version for {table_name}: {last_synced_version}") - - changes = await source_conn.fetch(f""" - SELECT * FROM "{table_name}" - WHERE version > $1 AND server_id = $2 - ORDER BY version ASC - LIMIT $3 - """, last_synced_version, source_id, batch_size) - - debug(f"Number of changes for {table_name}: {len(changes)}") - - if changes: - debug(f"Sample change for {table_name}: {changes[0]}") - changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, has_primary_key) - total_changes += changes_count - - if changes_count > 0: - last_synced_version = changes[-1]['version'] - await self.update_sync_status(dest_conn, table_name, source_id, last_synced_version) - - info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}") + if table_name in self.SPECIAL_TABLES: + await self.sync_special_table(source_conn, dest_conn, table_name) else: - info(f"No changes to sync for {table_name}") + primary_key = await self.ensure_sync_columns(dest_conn, table_name) + last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id) + + changes = await source_conn.fetch(f""" + SELECT * FROM "{table_name}" + WHERE version > $1 AND server_id = $2 + ORDER BY version ASC + LIMIT $3 + """, last_synced_version, source_id, batch_size) + + if changes: + changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, primary_key) + total_changes += changes_count + + if changes_count > 0: + info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}") + else: + info(f"No changes to sync for {table_name}") except Exception as e: err(f"Error syncing table {table_name}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") - # Continue with the next table info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}") @@ -684,63 +552,91 @@ class APIConfig(BaseModel): return total_changes + async def get_online_hosts(self) -> List[Dict[str, Any]]: + online_hosts = [] + for pool_entry in self.POOL: + try: + async with self.get_connection(pool_entry) as conn: + if conn is not None: + online_hosts.append(pool_entry) + except Exception as e: + err(f"Error checking host {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}") + return online_hosts - async def apply_batch_changes(self, conn, table_name, changes, has_primary_key): - if not changes: - return 0 - - try: - columns = list(changes[0].keys()) - placeholders = [f'${i}' for i in range(1, len(columns) + 1)] - - if has_primary_key: - insert_query = f""" - INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) - VALUES ({', '.join(placeholders)}) - ON CONFLICT ON CONSTRAINT {table_name}_pkey DO UPDATE SET - {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in ['version', 'server_id'])}, - version = EXCLUDED.version, - server_id = EXCLUDED.server_id - WHERE "{table_name}".version < EXCLUDED.version - OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) - """ - else: - # For tables without a primary key, we'll use all columns for conflict detection - insert_query = f""" - INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) - VALUES ({', '.join(placeholders)}) - ON CONFLICT ({', '.join(f'"{col}"' for col in columns if col not in ['version', 'server_id'])}) DO UPDATE SET - version = EXCLUDED.version, - server_id = EXCLUDED.server_id - WHERE "{table_name}".version < EXCLUDED.version - OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) - """ - - debug(f"Generated insert query for {table_name}: {insert_query}") - - # Prepare the statement - stmt = await conn.prepare(insert_query) - - affected_rows = 0 - for change in changes: + async def push_changes_to_all(self): + for pool_entry in self.POOL: + if pool_entry['ts_id'] != os.environ.get('TS_ID'): try: - values = [change[col] for col in columns] - result = await stmt.fetchval(*values) - affected_rows += 1 + await self.push_changes_to_one(pool_entry) except Exception as e: - err(f"Error inserting row into {table_name}: {str(e)}") - err(f"Row data: {change}") - # Continue with the next row - - return affected_rows + err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") + async def push_changes_to_one(self, pool_entry): + try: + async with self.get_connection() as local_conn: + async with self.get_connection(pool_entry) as remote_conn: + tables = await local_conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] + try: + if table_name in self.SPECIAL_TABLES: + await self.sync_special_table(local_conn, remote_conn, table_name) + else: + primary_key = await self.ensure_sync_columns(remote_conn, table_name) + last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID')) + + changes = await local_conn.fetch(f""" + SELECT * FROM "{table_name}" + WHERE version > $1 AND server_id = $2 + ORDER BY version ASC + """, last_synced_version, os.environ.get('TS_ID')) + + if changes: + changes_count = await self.apply_batch_changes(remote_conn, table_name, changes, primary_key) + + if changes_count > 0: + info(f"Pushed {changes_count} changes for table {table_name} to {pool_entry['ts_id']}") + + except Exception as e: + err(f"Error pushing changes for table {table_name} to {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + + info(f"Successfully pushed changes to {pool_entry['ts_id']}") except Exception as e: - err(f"Error applying batch changes to {table_name}: {str(e)}") + err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") - return 0 + async def get_last_synced_version(self, conn, table_name, server_id): + if table_name in self.SPECIAL_TABLES: + return 0 # Special handling for tables without version column + return await conn.fetchval(f""" + SELECT COALESCE(MAX(version), 0) + FROM "{table_name}" + WHERE server_id = $1 + """, server_id) + async def check_postgis(self, conn): + try: + result = await conn.fetchval("SELECT PostGIS_version();") + if result: + info(f"PostGIS version: {result}") + return True + else: + warn("PostGIS is not installed or not working properly") + return False + except Exception as e: + err(f"Error checking PostGIS: {str(e)}") + return False + + async def sync_special_table(self, source_conn, dest_conn, table_name): + if table_name == 'spatial_ref_sys': + return await self.sync_spatial_ref_sys(source_conn, dest_conn) + # Add more special cases as needed async def sync_spatial_ref_sys(self, source_conn, dest_conn): try: @@ -803,126 +699,59 @@ class APIConfig(BaseModel): err(f"Traceback: {traceback.format_exc()}") return 0 - - async def push_changes_to_all(self): - for pool_entry in self.POOL: - if pool_entry['ts_id'] != os.environ.get('TS_ID'): - try: - await self.push_changes_to_one(pool_entry) - except Exception as e: - err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") - - async def push_changes_to_one(self, pool_entry): - try: - async with self.get_connection() as local_conn: - async with self.get_connection(pool_entry) as remote_conn: - tables = await local_conn.fetch(""" + async def get_most_recent_source(self): + most_recent_source = None + max_version = -1 + local_ts_id = os.environ.get('TS_ID') + online_hosts = await self.get_online_hosts() + + for pool_entry in online_hosts: + if pool_entry['ts_id'] == local_ts_id: + continue # Skip local database + + try: + async with self.get_connection(pool_entry) as conn: + tables = await conn.fetch(""" SELECT tablename FROM pg_tables WHERE schemaname = 'public' """) for table in tables: table_name = table['tablename'] + if table_name in self.SPECIAL_TABLES: + continue # Skip special tables for version comparison try: - last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID')) - - changes = await local_conn.fetch(f""" - SELECT * FROM "{table_name}" - WHERE version > $1 AND server_id = $2 - ORDER BY version ASC - """, last_synced_version, os.environ.get('TS_ID')) - - if changes: - debug(f"Pushing changes for table {table_name}") - debug(f"Columns: {', '.join(changes[0].keys())}") - - columns = list(changes[0].keys()) - placeholders = [f'${i+1}' for i in range(len(columns))] - - insert_query = f""" - INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) - VALUES ({', '.join(placeholders)}) - ON CONFLICT (id) DO UPDATE SET - {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != 'id')} - """ - - debug(f"Insert query: {insert_query}") - - for change in changes: - values = [change[col] for col in columns] - await remote_conn.execute(insert_query, *values) - - await self.update_sync_status(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version']) - + result = await conn.fetchrow(f""" + SELECT MAX(version) as max_version, server_id + FROM "{table_name}" + WHERE version = (SELECT MAX(version) FROM "{table_name}") + GROUP BY server_id + ORDER BY MAX(version) DESC + LIMIT 1 + """) + if result: + version, server_id = result['max_version'], result['server_id'] + info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})") + if version > max_version: + max_version = version + most_recent_source = pool_entry + else: + info(f"No data in table {table_name} for {pool_entry['ts_id']}") + except asyncpg.exceptions.UndefinedColumnError: + warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping.") except Exception as e: - err(f"Error pushing changes for table {table_name}: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") - - info(f"Successfully pushed changes to {pool_entry['ts_id']}") - except Exception as e: - err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") + err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}") + + except asyncpg.exceptions.ConnectionFailureError: + warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") + except Exception as e: + err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + + return most_recent_source - async def update_sync_status(self, conn, table_name, server_id, version): - await conn.execute(""" - INSERT INTO sync_status (table_name, server_id, last_synced_version, last_sync_time) - VALUES ($1, $2, $3, NOW()) - ON CONFLICT (table_name, server_id) DO UPDATE - SET last_synced_version = EXCLUDED.last_synced_version, - last_sync_time = EXCLUDED.last_sync_time - """, table_name, server_id, version) - - - async def get_last_synced_version(self, conn, table_name, server_id): - return await conn.fetchval(f""" - SELECT COALESCE(MAX(version), 0) - FROM "{table_name}" - WHERE server_id = $1 - """, server_id) - - async def update_last_synced_version(self, conn, table_name, server_id, version): - await conn.execute(f""" - INSERT INTO "{table_name}" (server_id, version) - VALUES ($1, $2) - ON CONFLICT (server_id) DO UPDATE - SET version = EXCLUDED.version - WHERE "{table_name}".version < EXCLUDED.version - """, server_id, version) - - async def get_schema_version(self, pool_entry): - async with self.get_connection(pool_entry) as conn: - return await conn.fetchval(""" - SELECT COALESCE(MAX(version), 0) FROM ( - SELECT MAX(version) as version FROM pg_tables - WHERE schemaname = 'public' - ) as subquery - """) - - async def create_sequence_if_not_exists(self, conn, sequence_name): - await conn.execute(f""" - DO $$ - BEGIN - IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN - CREATE SEQUENCE {sequence_name}; - END IF; - END $$; - """) - - async def check_postgis(self, conn): - try: - result = await conn.fetchval("SELECT PostGIS_version();") - if result: - info(f"PostGIS version: {result}") - return True - else: - warn("PostGIS is not installed or not working properly") - return False - except Exception as e: - err(f"Error checking PostGIS: {str(e)}") - return False - @@ -1231,44 +1060,6 @@ class Geocoder: def __del__(self): self.executor.shutdown() -class Database(BaseModel): - host: str = Field(..., description="Database host") - port: int = Field(5432, description="Database port") - user: str = Field(..., description="Database user") - password: str = Field(..., description="Database password") - database: str = Field(..., description="Database name") - db_schema: Optional[str] = Field(None, description="Database schema") - - @asynccontextmanager - async def get_connection(self): - conn = await asyncpg.connect( - host=self.host, - port=self.port, - user=self.user, - password=self.password, - database=self.database - ) - try: - if self.db_schema: - await conn.execute(f"SET search_path TO {self.db_schema}") - yield conn - finally: - await conn.close() - - @classmethod - def from_env(cls): - import os - return cls( - host=os.getenv("DB_HOST", "localhost"), - port=int(os.getenv("DB_PORT", 5432)), - user=os.getenv("DB_USER"), - password=os.getenv("DB_PASSWORD"), - database=os.getenv("DB_NAME"), - db_schema=os.getenv("DB_SCHEMA") - ) - - def to_dict(self): - return self.dict(exclude_none=True) class IMAPConfig(BaseModel): username: str