Auto-update: Tue Jul 30 23:10:36 PDT 2024

This commit is contained in:
sanj 2024-07-30 23:10:36 -07:00
parent 8f38810626
commit 1ae0f506cf
2 changed files with 242 additions and 451 deletions

View file

@ -6,7 +6,7 @@ import multiprocessing
from dotenv import load_dotenv from dotenv import load_dotenv
from dateutil import tz from dateutil import tz
from pathlib import Path from pathlib import Path
from .classes import Database, Geocoder, APIConfig, Configuration from .classes import Geocoder, APIConfig, Configuration
from .logs import Logger from .logs import Logger
# INITIALization # INITIALization

View file

@ -167,6 +167,7 @@ class Configuration(BaseModel):
class APIConfig(BaseModel): class APIConfig(BaseModel):
HOST: str HOST: str
PORT: int PORT: int
@ -182,6 +183,8 @@ class APIConfig(BaseModel):
GARBAGE: Dict[str, Any] GARBAGE: Dict[str, Any]
_db_pools: Dict[str, asyncpg.Pool] = {} _db_pools: Dict[str, asyncpg.Pool] = {}
SPECIAL_TABLES = ['spatial_ref_sys'] # Tables that can't have server_id and version columns
@classmethod @classmethod
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]): def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
config_path = cls._resolve_path(config_path, 'config') config_path = cls._resolve_path(config_path, 'config')
@ -301,36 +304,37 @@ class APIConfig(BaseModel):
if pool_entry is None: if pool_entry is None:
pool_entry = self.local_db pool_entry = self.local_db
info(f"Attempting to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
if pool_key not in self._db_pools:
try: try:
conn = await asyncpg.connect( self._db_pools[pool_key] = await asyncpg.create_pool(
host=pool_entry['ts_ip'], host=pool_entry['ts_ip'],
port=pool_entry['db_port'], port=pool_entry['db_port'],
user=pool_entry['db_user'], user=pool_entry['db_user'],
password=pool_entry['db_pass'], password=pool_entry['db_pass'],
database=pool_entry['db_name'], database=pool_entry['db_name'],
timeout=5 # Add a timeout to prevent hanging min_size=1,
max_size=10, # adjust as needed
timeout=5 # connection timeout in seconds
) )
info(f"Successfully connected to {pool_entry['ts_ip']}:{pool_entry['db_port']}")
try:
yield conn
finally:
await conn.close()
info(f"Closed connection to {pool_entry['ts_ip']}:{pool_entry['db_port']}")
except asyncpg.exceptions.ConnectionDoesNotExistError:
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']} - Connection does not exist")
raise
except asyncpg.exceptions.ConnectionFailureError:
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']} - Connection failure")
raise
except asyncpg.exceptions.PostgresError as e:
err(f"PostgreSQL error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
raise
except Exception as e: except Exception as e:
err(f"Unexpected error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}") err(f"Failed to create connection pool for {pool_key}: {str(e)}")
raise yield None
return
try:
async with self._db_pools[pool_key].acquire() as conn:
yield conn
except asyncpg.exceptions.ConnectionDoesNotExistError:
err(f"Failed to acquire connection from pool for {pool_key}: Connection does not exist")
yield None
except asyncpg.exceptions.ConnectionFailureError:
err(f"Failed to acquire connection from pool for {pool_key}: Connection failure")
yield None
except Exception as e:
err(f"Unexpected error when acquiring connection from pool for {pool_key}: {str(e)}")
yield None
async def close_db_pools(self): async def close_db_pools(self):
info("Closing database connection pools...") info("Closing database connection pools...")
@ -343,14 +347,18 @@ class APIConfig(BaseModel):
self._db_pools.clear() self._db_pools.clear()
info("All database connection pools closed.") info("All database connection pools closed.")
async def initialize_sync(self): async def initialize_sync(self):
local_ts_id = os.environ.get('TS_ID') local_ts_id = os.environ.get('TS_ID')
for pool_entry in self.POOL: online_hosts = await self.get_online_hosts()
for pool_entry in online_hosts:
if pool_entry['ts_id'] == local_ts_id: if pool_entry['ts_id'] == local_ts_id:
continue # Skip local database continue # Skip local database
try: try:
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
if conn is None:
continue # Skip this database if connection failed
info(f"Starting sync initialization for {pool_entry['ts_ip']}...") info(f"Starting sync initialization for {pool_entry['ts_ip']}...")
# Check PostGIS installation # Check PostGIS installation
@ -358,62 +366,6 @@ class APIConfig(BaseModel):
if not postgis_installed: if not postgis_installed:
warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.") warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.")
# Initialize sync_status table
await self.initialize_sync_status_table(conn)
# Continue with sync initialization
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
all_tables_synced = True
for table in tables:
table_name = table['tablename']
if not await self.ensure_sync_columns(conn, table_name):
all_tables_synced = False
if all_tables_synced:
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.")
else:
warn(f"Sync initialization partially complete for {pool_entry['ts_ip']}. Some tables may be missing version or server_id columns.")
except Exception as e:
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
async def initialize_sync_status_table(self, conn):
await conn.execute("""
CREATE TABLE IF NOT EXISTS sync_status (
table_name TEXT,
server_id TEXT,
last_synced_version INTEGER,
last_sync_time TIMESTAMP WITH TIME ZONE,
PRIMARY KEY (table_name, server_id)
)
""")
# Check if the last_sync_time column exists, and add it if it doesn't
column_exists = await conn.fetchval("""
SELECT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_name = 'sync_status' AND column_name = 'last_sync_time'
)
""")
if not column_exists:
await conn.execute("""
ALTER TABLE sync_status
ADD COLUMN last_sync_time TIMESTAMP WITH TIME ZONE
""")
async def ensure_sync_structure(self, conn):
tables = await conn.fetch(""" tables = await conn.fetch("""
SELECT tablename FROM pg_tables SELECT tablename FROM pg_tables
WHERE schemaname = 'public' WHERE schemaname = 'public'
@ -422,86 +374,42 @@ class APIConfig(BaseModel):
for table in tables: for table in tables:
table_name = table['tablename'] table_name = table['tablename']
await self.ensure_sync_columns(conn, table_name) await self.ensure_sync_columns(conn, table_name)
await self.ensure_sync_trigger(conn, table_name)
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have necessary sync columns and triggers.")
except Exception as e:
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
async def ensure_sync_columns(self, conn, table_name): async def ensure_sync_columns(self, conn, table_name):
if table_name in self.SPECIAL_TABLES:
info(f"Skipping sync columns for special table: {table_name}")
return None
try: try:
# Check if the table has a primary key # Get primary key information
has_primary_key = await conn.fetchval(f""" primary_key = await conn.fetchval(f"""
SELECT EXISTS ( SELECT a.attname
SELECT 1 FROM pg_index i
FROM information_schema.table_constraints JOIN pg_attribute a ON a.attrelid = i.indrelid
WHERE table_name = '{table_name}' AND a.attnum = ANY(i.indkey)
AND constraint_type = 'PRIMARY KEY' WHERE i.indrelid = '{table_name}'::regclass
) AND i.indisprimary;
""") """)
# Check if version column exists # Ensure version column exists
version_exists = await conn.fetchval(f"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = '{table_name}' AND column_name = 'version'
)
""")
# Check if server_id column exists
server_id_exists = await conn.fetchval(f"""
SELECT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_name = '{table_name}' AND column_name = 'server_id'
)
""")
# Add version column if it doesn't exist
if not version_exists:
await conn.execute(f""" await conn.execute(f"""
ALTER TABLE "{table_name}" ADD COLUMN version INTEGER DEFAULT 1 ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1;
""") """)
# Add server_id column if it doesn't exist # Ensure server_id column exists
if not server_id_exists:
await conn.execute(f""" await conn.execute(f"""
ALTER TABLE "{table_name}" ADD COLUMN server_id TEXT DEFAULT '{os.environ.get('TS_ID')}' ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
""") """)
# Create or replace the trigger function # Create or replace the trigger function
await conn.execute("""
CREATE OR REPLACE FUNCTION update_version_and_server_id()
RETURNS TRIGGER AS $$
BEGIN
NEW.version = COALESCE(OLD.version, 0) + 1;
NEW.server_id = $1;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""")
# Create the trigger if it doesn't exist
trigger_exists = await conn.fetchval(f"""
SELECT EXISTS (
SELECT 1 FROM pg_trigger
WHERE tgname = 'update_version_and_server_id_trigger'
AND tgrelid = '{table_name}'::regclass
)
""")
if not trigger_exists:
await conn.execute(f"""
CREATE TRIGGER update_version_and_server_id_trigger
BEFORE INSERT OR UPDATE ON "{table_name}"
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id('{os.environ.get('TS_ID')}')
""")
info(f"Successfully ensured sync columns and trigger for table {table_name}. Has primary key: {has_primary_key}")
return has_primary_key
except Exception as e:
err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
return False
async def ensure_sync_trigger(self, conn, table_name):
await conn.execute(f""" await conn.execute(f"""
CREATE OR REPLACE FUNCTION update_version_and_server_id() CREATE OR REPLACE FUNCTION update_version_and_server_id()
RETURNS TRIGGER AS $$ RETURNS TRIGGER AS $$
@ -511,107 +419,74 @@ class APIConfig(BaseModel):
RETURN NEW; RETURN NEW;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
""")
DROP TRIGGER IF EXISTS update_version_and_server_id_trigger ON "{table_name}"; # Check if the trigger exists and create it if it doesn't
trigger_exists = await conn.fetchval(f"""
SELECT EXISTS (
SELECT 1
FROM pg_trigger
WHERE tgname = 'update_version_and_server_id_trigger'
AND tgrelid = '{table_name}'::regclass
)
""")
if not trigger_exists:
await conn.execute(f"""
CREATE TRIGGER update_version_and_server_id_trigger CREATE TRIGGER update_version_and_server_id_trigger
BEFORE INSERT OR UPDATE ON "{table_name}" BEFORE INSERT OR UPDATE ON "{table_name}"
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id(); FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
""") """)
async def get_most_recent_source(self): info(f"Successfully ensured sync columns and trigger for table {table_name}")
most_recent_source = None return primary_key
max_version = -1
local_ts_id = os.environ.get('TS_ID')
for pool_entry in self.POOL:
if pool_entry['ts_id'] == local_ts_id:
continue # Skip local database
if not await self.is_server_accessible(pool_entry['ts_ip'], pool_entry['db_port']):
warn(f"Server {pool_entry['ts_id']} ({pool_entry['ts_ip']}:{pool_entry['db_port']}) is not accessible. Skipping.")
continue
try:
async with self.get_connection(pool_entry) as conn:
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
try:
result = await conn.fetchrow(f"""
SELECT MAX(version) as max_version, server_id
FROM "{table_name}"
WHERE version = (SELECT MAX(version) FROM "{table_name}")
GROUP BY server_id
ORDER BY MAX(version) DESC
LIMIT 1
""")
if result:
version, server_id = result['max_version'], result['server_id']
info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})")
if version > max_version:
max_version = version
most_recent_source = pool_entry
else:
info(f"No data in table {table_name} for {pool_entry['ts_id']}")
except asyncpg.exceptions.UndefinedColumnError:
warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Attempting to add...")
await self.ensure_sync_columns(conn, table_name)
except Exception as e: except Exception as e:
err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}") err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
except asyncpg.exceptions.ConnectionFailureError as e:
err(f"Failed to establish database connection with {pool_entry['ts_id']} ({pool_entry['ts_ip']}:{pool_entry['db_port']}): {str(e)}")
except Exception as e:
err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
return most_recent_source async def apply_batch_changes(self, conn, table_name, changes, primary_key):
if not changes:
return 0
async def is_server_accessible(self, host, port, timeout=2):
try: try:
future = asyncio.open_connection(host, port) columns = list(changes[0].keys())
await asyncio.wait_for(future, timeout=timeout) placeholders = [f'${i+1}' for i in range(len(columns))]
return True
except (asyncio.TimeoutError, ConnectionRefusedError, socket.gaierror): if primary_key:
return False insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT ("{primary_key}") DO UPDATE SET
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in [primary_key, 'version', 'server_id'])},
version = EXCLUDED.version,
server_id = EXCLUDED.server_id
WHERE "{table_name}".version < EXCLUDED.version
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
"""
else:
# For tables without a primary key, we'll use all columns for conflict resolution
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT DO NOTHING
"""
debug(f"Generated insert query for {table_name}: {insert_query}")
affected_rows = 0
async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"):
values = [change[col] for col in columns]
debug(f"Executing query for {table_name} with values: {values}")
result = await conn.execute(insert_query, *values)
affected_rows += int(result.split()[-1])
return affected_rows
async def check_version_column_exists(self, conn):
try:
result = await conn.fetchval("""
SELECT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_schema = 'public'
AND column_name = 'version'
AND table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
)
""")
if not result:
tables_without_version = await conn.fetch("""
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
AND tablename NOT IN (
SELECT table_name
FROM information_schema.columns
WHERE table_schema = 'public' AND column_name = 'version'
)
""")
table_names = ", ".join([t['tablename'] for t in tables_without_version])
warn(f"Tables without 'version' column: {table_names}")
return result
except Exception as e: except Exception as e:
err(f"Error checking for 'version' column existence: {str(e)}") err(f"Error applying batch changes to {table_name}: {str(e)}")
return False err(f"Traceback: {traceback.format_exc()}")
return 0
async def pull_changes(self, source_pool_entry, batch_size=10000): async def pull_changes(self, source_pool_entry, batch_size=10000):
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
@ -634,14 +509,14 @@ class APIConfig(BaseModel):
WHERE schemaname = 'public' WHERE schemaname = 'public'
""") """)
for table in tables: async for table in tqdm(tables, desc="Syncing tables", unit="table"):
table_name = table['tablename'] table_name = table['tablename']
try: try:
debug(f"Processing table: {table_name}") if table_name in self.SPECIAL_TABLES:
has_primary_key = await self.ensure_sync_columns(dest_conn, table_name) await self.sync_special_table(source_conn, dest_conn, table_name)
debug(f"Table {table_name} has primary key: {has_primary_key}") else:
primary_key = await self.ensure_sync_columns(dest_conn, table_name)
last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id) last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
debug(f"Last synced version for {table_name}: {last_synced_version}")
changes = await source_conn.fetch(f""" changes = await source_conn.fetch(f"""
SELECT * FROM "{table_name}" SELECT * FROM "{table_name}"
@ -650,17 +525,11 @@ class APIConfig(BaseModel):
LIMIT $3 LIMIT $3
""", last_synced_version, source_id, batch_size) """, last_synced_version, source_id, batch_size)
debug(f"Number of changes for {table_name}: {len(changes)}")
if changes: if changes:
debug(f"Sample change for {table_name}: {changes[0]}") changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, primary_key)
changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, has_primary_key)
total_changes += changes_count total_changes += changes_count
if changes_count > 0: if changes_count > 0:
last_synced_version = changes[-1]['version']
await self.update_sync_status(dest_conn, table_name, source_id, last_synced_version)
info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}") info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}")
else: else:
info(f"No changes to sync for {table_name}") info(f"No changes to sync for {table_name}")
@ -668,7 +537,6 @@ class APIConfig(BaseModel):
except Exception as e: except Exception as e:
err(f"Error syncing table {table_name}: {str(e)}") err(f"Error syncing table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
# Continue with the next table
info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}") info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}")
@ -684,63 +552,91 @@ class APIConfig(BaseModel):
return total_changes return total_changes
async def get_online_hosts(self) -> List[Dict[str, Any]]:
async def apply_batch_changes(self, conn, table_name, changes, has_primary_key): online_hosts = []
if not changes: for pool_entry in self.POOL:
return 0
try: try:
columns = list(changes[0].keys()) async with self.get_connection(pool_entry) as conn:
placeholders = [f'${i}' for i in range(1, len(columns) + 1)] if conn is not None:
online_hosts.append(pool_entry)
except Exception as e:
err(f"Error checking host {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
return online_hosts
if has_primary_key: async def push_changes_to_all(self):
insert_query = f""" for pool_entry in self.POOL:
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) if pool_entry['ts_id'] != os.environ.get('TS_ID'):
VALUES ({', '.join(placeholders)}) try:
ON CONFLICT ON CONSTRAINT {table_name}_pkey DO UPDATE SET await self.push_changes_to_one(pool_entry)
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in ['version', 'server_id'])}, except Exception as e:
version = EXCLUDED.version, err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
server_id = EXCLUDED.server_id
WHERE "{table_name}".version < EXCLUDED.version async def push_changes_to_one(self, pool_entry):
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) try:
""" async with self.get_connection() as local_conn:
async with self.get_connection(pool_entry) as remote_conn:
tables = await local_conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
try:
if table_name in self.SPECIAL_TABLES:
await self.sync_special_table(local_conn, remote_conn, table_name)
else: else:
# For tables without a primary key, we'll use all columns for conflict detection primary_key = await self.ensure_sync_columns(remote_conn, table_name)
insert_query = f""" last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT ({', '.join(f'"{col}"' for col in columns if col not in ['version', 'server_id'])}) DO UPDATE SET
version = EXCLUDED.version,
server_id = EXCLUDED.server_id
WHERE "{table_name}".version < EXCLUDED.version
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
"""
debug(f"Generated insert query for {table_name}: {insert_query}") changes = await local_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2
ORDER BY version ASC
""", last_synced_version, os.environ.get('TS_ID'))
# Prepare the statement if changes:
stmt = await conn.prepare(insert_query) changes_count = await self.apply_batch_changes(remote_conn, table_name, changes, primary_key)
affected_rows = 0 if changes_count > 0:
for change in changes: info(f"Pushed {changes_count} changes for table {table_name} to {pool_entry['ts_id']}")
try:
values = [change[col] for col in columns]
result = await stmt.fetchval(*values)
affected_rows += 1
except Exception as e:
err(f"Error inserting row into {table_name}: {str(e)}")
err(f"Row data: {change}")
# Continue with the next row
return affected_rows
except Exception as e: except Exception as e:
err(f"Error applying batch changes to {table_name}: {str(e)}") err(f"Error pushing changes for table {table_name} to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
return 0
info(f"Successfully pushed changes to {pool_entry['ts_id']}")
except Exception as e:
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
async def get_last_synced_version(self, conn, table_name, server_id):
if table_name in self.SPECIAL_TABLES:
return 0 # Special handling for tables without version column
return await conn.fetchval(f"""
SELECT COALESCE(MAX(version), 0)
FROM "{table_name}"
WHERE server_id = $1
""", server_id)
async def check_postgis(self, conn):
try:
result = await conn.fetchval("SELECT PostGIS_version();")
if result:
info(f"PostGIS version: {result}")
return True
else:
warn("PostGIS is not installed or not working properly")
return False
except Exception as e:
err(f"Error checking PostGIS: {str(e)}")
return False
async def sync_special_table(self, source_conn, dest_conn, table_name):
if table_name == 'spatial_ref_sys':
return await self.sync_spatial_ref_sys(source_conn, dest_conn)
# Add more special cases as needed
async def sync_spatial_ref_sys(self, source_conn, dest_conn): async def sync_spatial_ref_sys(self, source_conn, dest_conn):
try: try:
@ -803,125 +699,58 @@ class APIConfig(BaseModel):
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
return 0 return 0
async def get_most_recent_source(self):
most_recent_source = None
max_version = -1
local_ts_id = os.environ.get('TS_ID')
online_hosts = await self.get_online_hosts()
async def push_changes_to_all(self): for pool_entry in online_hosts:
for pool_entry in self.POOL: if pool_entry['ts_id'] == local_ts_id:
if pool_entry['ts_id'] != os.environ.get('TS_ID'): continue # Skip local database
try:
await self.push_changes_to_one(pool_entry)
except Exception as e:
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
async def push_changes_to_one(self, pool_entry):
try: try:
async with self.get_connection() as local_conn: async with self.get_connection(pool_entry) as conn:
async with self.get_connection(pool_entry) as remote_conn: tables = await conn.fetch("""
tables = await local_conn.fetch("""
SELECT tablename FROM pg_tables SELECT tablename FROM pg_tables
WHERE schemaname = 'public' WHERE schemaname = 'public'
""") """)
for table in tables: for table in tables:
table_name = table['tablename'] table_name = table['tablename']
if table_name in self.SPECIAL_TABLES:
continue # Skip special tables for version comparison
try: try:
last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID')) result = await conn.fetchrow(f"""
SELECT MAX(version) as max_version, server_id
changes = await local_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2
ORDER BY version ASC
""", last_synced_version, os.environ.get('TS_ID'))
if changes:
debug(f"Pushing changes for table {table_name}")
debug(f"Columns: {', '.join(changes[0].keys())}")
columns = list(changes[0].keys())
placeholders = [f'${i+1}' for i in range(len(columns))]
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT (id) DO UPDATE SET
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != 'id')}
"""
debug(f"Insert query: {insert_query}")
for change in changes:
values = [change[col] for col in columns]
await remote_conn.execute(insert_query, *values)
await self.update_sync_status(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version'])
except Exception as e:
err(f"Error pushing changes for table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
info(f"Successfully pushed changes to {pool_entry['ts_id']}")
except Exception as e:
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
async def update_sync_status(self, conn, table_name, server_id, version):
await conn.execute("""
INSERT INTO sync_status (table_name, server_id, last_synced_version, last_sync_time)
VALUES ($1, $2, $3, NOW())
ON CONFLICT (table_name, server_id) DO UPDATE
SET last_synced_version = EXCLUDED.last_synced_version,
last_sync_time = EXCLUDED.last_sync_time
""", table_name, server_id, version)
async def get_last_synced_version(self, conn, table_name, server_id):
return await conn.fetchval(f"""
SELECT COALESCE(MAX(version), 0)
FROM "{table_name}" FROM "{table_name}"
WHERE server_id = $1 WHERE version = (SELECT MAX(version) FROM "{table_name}")
""", server_id) GROUP BY server_id
ORDER BY MAX(version) DESC
async def update_last_synced_version(self, conn, table_name, server_id, version): LIMIT 1
await conn.execute(f"""
INSERT INTO "{table_name}" (server_id, version)
VALUES ($1, $2)
ON CONFLICT (server_id) DO UPDATE
SET version = EXCLUDED.version
WHERE "{table_name}".version < EXCLUDED.version
""", server_id, version)
async def get_schema_version(self, pool_entry):
async with self.get_connection(pool_entry) as conn:
return await conn.fetchval("""
SELECT COALESCE(MAX(version), 0) FROM (
SELECT MAX(version) as version FROM pg_tables
WHERE schemaname = 'public'
) as subquery
""") """)
async def create_sequence_if_not_exists(self, conn, sequence_name):
await conn.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
CREATE SEQUENCE {sequence_name};
END IF;
END $$;
""")
async def check_postgis(self, conn):
try:
result = await conn.fetchval("SELECT PostGIS_version();")
if result: if result:
info(f"PostGIS version: {result}") version, server_id = result['max_version'], result['server_id']
return True info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})")
if version > max_version:
max_version = version
most_recent_source = pool_entry
else: else:
warn("PostGIS is not installed or not working properly") info(f"No data in table {table_name} for {pool_entry['ts_id']}")
return False except asyncpg.exceptions.UndefinedColumnError:
warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping.")
except Exception as e: except Exception as e:
err(f"Error checking PostGIS: {str(e)}") err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}")
return False
except asyncpg.exceptions.ConnectionFailureError:
warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
except Exception as e:
err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
return most_recent_source
@ -1231,44 +1060,6 @@ class Geocoder:
def __del__(self): def __del__(self):
self.executor.shutdown() self.executor.shutdown()
class Database(BaseModel):
host: str = Field(..., description="Database host")
port: int = Field(5432, description="Database port")
user: str = Field(..., description="Database user")
password: str = Field(..., description="Database password")
database: str = Field(..., description="Database name")
db_schema: Optional[str] = Field(None, description="Database schema")
@asynccontextmanager
async def get_connection(self):
conn = await asyncpg.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.database
)
try:
if self.db_schema:
await conn.execute(f"SET search_path TO {self.db_schema}")
yield conn
finally:
await conn.close()
@classmethod
def from_env(cls):
import os
return cls(
host=os.getenv("DB_HOST", "localhost"),
port=int(os.getenv("DB_PORT", 5432)),
user=os.getenv("DB_USER"),
password=os.getenv("DB_PASSWORD"),
database=os.getenv("DB_NAME"),
db_schema=os.getenv("DB_SCHEMA")
)
def to_dict(self):
return self.dict(exclude_none=True)
class IMAPConfig(BaseModel): class IMAPConfig(BaseModel):
username: str username: str