Auto-update: Tue Jul 30 23:10:36 PDT 2024

This commit is contained in:
sanj 2024-07-30 23:10:36 -07:00
parent 8f38810626
commit 1ae0f506cf
2 changed files with 242 additions and 451 deletions

View file

@ -6,7 +6,7 @@ import multiprocessing
from dotenv import load_dotenv from dotenv import load_dotenv
from dateutil import tz from dateutil import tz
from pathlib import Path from pathlib import Path
from .classes import Database, Geocoder, APIConfig, Configuration from .classes import Geocoder, APIConfig, Configuration
from .logs import Logger from .logs import Logger
# INITIALization # INITIALization

View file

@ -167,6 +167,7 @@ class Configuration(BaseModel):
class APIConfig(BaseModel): class APIConfig(BaseModel):
HOST: str HOST: str
PORT: int PORT: int
@ -182,6 +183,8 @@ class APIConfig(BaseModel):
GARBAGE: Dict[str, Any] GARBAGE: Dict[str, Any]
_db_pools: Dict[str, asyncpg.Pool] = {} _db_pools: Dict[str, asyncpg.Pool] = {}
SPECIAL_TABLES = ['spatial_ref_sys'] # Tables that can't have server_id and version columns
@classmethod @classmethod
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]): def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
config_path = cls._resolve_path(config_path, 'config') config_path = cls._resolve_path(config_path, 'config')
@ -301,36 +304,37 @@ class APIConfig(BaseModel):
if pool_entry is None: if pool_entry is None:
pool_entry = self.local_db pool_entry = self.local_db
info(f"Attempting to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
try:
conn = await asyncpg.connect( if pool_key not in self._db_pools:
host=pool_entry['ts_ip'],
port=pool_entry['db_port'],
user=pool_entry['db_user'],
password=pool_entry['db_pass'],
database=pool_entry['db_name'],
timeout=5 # Add a timeout to prevent hanging
)
info(f"Successfully connected to {pool_entry['ts_ip']}:{pool_entry['db_port']}")
try: try:
self._db_pools[pool_key] = await asyncpg.create_pool(
host=pool_entry['ts_ip'],
port=pool_entry['db_port'],
user=pool_entry['db_user'],
password=pool_entry['db_pass'],
database=pool_entry['db_name'],
min_size=1,
max_size=10, # adjust as needed
timeout=5 # connection timeout in seconds
)
except Exception as e:
err(f"Failed to create connection pool for {pool_key}: {str(e)}")
yield None
return
try:
async with self._db_pools[pool_key].acquire() as conn:
yield conn yield conn
finally:
await conn.close()
info(f"Closed connection to {pool_entry['ts_ip']}:{pool_entry['db_port']}")
except asyncpg.exceptions.ConnectionDoesNotExistError: except asyncpg.exceptions.ConnectionDoesNotExistError:
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']} - Connection does not exist") err(f"Failed to acquire connection from pool for {pool_key}: Connection does not exist")
raise yield None
except asyncpg.exceptions.ConnectionFailureError: except asyncpg.exceptions.ConnectionFailureError:
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']} - Connection failure") err(f"Failed to acquire connection from pool for {pool_key}: Connection failure")
raise yield None
except asyncpg.exceptions.PostgresError as e:
err(f"PostgreSQL error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
raise
except Exception as e: except Exception as e:
err(f"Unexpected error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}") err(f"Unexpected error when acquiring connection from pool for {pool_key}: {str(e)}")
raise yield None
async def close_db_pools(self): async def close_db_pools(self):
info("Closing database connection pools...") info("Closing database connection pools...")
@ -343,14 +347,18 @@ class APIConfig(BaseModel):
self._db_pools.clear() self._db_pools.clear()
info("All database connection pools closed.") info("All database connection pools closed.")
async def initialize_sync(self): async def initialize_sync(self):
local_ts_id = os.environ.get('TS_ID') local_ts_id = os.environ.get('TS_ID')
for pool_entry in self.POOL: online_hosts = await self.get_online_hosts()
for pool_entry in online_hosts:
if pool_entry['ts_id'] == local_ts_id: if pool_entry['ts_id'] == local_ts_id:
continue # Skip local database continue # Skip local database
try: try:
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
if conn is None:
continue # Skip this database if connection failed
info(f"Starting sync initialization for {pool_entry['ts_ip']}...") info(f"Starting sync initialization for {pool_entry['ts_ip']}...")
# Check PostGIS installation # Check PostGIS installation
@ -358,130 +366,67 @@ class APIConfig(BaseModel):
if not postgis_installed: if not postgis_installed:
warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.") warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.")
# Initialize sync_status table
await self.initialize_sync_status_table(conn)
# Continue with sync initialization
tables = await conn.fetch(""" tables = await conn.fetch("""
SELECT tablename FROM pg_tables SELECT tablename FROM pg_tables
WHERE schemaname = 'public' WHERE schemaname = 'public'
""") """)
all_tables_synced = True
for table in tables: for table in tables:
table_name = table['tablename'] table_name = table['tablename']
if not await self.ensure_sync_columns(conn, table_name): await self.ensure_sync_columns(conn, table_name)
all_tables_synced = False
if all_tables_synced: info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have necessary sync columns and triggers.")
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.")
else:
warn(f"Sync initialization partially complete for {pool_entry['ts_ip']}. Some tables may be missing version or server_id columns.")
except Exception as e: except Exception as e:
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}") err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
async def initialize_sync_status_table(self, conn):
await conn.execute("""
CREATE TABLE IF NOT EXISTS sync_status (
table_name TEXT,
server_id TEXT,
last_synced_version INTEGER,
last_sync_time TIMESTAMP WITH TIME ZONE,
PRIMARY KEY (table_name, server_id)
)
""")
# Check if the last_sync_time column exists, and add it if it doesn't
column_exists = await conn.fetchval("""
SELECT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = 'sync_status' AND column_name = 'last_sync_time'
)
""")
if not column_exists:
await conn.execute("""
ALTER TABLE sync_status
ADD COLUMN last_sync_time TIMESTAMP WITH TIME ZONE
""")
async def ensure_sync_structure(self, conn):
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
await self.ensure_sync_columns(conn, table_name)
await self.ensure_sync_trigger(conn, table_name)
async def ensure_sync_columns(self, conn, table_name): async def ensure_sync_columns(self, conn, table_name):
if table_name in self.SPECIAL_TABLES:
info(f"Skipping sync columns for special table: {table_name}")
return None
try: try:
# Check if the table has a primary key # Get primary key information
has_primary_key = await conn.fetchval(f""" primary_key = await conn.fetchval(f"""
SELECT EXISTS ( SELECT a.attname
SELECT 1 FROM pg_index i
FROM information_schema.table_constraints JOIN pg_attribute a ON a.attrelid = i.indrelid
WHERE table_name = '{table_name}' AND a.attnum = ANY(i.indkey)
AND constraint_type = 'PRIMARY KEY' WHERE i.indrelid = '{table_name}'::regclass
) AND i.indisprimary;
""") """)
# Check if version column exists # Ensure version column exists
version_exists = await conn.fetchval(f""" await conn.execute(f"""
SELECT EXISTS ( ALTER TABLE "{table_name}"
SELECT 1 FROM information_schema.columns ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1;
WHERE table_name = '{table_name}' AND column_name = 'version'
)
""") """)
# Check if server_id column exists # Ensure server_id column exists
server_id_exists = await conn.fetchval(f""" await conn.execute(f"""
SELECT EXISTS ( ALTER TABLE "{table_name}"
SELECT 1 FROM information_schema.columns ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
WHERE table_name = '{table_name}' AND column_name = 'server_id'
)
""") """)
# Add version column if it doesn't exist
if not version_exists:
await conn.execute(f"""
ALTER TABLE "{table_name}" ADD COLUMN version INTEGER DEFAULT 1
""")
# Add server_id column if it doesn't exist
if not server_id_exists:
await conn.execute(f"""
ALTER TABLE "{table_name}" ADD COLUMN server_id TEXT DEFAULT '{os.environ.get('TS_ID')}'
""")
# Create or replace the trigger function # Create or replace the trigger function
await conn.execute(""" await conn.execute(f"""
CREATE OR REPLACE FUNCTION update_version_and_server_id() CREATE OR REPLACE FUNCTION update_version_and_server_id()
RETURNS TRIGGER AS $$ RETURNS TRIGGER AS $$
BEGIN BEGIN
NEW.version = COALESCE(OLD.version, 0) + 1; NEW.version = COALESCE(OLD.version, 0) + 1;
NEW.server_id = $1; NEW.server_id = '{os.environ.get('TS_ID')}';
RETURN NEW; RETURN NEW;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
""") """)
# Create the trigger if it doesn't exist # Check if the trigger exists and create it if it doesn't
trigger_exists = await conn.fetchval(f""" trigger_exists = await conn.fetchval(f"""
SELECT EXISTS ( SELECT EXISTS (
SELECT 1 FROM pg_trigger SELECT 1
WHERE tgname = 'update_version_and_server_id_trigger' FROM pg_trigger
WHERE tgname = 'update_version_and_server_id_trigger'
AND tgrelid = '{table_name}'::regclass AND tgrelid = '{table_name}'::regclass
) )
""") """)
@ -490,128 +435,58 @@ class APIConfig(BaseModel):
await conn.execute(f""" await conn.execute(f"""
CREATE TRIGGER update_version_and_server_id_trigger CREATE TRIGGER update_version_and_server_id_trigger
BEFORE INSERT OR UPDATE ON "{table_name}" BEFORE INSERT OR UPDATE ON "{table_name}"
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id('{os.environ.get('TS_ID')}') FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
""") """)
info(f"Successfully ensured sync columns and trigger for table {table_name}. Has primary key: {has_primary_key}") info(f"Successfully ensured sync columns and trigger for table {table_name}")
return has_primary_key return primary_key
except Exception as e: except Exception as e:
err(f"Error ensuring sync columns for table {table_name}: {str(e)}") err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
return False
async def ensure_sync_trigger(self, conn, table_name): async def apply_batch_changes(self, conn, table_name, changes, primary_key):
await conn.execute(f""" if not changes:
CREATE OR REPLACE FUNCTION update_version_and_server_id() return 0
RETURNS TRIGGER AS $$
BEGIN
NEW.version = COALESCE(OLD.version, 0) + 1;
NEW.server_id = '{os.environ.get('TS_ID')}';
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
DROP TRIGGER IF EXISTS update_version_and_server_id_trigger ON "{table_name}"; try:
columns = list(changes[0].keys())
CREATE TRIGGER update_version_and_server_id_trigger placeholders = [f'${i+1}' for i in range(len(columns))]
BEFORE INSERT OR UPDATE ON "{table_name}"
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
""")
async def get_most_recent_source(self):
most_recent_source = None
max_version = -1
local_ts_id = os.environ.get('TS_ID')
for pool_entry in self.POOL:
if pool_entry['ts_id'] == local_ts_id:
continue # Skip local database
if not await self.is_server_accessible(pool_entry['ts_ip'], pool_entry['db_port']): if primary_key:
warn(f"Server {pool_entry['ts_id']} ({pool_entry['ts_ip']}:{pool_entry['db_port']}) is not accessible. Skipping.") insert_query = f"""
continue INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT ("{primary_key}") DO UPDATE SET
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in [primary_key, 'version', 'server_id'])},
version = EXCLUDED.version,
server_id = EXCLUDED.server_id
WHERE "{table_name}".version < EXCLUDED.version
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
"""
else:
# For tables without a primary key, we'll use all columns for conflict resolution
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT DO NOTHING
"""
try: debug(f"Generated insert query for {table_name}: {insert_query}")
async with self.get_connection(pool_entry) as conn:
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
try:
result = await conn.fetchrow(f"""
SELECT MAX(version) as max_version, server_id
FROM "{table_name}"
WHERE version = (SELECT MAX(version) FROM "{table_name}")
GROUP BY server_id
ORDER BY MAX(version) DESC
LIMIT 1
""")
if result:
version, server_id = result['max_version'], result['server_id']
info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})")
if version > max_version:
max_version = version
most_recent_source = pool_entry
else:
info(f"No data in table {table_name} for {pool_entry['ts_id']}")
except asyncpg.exceptions.UndefinedColumnError:
warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Attempting to add...")
await self.ensure_sync_columns(conn, table_name)
except Exception as e:
err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}")
except asyncpg.exceptions.ConnectionFailureError as e: affected_rows = 0
err(f"Failed to establish database connection with {pool_entry['ts_id']} ({pool_entry['ts_ip']}:{pool_entry['db_port']}): {str(e)}") async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"):
except Exception as e: values = [change[col] for col in columns]
err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}") debug(f"Executing query for {table_name} with values: {values}")
err(f"Traceback: {traceback.format_exc()}") result = await conn.execute(insert_query, *values)
affected_rows += int(result.split()[-1])
return most_recent_source
return affected_rows
async def is_server_accessible(self, host, port, timeout=2):
try:
future = asyncio.open_connection(host, port)
await asyncio.wait_for(future, timeout=timeout)
return True
except (asyncio.TimeoutError, ConnectionRefusedError, socket.gaierror):
return False
async def check_version_column_exists(self, conn):
try:
result = await conn.fetchval("""
SELECT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_schema = 'public'
AND column_name = 'version'
AND table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
)
""")
if not result:
tables_without_version = await conn.fetch("""
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
AND tablename NOT IN (
SELECT table_name
FROM information_schema.columns
WHERE table_schema = 'public' AND column_name = 'version'
)
""")
table_names = ", ".join([t['tablename'] for t in tables_without_version])
warn(f"Tables without 'version' column: {table_names}")
return result
except Exception as e: except Exception as e:
err(f"Error checking for 'version' column existence: {str(e)}") err(f"Error applying batch changes to {table_name}: {str(e)}")
return False err(f"Traceback: {traceback.format_exc()}")
return 0
async def pull_changes(self, source_pool_entry, batch_size=10000): async def pull_changes(self, source_pool_entry, batch_size=10000):
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
@ -634,41 +509,34 @@ class APIConfig(BaseModel):
WHERE schemaname = 'public' WHERE schemaname = 'public'
""") """)
for table in tables: async for table in tqdm(tables, desc="Syncing tables", unit="table"):
table_name = table['tablename'] table_name = table['tablename']
try: try:
debug(f"Processing table: {table_name}") if table_name in self.SPECIAL_TABLES:
has_primary_key = await self.ensure_sync_columns(dest_conn, table_name) await self.sync_special_table(source_conn, dest_conn, table_name)
debug(f"Table {table_name} has primary key: {has_primary_key}")
last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
debug(f"Last synced version for {table_name}: {last_synced_version}")
changes = await source_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2
ORDER BY version ASC
LIMIT $3
""", last_synced_version, source_id, batch_size)
debug(f"Number of changes for {table_name}: {len(changes)}")
if changes:
debug(f"Sample change for {table_name}: {changes[0]}")
changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, has_primary_key)
total_changes += changes_count
if changes_count > 0:
last_synced_version = changes[-1]['version']
await self.update_sync_status(dest_conn, table_name, source_id, last_synced_version)
info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}")
else: else:
info(f"No changes to sync for {table_name}") primary_key = await self.ensure_sync_columns(dest_conn, table_name)
last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
changes = await source_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2
ORDER BY version ASC
LIMIT $3
""", last_synced_version, source_id, batch_size)
if changes:
changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, primary_key)
total_changes += changes_count
if changes_count > 0:
info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}")
else:
info(f"No changes to sync for {table_name}")
except Exception as e: except Exception as e:
err(f"Error syncing table {table_name}: {str(e)}") err(f"Error syncing table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
# Continue with the next table
info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}") info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}")
@ -684,63 +552,91 @@ class APIConfig(BaseModel):
return total_changes return total_changes
async def get_online_hosts(self) -> List[Dict[str, Any]]:
online_hosts = []
for pool_entry in self.POOL:
try:
async with self.get_connection(pool_entry) as conn:
if conn is not None:
online_hosts.append(pool_entry)
except Exception as e:
err(f"Error checking host {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
return online_hosts
async def apply_batch_changes(self, conn, table_name, changes, has_primary_key): async def push_changes_to_all(self):
if not changes: for pool_entry in self.POOL:
return 0 if pool_entry['ts_id'] != os.environ.get('TS_ID'):
try:
columns = list(changes[0].keys())
placeholders = [f'${i}' for i in range(1, len(columns) + 1)]
if has_primary_key:
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT ON CONSTRAINT {table_name}_pkey DO UPDATE SET
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in ['version', 'server_id'])},
version = EXCLUDED.version,
server_id = EXCLUDED.server_id
WHERE "{table_name}".version < EXCLUDED.version
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
"""
else:
# For tables without a primary key, we'll use all columns for conflict detection
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT ({', '.join(f'"{col}"' for col in columns if col not in ['version', 'server_id'])}) DO UPDATE SET
version = EXCLUDED.version,
server_id = EXCLUDED.server_id
WHERE "{table_name}".version < EXCLUDED.version
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
"""
debug(f"Generated insert query for {table_name}: {insert_query}")
# Prepare the statement
stmt = await conn.prepare(insert_query)
affected_rows = 0
for change in changes:
try: try:
values = [change[col] for col in columns] await self.push_changes_to_one(pool_entry)
result = await stmt.fetchval(*values)
affected_rows += 1
except Exception as e: except Exception as e:
err(f"Error inserting row into {table_name}: {str(e)}") err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
err(f"Row data: {change}")
# Continue with the next row
return affected_rows
async def push_changes_to_one(self, pool_entry):
try:
async with self.get_connection() as local_conn:
async with self.get_connection(pool_entry) as remote_conn:
tables = await local_conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
try:
if table_name in self.SPECIAL_TABLES:
await self.sync_special_table(local_conn, remote_conn, table_name)
else:
primary_key = await self.ensure_sync_columns(remote_conn, table_name)
last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
changes = await local_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2
ORDER BY version ASC
""", last_synced_version, os.environ.get('TS_ID'))
if changes:
changes_count = await self.apply_batch_changes(remote_conn, table_name, changes, primary_key)
if changes_count > 0:
info(f"Pushed {changes_count} changes for table {table_name} to {pool_entry['ts_id']}")
except Exception as e:
err(f"Error pushing changes for table {table_name} to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
info(f"Successfully pushed changes to {pool_entry['ts_id']}")
except Exception as e: except Exception as e:
err(f"Error applying batch changes to {table_name}: {str(e)}") err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
return 0
async def get_last_synced_version(self, conn, table_name, server_id):
if table_name in self.SPECIAL_TABLES:
return 0 # Special handling for tables without version column
return await conn.fetchval(f"""
SELECT COALESCE(MAX(version), 0)
FROM "{table_name}"
WHERE server_id = $1
""", server_id)
async def check_postgis(self, conn):
try:
result = await conn.fetchval("SELECT PostGIS_version();")
if result:
info(f"PostGIS version: {result}")
return True
else:
warn("PostGIS is not installed or not working properly")
return False
except Exception as e:
err(f"Error checking PostGIS: {str(e)}")
return False
async def sync_special_table(self, source_conn, dest_conn, table_name):
if table_name == 'spatial_ref_sys':
return await self.sync_spatial_ref_sys(source_conn, dest_conn)
# Add more special cases as needed
async def sync_spatial_ref_sys(self, source_conn, dest_conn): async def sync_spatial_ref_sys(self, source_conn, dest_conn):
try: try:
@ -803,126 +699,59 @@ class APIConfig(BaseModel):
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
return 0 return 0
async def get_most_recent_source(self):
async def push_changes_to_all(self): most_recent_source = None
for pool_entry in self.POOL: max_version = -1
if pool_entry['ts_id'] != os.environ.get('TS_ID'): local_ts_id = os.environ.get('TS_ID')
try: online_hosts = await self.get_online_hosts()
await self.push_changes_to_one(pool_entry)
except Exception as e: for pool_entry in online_hosts:
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") if pool_entry['ts_id'] == local_ts_id:
continue # Skip local database
async def push_changes_to_one(self, pool_entry):
try: try:
async with self.get_connection() as local_conn: async with self.get_connection(pool_entry) as conn:
async with self.get_connection(pool_entry) as remote_conn: tables = await conn.fetch("""
tables = await local_conn.fetch("""
SELECT tablename FROM pg_tables SELECT tablename FROM pg_tables
WHERE schemaname = 'public' WHERE schemaname = 'public'
""") """)
for table in tables: for table in tables:
table_name = table['tablename'] table_name = table['tablename']
if table_name in self.SPECIAL_TABLES:
continue # Skip special tables for version comparison
try: try:
last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID')) result = await conn.fetchrow(f"""
SELECT MAX(version) as max_version, server_id
changes = await local_conn.fetch(f""" FROM "{table_name}"
SELECT * FROM "{table_name}" WHERE version = (SELECT MAX(version) FROM "{table_name}")
WHERE version > $1 AND server_id = $2 GROUP BY server_id
ORDER BY version ASC ORDER BY MAX(version) DESC
""", last_synced_version, os.environ.get('TS_ID')) LIMIT 1
""")
if changes: if result:
debug(f"Pushing changes for table {table_name}") version, server_id = result['max_version'], result['server_id']
debug(f"Columns: {', '.join(changes[0].keys())}") info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})")
if version > max_version:
columns = list(changes[0].keys()) max_version = version
placeholders = [f'${i+1}' for i in range(len(columns))] most_recent_source = pool_entry
else:
insert_query = f""" info(f"No data in table {table_name} for {pool_entry['ts_id']}")
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) except asyncpg.exceptions.UndefinedColumnError:
VALUES ({', '.join(placeholders)}) warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping.")
ON CONFLICT (id) DO UPDATE SET
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != 'id')}
"""
debug(f"Insert query: {insert_query}")
for change in changes:
values = [change[col] for col in columns]
await remote_conn.execute(insert_query, *values)
await self.update_sync_status(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version'])
except Exception as e: except Exception as e:
err(f"Error pushing changes for table {table_name}: {str(e)}") err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
except asyncpg.exceptions.ConnectionFailureError:
info(f"Successfully pushed changes to {pool_entry['ts_id']}") warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
except Exception as e: except Exception as e:
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
return most_recent_source
async def update_sync_status(self, conn, table_name, server_id, version):
await conn.execute("""
INSERT INTO sync_status (table_name, server_id, last_synced_version, last_sync_time)
VALUES ($1, $2, $3, NOW())
ON CONFLICT (table_name, server_id) DO UPDATE
SET last_synced_version = EXCLUDED.last_synced_version,
last_sync_time = EXCLUDED.last_sync_time
""", table_name, server_id, version)
async def get_last_synced_version(self, conn, table_name, server_id):
return await conn.fetchval(f"""
SELECT COALESCE(MAX(version), 0)
FROM "{table_name}"
WHERE server_id = $1
""", server_id)
async def update_last_synced_version(self, conn, table_name, server_id, version):
await conn.execute(f"""
INSERT INTO "{table_name}" (server_id, version)
VALUES ($1, $2)
ON CONFLICT (server_id) DO UPDATE
SET version = EXCLUDED.version
WHERE "{table_name}".version < EXCLUDED.version
""", server_id, version)
async def get_schema_version(self, pool_entry):
async with self.get_connection(pool_entry) as conn:
return await conn.fetchval("""
SELECT COALESCE(MAX(version), 0) FROM (
SELECT MAX(version) as version FROM pg_tables
WHERE schemaname = 'public'
) as subquery
""")
async def create_sequence_if_not_exists(self, conn, sequence_name):
await conn.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
CREATE SEQUENCE {sequence_name};
END IF;
END $$;
""")
async def check_postgis(self, conn):
try:
result = await conn.fetchval("SELECT PostGIS_version();")
if result:
info(f"PostGIS version: {result}")
return True
else:
warn("PostGIS is not installed or not working properly")
return False
except Exception as e:
err(f"Error checking PostGIS: {str(e)}")
return False
@ -1231,44 +1060,6 @@ class Geocoder:
def __del__(self): def __del__(self):
self.executor.shutdown() self.executor.shutdown()
class Database(BaseModel):
host: str = Field(..., description="Database host")
port: int = Field(5432, description="Database port")
user: str = Field(..., description="Database user")
password: str = Field(..., description="Database password")
database: str = Field(..., description="Database name")
db_schema: Optional[str] = Field(None, description="Database schema")
@asynccontextmanager
async def get_connection(self):
conn = await asyncpg.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.database
)
try:
if self.db_schema:
await conn.execute(f"SET search_path TO {self.db_schema}")
yield conn
finally:
await conn.close()
@classmethod
def from_env(cls):
import os
return cls(
host=os.getenv("DB_HOST", "localhost"),
port=int(os.getenv("DB_PORT", 5432)),
user=os.getenv("DB_USER"),
password=os.getenv("DB_PASSWORD"),
database=os.getenv("DB_NAME"),
db_schema=os.getenv("DB_SCHEMA")
)
def to_dict(self):
return self.dict(exclude_none=True)
class IMAPConfig(BaseModel): class IMAPConfig(BaseModel):
username: str username: str