Auto-update: Tue Jul 30 23:10:36 PDT 2024
This commit is contained in:
parent
8f38810626
commit
1ae0f506cf
2 changed files with 242 additions and 451 deletions
|
@ -6,7 +6,7 @@ import multiprocessing
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from dateutil import tz
|
from dateutil import tz
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from .classes import Database, Geocoder, APIConfig, Configuration
|
from .classes import Geocoder, APIConfig, Configuration
|
||||||
from .logs import Logger
|
from .logs import Logger
|
||||||
|
|
||||||
# INITIALization
|
# INITIALization
|
||||||
|
|
|
@ -167,6 +167,7 @@ class Configuration(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class APIConfig(BaseModel):
|
class APIConfig(BaseModel):
|
||||||
HOST: str
|
HOST: str
|
||||||
PORT: int
|
PORT: int
|
||||||
|
@ -182,6 +183,8 @@ class APIConfig(BaseModel):
|
||||||
GARBAGE: Dict[str, Any]
|
GARBAGE: Dict[str, Any]
|
||||||
_db_pools: Dict[str, asyncpg.Pool] = {}
|
_db_pools: Dict[str, asyncpg.Pool] = {}
|
||||||
|
|
||||||
|
SPECIAL_TABLES = ['spatial_ref_sys'] # Tables that can't have server_id and version columns
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
|
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
|
||||||
config_path = cls._resolve_path(config_path, 'config')
|
config_path = cls._resolve_path(config_path, 'config')
|
||||||
|
@ -301,36 +304,37 @@ class APIConfig(BaseModel):
|
||||||
if pool_entry is None:
|
if pool_entry is None:
|
||||||
pool_entry = self.local_db
|
pool_entry = self.local_db
|
||||||
|
|
||||||
info(f"Attempting to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
|
||||||
try:
|
|
||||||
conn = await asyncpg.connect(
|
if pool_key not in self._db_pools:
|
||||||
host=pool_entry['ts_ip'],
|
|
||||||
port=pool_entry['db_port'],
|
|
||||||
user=pool_entry['db_user'],
|
|
||||||
password=pool_entry['db_pass'],
|
|
||||||
database=pool_entry['db_name'],
|
|
||||||
timeout=5 # Add a timeout to prevent hanging
|
|
||||||
)
|
|
||||||
info(f"Successfully connected to {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
|
||||||
try:
|
try:
|
||||||
|
self._db_pools[pool_key] = await asyncpg.create_pool(
|
||||||
|
host=pool_entry['ts_ip'],
|
||||||
|
port=pool_entry['db_port'],
|
||||||
|
user=pool_entry['db_user'],
|
||||||
|
password=pool_entry['db_pass'],
|
||||||
|
database=pool_entry['db_name'],
|
||||||
|
min_size=1,
|
||||||
|
max_size=10, # adjust as needed
|
||||||
|
timeout=5 # connection timeout in seconds
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Failed to create connection pool for {pool_key}: {str(e)}")
|
||||||
|
yield None
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self._db_pools[pool_key].acquire() as conn:
|
||||||
yield conn
|
yield conn
|
||||||
finally:
|
|
||||||
await conn.close()
|
|
||||||
info(f"Closed connection to {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
|
||||||
except asyncpg.exceptions.ConnectionDoesNotExistError:
|
except asyncpg.exceptions.ConnectionDoesNotExistError:
|
||||||
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']} - Connection does not exist")
|
err(f"Failed to acquire connection from pool for {pool_key}: Connection does not exist")
|
||||||
raise
|
yield None
|
||||||
except asyncpg.exceptions.ConnectionFailureError:
|
except asyncpg.exceptions.ConnectionFailureError:
|
||||||
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']} - Connection failure")
|
err(f"Failed to acquire connection from pool for {pool_key}: Connection failure")
|
||||||
raise
|
yield None
|
||||||
except asyncpg.exceptions.PostgresError as e:
|
|
||||||
err(f"PostgreSQL error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Unexpected error when connecting to {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
|
err(f"Unexpected error when acquiring connection from pool for {pool_key}: {str(e)}")
|
||||||
raise
|
yield None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def close_db_pools(self):
|
async def close_db_pools(self):
|
||||||
info("Closing database connection pools...")
|
info("Closing database connection pools...")
|
||||||
|
@ -343,14 +347,18 @@ class APIConfig(BaseModel):
|
||||||
self._db_pools.clear()
|
self._db_pools.clear()
|
||||||
info("All database connection pools closed.")
|
info("All database connection pools closed.")
|
||||||
|
|
||||||
|
|
||||||
async def initialize_sync(self):
|
async def initialize_sync(self):
|
||||||
local_ts_id = os.environ.get('TS_ID')
|
local_ts_id = os.environ.get('TS_ID')
|
||||||
for pool_entry in self.POOL:
|
online_hosts = await self.get_online_hosts()
|
||||||
|
|
||||||
|
for pool_entry in online_hosts:
|
||||||
if pool_entry['ts_id'] == local_ts_id:
|
if pool_entry['ts_id'] == local_ts_id:
|
||||||
continue # Skip local database
|
continue # Skip local database
|
||||||
try:
|
try:
|
||||||
async with self.get_connection(pool_entry) as conn:
|
async with self.get_connection(pool_entry) as conn:
|
||||||
|
if conn is None:
|
||||||
|
continue # Skip this database if connection failed
|
||||||
|
|
||||||
info(f"Starting sync initialization for {pool_entry['ts_ip']}...")
|
info(f"Starting sync initialization for {pool_entry['ts_ip']}...")
|
||||||
|
|
||||||
# Check PostGIS installation
|
# Check PostGIS installation
|
||||||
|
@ -358,130 +366,67 @@ class APIConfig(BaseModel):
|
||||||
if not postgis_installed:
|
if not postgis_installed:
|
||||||
warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.")
|
warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.")
|
||||||
|
|
||||||
# Initialize sync_status table
|
|
||||||
await self.initialize_sync_status_table(conn)
|
|
||||||
|
|
||||||
# Continue with sync initialization
|
|
||||||
tables = await conn.fetch("""
|
tables = await conn.fetch("""
|
||||||
SELECT tablename FROM pg_tables
|
SELECT tablename FROM pg_tables
|
||||||
WHERE schemaname = 'public'
|
WHERE schemaname = 'public'
|
||||||
""")
|
""")
|
||||||
|
|
||||||
all_tables_synced = True
|
|
||||||
for table in tables:
|
for table in tables:
|
||||||
table_name = table['tablename']
|
table_name = table['tablename']
|
||||||
if not await self.ensure_sync_columns(conn, table_name):
|
await self.ensure_sync_columns(conn, table_name)
|
||||||
all_tables_synced = False
|
|
||||||
|
|
||||||
if all_tables_synced:
|
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have necessary sync columns and triggers.")
|
||||||
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.")
|
|
||||||
else:
|
|
||||||
warn(f"Sync initialization partially complete for {pool_entry['ts_ip']}. Some tables may be missing version or server_id columns.")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
|
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize_sync_status_table(self, conn):
|
|
||||||
await conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS sync_status (
|
|
||||||
table_name TEXT,
|
|
||||||
server_id TEXT,
|
|
||||||
last_synced_version INTEGER,
|
|
||||||
last_sync_time TIMESTAMP WITH TIME ZONE,
|
|
||||||
PRIMARY KEY (table_name, server_id)
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Check if the last_sync_time column exists, and add it if it doesn't
|
|
||||||
column_exists = await conn.fetchval("""
|
|
||||||
SELECT EXISTS (
|
|
||||||
SELECT 1
|
|
||||||
FROM information_schema.columns
|
|
||||||
WHERE table_name = 'sync_status' AND column_name = 'last_sync_time'
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
if not column_exists:
|
|
||||||
await conn.execute("""
|
|
||||||
ALTER TABLE sync_status
|
|
||||||
ADD COLUMN last_sync_time TIMESTAMP WITH TIME ZONE
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def ensure_sync_structure(self, conn):
|
|
||||||
tables = await conn.fetch("""
|
|
||||||
SELECT tablename FROM pg_tables
|
|
||||||
WHERE schemaname = 'public'
|
|
||||||
""")
|
|
||||||
|
|
||||||
for table in tables:
|
|
||||||
table_name = table['tablename']
|
|
||||||
await self.ensure_sync_columns(conn, table_name)
|
|
||||||
await self.ensure_sync_trigger(conn, table_name)
|
|
||||||
|
|
||||||
|
|
||||||
async def ensure_sync_columns(self, conn, table_name):
|
async def ensure_sync_columns(self, conn, table_name):
|
||||||
|
if table_name in self.SPECIAL_TABLES:
|
||||||
|
info(f"Skipping sync columns for special table: {table_name}")
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check if the table has a primary key
|
# Get primary key information
|
||||||
has_primary_key = await conn.fetchval(f"""
|
primary_key = await conn.fetchval(f"""
|
||||||
SELECT EXISTS (
|
SELECT a.attname
|
||||||
SELECT 1
|
FROM pg_index i
|
||||||
FROM information_schema.table_constraints
|
JOIN pg_attribute a ON a.attrelid = i.indrelid
|
||||||
WHERE table_name = '{table_name}'
|
AND a.attnum = ANY(i.indkey)
|
||||||
AND constraint_type = 'PRIMARY KEY'
|
WHERE i.indrelid = '{table_name}'::regclass
|
||||||
)
|
AND i.indisprimary;
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Check if version column exists
|
# Ensure version column exists
|
||||||
version_exists = await conn.fetchval(f"""
|
await conn.execute(f"""
|
||||||
SELECT EXISTS (
|
ALTER TABLE "{table_name}"
|
||||||
SELECT 1 FROM information_schema.columns
|
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1;
|
||||||
WHERE table_name = '{table_name}' AND column_name = 'version'
|
|
||||||
)
|
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Check if server_id column exists
|
# Ensure server_id column exists
|
||||||
server_id_exists = await conn.fetchval(f"""
|
await conn.execute(f"""
|
||||||
SELECT EXISTS (
|
ALTER TABLE "{table_name}"
|
||||||
SELECT 1 FROM information_schema.columns
|
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
|
||||||
WHERE table_name = '{table_name}' AND column_name = 'server_id'
|
|
||||||
)
|
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Add version column if it doesn't exist
|
|
||||||
if not version_exists:
|
|
||||||
await conn.execute(f"""
|
|
||||||
ALTER TABLE "{table_name}" ADD COLUMN version INTEGER DEFAULT 1
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Add server_id column if it doesn't exist
|
|
||||||
if not server_id_exists:
|
|
||||||
await conn.execute(f"""
|
|
||||||
ALTER TABLE "{table_name}" ADD COLUMN server_id TEXT DEFAULT '{os.environ.get('TS_ID')}'
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Create or replace the trigger function
|
# Create or replace the trigger function
|
||||||
await conn.execute("""
|
await conn.execute(f"""
|
||||||
CREATE OR REPLACE FUNCTION update_version_and_server_id()
|
CREATE OR REPLACE FUNCTION update_version_and_server_id()
|
||||||
RETURNS TRIGGER AS $$
|
RETURNS TRIGGER AS $$
|
||||||
BEGIN
|
BEGIN
|
||||||
NEW.version = COALESCE(OLD.version, 0) + 1;
|
NEW.version = COALESCE(OLD.version, 0) + 1;
|
||||||
NEW.server_id = $1;
|
NEW.server_id = '{os.environ.get('TS_ID')}';
|
||||||
RETURN NEW;
|
RETURN NEW;
|
||||||
END;
|
END;
|
||||||
$$ LANGUAGE plpgsql;
|
$$ LANGUAGE plpgsql;
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Create the trigger if it doesn't exist
|
# Check if the trigger exists and create it if it doesn't
|
||||||
trigger_exists = await conn.fetchval(f"""
|
trigger_exists = await conn.fetchval(f"""
|
||||||
SELECT EXISTS (
|
SELECT EXISTS (
|
||||||
SELECT 1 FROM pg_trigger
|
SELECT 1
|
||||||
WHERE tgname = 'update_version_and_server_id_trigger'
|
FROM pg_trigger
|
||||||
|
WHERE tgname = 'update_version_and_server_id_trigger'
|
||||||
AND tgrelid = '{table_name}'::regclass
|
AND tgrelid = '{table_name}'::regclass
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
@ -490,128 +435,58 @@ class APIConfig(BaseModel):
|
||||||
await conn.execute(f"""
|
await conn.execute(f"""
|
||||||
CREATE TRIGGER update_version_and_server_id_trigger
|
CREATE TRIGGER update_version_and_server_id_trigger
|
||||||
BEFORE INSERT OR UPDATE ON "{table_name}"
|
BEFORE INSERT OR UPDATE ON "{table_name}"
|
||||||
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id('{os.environ.get('TS_ID')}')
|
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
|
||||||
""")
|
""")
|
||||||
|
|
||||||
info(f"Successfully ensured sync columns and trigger for table {table_name}. Has primary key: {has_primary_key}")
|
info(f"Successfully ensured sync columns and trigger for table {table_name}")
|
||||||
return has_primary_key
|
return primary_key
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
|
err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
return False
|
|
||||||
|
|
||||||
async def ensure_sync_trigger(self, conn, table_name):
|
async def apply_batch_changes(self, conn, table_name, changes, primary_key):
|
||||||
await conn.execute(f"""
|
if not changes:
|
||||||
CREATE OR REPLACE FUNCTION update_version_and_server_id()
|
return 0
|
||||||
RETURNS TRIGGER AS $$
|
|
||||||
BEGIN
|
|
||||||
NEW.version = COALESCE(OLD.version, 0) + 1;
|
|
||||||
NEW.server_id = '{os.environ.get('TS_ID')}';
|
|
||||||
RETURN NEW;
|
|
||||||
END;
|
|
||||||
$$ LANGUAGE plpgsql;
|
|
||||||
|
|
||||||
DROP TRIGGER IF EXISTS update_version_and_server_id_trigger ON "{table_name}";
|
try:
|
||||||
|
columns = list(changes[0].keys())
|
||||||
CREATE TRIGGER update_version_and_server_id_trigger
|
placeholders = [f'${i+1}' for i in range(len(columns))]
|
||||||
BEFORE INSERT OR UPDATE ON "{table_name}"
|
|
||||||
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
|
|
||||||
""")
|
|
||||||
|
|
||||||
async def get_most_recent_source(self):
|
|
||||||
most_recent_source = None
|
|
||||||
max_version = -1
|
|
||||||
local_ts_id = os.environ.get('TS_ID')
|
|
||||||
|
|
||||||
for pool_entry in self.POOL:
|
|
||||||
if pool_entry['ts_id'] == local_ts_id:
|
|
||||||
continue # Skip local database
|
|
||||||
|
|
||||||
if not await self.is_server_accessible(pool_entry['ts_ip'], pool_entry['db_port']):
|
if primary_key:
|
||||||
warn(f"Server {pool_entry['ts_id']} ({pool_entry['ts_ip']}:{pool_entry['db_port']}) is not accessible. Skipping.")
|
insert_query = f"""
|
||||||
continue
|
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
|
||||||
|
VALUES ({', '.join(placeholders)})
|
||||||
|
ON CONFLICT ("{primary_key}") DO UPDATE SET
|
||||||
|
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in [primary_key, 'version', 'server_id'])},
|
||||||
|
version = EXCLUDED.version,
|
||||||
|
server_id = EXCLUDED.server_id
|
||||||
|
WHERE "{table_name}".version < EXCLUDED.version
|
||||||
|
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
|
||||||
|
"""
|
||||||
|
else:
|
||||||
|
# For tables without a primary key, we'll use all columns for conflict resolution
|
||||||
|
insert_query = f"""
|
||||||
|
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
|
||||||
|
VALUES ({', '.join(placeholders)})
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
"""
|
||||||
|
|
||||||
try:
|
debug(f"Generated insert query for {table_name}: {insert_query}")
|
||||||
async with self.get_connection(pool_entry) as conn:
|
|
||||||
tables = await conn.fetch("""
|
|
||||||
SELECT tablename FROM pg_tables
|
|
||||||
WHERE schemaname = 'public'
|
|
||||||
""")
|
|
||||||
|
|
||||||
for table in tables:
|
|
||||||
table_name = table['tablename']
|
|
||||||
try:
|
|
||||||
result = await conn.fetchrow(f"""
|
|
||||||
SELECT MAX(version) as max_version, server_id
|
|
||||||
FROM "{table_name}"
|
|
||||||
WHERE version = (SELECT MAX(version) FROM "{table_name}")
|
|
||||||
GROUP BY server_id
|
|
||||||
ORDER BY MAX(version) DESC
|
|
||||||
LIMIT 1
|
|
||||||
""")
|
|
||||||
if result:
|
|
||||||
version, server_id = result['max_version'], result['server_id']
|
|
||||||
info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})")
|
|
||||||
if version > max_version:
|
|
||||||
max_version = version
|
|
||||||
most_recent_source = pool_entry
|
|
||||||
else:
|
|
||||||
info(f"No data in table {table_name} for {pool_entry['ts_id']}")
|
|
||||||
except asyncpg.exceptions.UndefinedColumnError:
|
|
||||||
warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Attempting to add...")
|
|
||||||
await self.ensure_sync_columns(conn, table_name)
|
|
||||||
except Exception as e:
|
|
||||||
err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}")
|
|
||||||
|
|
||||||
except asyncpg.exceptions.ConnectionFailureError as e:
|
affected_rows = 0
|
||||||
err(f"Failed to establish database connection with {pool_entry['ts_id']} ({pool_entry['ts_ip']}:{pool_entry['db_port']}): {str(e)}")
|
async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"):
|
||||||
except Exception as e:
|
values = [change[col] for col in columns]
|
||||||
err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}")
|
debug(f"Executing query for {table_name} with values: {values}")
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
result = await conn.execute(insert_query, *values)
|
||||||
|
affected_rows += int(result.split()[-1])
|
||||||
return most_recent_source
|
|
||||||
|
|
||||||
|
return affected_rows
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def is_server_accessible(self, host, port, timeout=2):
|
|
||||||
try:
|
|
||||||
future = asyncio.open_connection(host, port)
|
|
||||||
await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
return True
|
|
||||||
except (asyncio.TimeoutError, ConnectionRefusedError, socket.gaierror):
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def check_version_column_exists(self, conn):
|
|
||||||
try:
|
|
||||||
result = await conn.fetchval("""
|
|
||||||
SELECT EXISTS (
|
|
||||||
SELECT 1
|
|
||||||
FROM information_schema.columns
|
|
||||||
WHERE table_schema = 'public'
|
|
||||||
AND column_name = 'version'
|
|
||||||
AND table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
if not result:
|
|
||||||
tables_without_version = await conn.fetch("""
|
|
||||||
SELECT tablename
|
|
||||||
FROM pg_tables
|
|
||||||
WHERE schemaname = 'public'
|
|
||||||
AND tablename NOT IN (
|
|
||||||
SELECT table_name
|
|
||||||
FROM information_schema.columns
|
|
||||||
WHERE table_schema = 'public' AND column_name = 'version'
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
table_names = ", ".join([t['tablename'] for t in tables_without_version])
|
|
||||||
warn(f"Tables without 'version' column: {table_names}")
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error checking for 'version' column existence: {str(e)}")
|
err(f"Error applying batch changes to {table_name}: {str(e)}")
|
||||||
return False
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
|
return 0
|
||||||
|
|
||||||
async def pull_changes(self, source_pool_entry, batch_size=10000):
|
async def pull_changes(self, source_pool_entry, batch_size=10000):
|
||||||
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
|
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
|
||||||
|
@ -634,41 +509,34 @@ class APIConfig(BaseModel):
|
||||||
WHERE schemaname = 'public'
|
WHERE schemaname = 'public'
|
||||||
""")
|
""")
|
||||||
|
|
||||||
for table in tables:
|
async for table in tqdm(tables, desc="Syncing tables", unit="table"):
|
||||||
table_name = table['tablename']
|
table_name = table['tablename']
|
||||||
try:
|
try:
|
||||||
debug(f"Processing table: {table_name}")
|
if table_name in self.SPECIAL_TABLES:
|
||||||
has_primary_key = await self.ensure_sync_columns(dest_conn, table_name)
|
await self.sync_special_table(source_conn, dest_conn, table_name)
|
||||||
debug(f"Table {table_name} has primary key: {has_primary_key}")
|
|
||||||
last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
|
|
||||||
debug(f"Last synced version for {table_name}: {last_synced_version}")
|
|
||||||
|
|
||||||
changes = await source_conn.fetch(f"""
|
|
||||||
SELECT * FROM "{table_name}"
|
|
||||||
WHERE version > $1 AND server_id = $2
|
|
||||||
ORDER BY version ASC
|
|
||||||
LIMIT $3
|
|
||||||
""", last_synced_version, source_id, batch_size)
|
|
||||||
|
|
||||||
debug(f"Number of changes for {table_name}: {len(changes)}")
|
|
||||||
|
|
||||||
if changes:
|
|
||||||
debug(f"Sample change for {table_name}: {changes[0]}")
|
|
||||||
changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, has_primary_key)
|
|
||||||
total_changes += changes_count
|
|
||||||
|
|
||||||
if changes_count > 0:
|
|
||||||
last_synced_version = changes[-1]['version']
|
|
||||||
await self.update_sync_status(dest_conn, table_name, source_id, last_synced_version)
|
|
||||||
|
|
||||||
info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}")
|
|
||||||
else:
|
else:
|
||||||
info(f"No changes to sync for {table_name}")
|
primary_key = await self.ensure_sync_columns(dest_conn, table_name)
|
||||||
|
last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
|
||||||
|
|
||||||
|
changes = await source_conn.fetch(f"""
|
||||||
|
SELECT * FROM "{table_name}"
|
||||||
|
WHERE version > $1 AND server_id = $2
|
||||||
|
ORDER BY version ASC
|
||||||
|
LIMIT $3
|
||||||
|
""", last_synced_version, source_id, batch_size)
|
||||||
|
|
||||||
|
if changes:
|
||||||
|
changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, primary_key)
|
||||||
|
total_changes += changes_count
|
||||||
|
|
||||||
|
if changes_count > 0:
|
||||||
|
info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}")
|
||||||
|
else:
|
||||||
|
info(f"No changes to sync for {table_name}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error syncing table {table_name}: {str(e)}")
|
err(f"Error syncing table {table_name}: {str(e)}")
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
# Continue with the next table
|
|
||||||
|
|
||||||
info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}")
|
info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}")
|
||||||
|
|
||||||
|
@ -684,63 +552,91 @@ class APIConfig(BaseModel):
|
||||||
|
|
||||||
return total_changes
|
return total_changes
|
||||||
|
|
||||||
|
async def get_online_hosts(self) -> List[Dict[str, Any]]:
|
||||||
|
online_hosts = []
|
||||||
|
for pool_entry in self.POOL:
|
||||||
|
try:
|
||||||
|
async with self.get_connection(pool_entry) as conn:
|
||||||
|
if conn is not None:
|
||||||
|
online_hosts.append(pool_entry)
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Error checking host {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
|
||||||
|
return online_hosts
|
||||||
|
|
||||||
async def apply_batch_changes(self, conn, table_name, changes, has_primary_key):
|
async def push_changes_to_all(self):
|
||||||
if not changes:
|
for pool_entry in self.POOL:
|
||||||
return 0
|
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
|
||||||
|
|
||||||
try:
|
|
||||||
columns = list(changes[0].keys())
|
|
||||||
placeholders = [f'${i}' for i in range(1, len(columns) + 1)]
|
|
||||||
|
|
||||||
if has_primary_key:
|
|
||||||
insert_query = f"""
|
|
||||||
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
|
|
||||||
VALUES ({', '.join(placeholders)})
|
|
||||||
ON CONFLICT ON CONSTRAINT {table_name}_pkey DO UPDATE SET
|
|
||||||
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in ['version', 'server_id'])},
|
|
||||||
version = EXCLUDED.version,
|
|
||||||
server_id = EXCLUDED.server_id
|
|
||||||
WHERE "{table_name}".version < EXCLUDED.version
|
|
||||||
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
|
|
||||||
"""
|
|
||||||
else:
|
|
||||||
# For tables without a primary key, we'll use all columns for conflict detection
|
|
||||||
insert_query = f"""
|
|
||||||
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
|
|
||||||
VALUES ({', '.join(placeholders)})
|
|
||||||
ON CONFLICT ({', '.join(f'"{col}"' for col in columns if col not in ['version', 'server_id'])}) DO UPDATE SET
|
|
||||||
version = EXCLUDED.version,
|
|
||||||
server_id = EXCLUDED.server_id
|
|
||||||
WHERE "{table_name}".version < EXCLUDED.version
|
|
||||||
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
|
|
||||||
"""
|
|
||||||
|
|
||||||
debug(f"Generated insert query for {table_name}: {insert_query}")
|
|
||||||
|
|
||||||
# Prepare the statement
|
|
||||||
stmt = await conn.prepare(insert_query)
|
|
||||||
|
|
||||||
affected_rows = 0
|
|
||||||
for change in changes:
|
|
||||||
try:
|
try:
|
||||||
values = [change[col] for col in columns]
|
await self.push_changes_to_one(pool_entry)
|
||||||
result = await stmt.fetchval(*values)
|
|
||||||
affected_rows += 1
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error inserting row into {table_name}: {str(e)}")
|
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
|
||||||
err(f"Row data: {change}")
|
|
||||||
# Continue with the next row
|
|
||||||
|
|
||||||
return affected_rows
|
|
||||||
|
|
||||||
|
async def push_changes_to_one(self, pool_entry):
|
||||||
|
try:
|
||||||
|
async with self.get_connection() as local_conn:
|
||||||
|
async with self.get_connection(pool_entry) as remote_conn:
|
||||||
|
tables = await local_conn.fetch("""
|
||||||
|
SELECT tablename FROM pg_tables
|
||||||
|
WHERE schemaname = 'public'
|
||||||
|
""")
|
||||||
|
|
||||||
|
for table in tables:
|
||||||
|
table_name = table['tablename']
|
||||||
|
try:
|
||||||
|
if table_name in self.SPECIAL_TABLES:
|
||||||
|
await self.sync_special_table(local_conn, remote_conn, table_name)
|
||||||
|
else:
|
||||||
|
primary_key = await self.ensure_sync_columns(remote_conn, table_name)
|
||||||
|
last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
|
||||||
|
|
||||||
|
changes = await local_conn.fetch(f"""
|
||||||
|
SELECT * FROM "{table_name}"
|
||||||
|
WHERE version > $1 AND server_id = $2
|
||||||
|
ORDER BY version ASC
|
||||||
|
""", last_synced_version, os.environ.get('TS_ID'))
|
||||||
|
|
||||||
|
if changes:
|
||||||
|
changes_count = await self.apply_batch_changes(remote_conn, table_name, changes, primary_key)
|
||||||
|
|
||||||
|
if changes_count > 0:
|
||||||
|
info(f"Pushed {changes_count} changes for table {table_name} to {pool_entry['ts_id']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Error pushing changes for table {table_name} to {pool_entry['ts_id']}: {str(e)}")
|
||||||
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
info(f"Successfully pushed changes to {pool_entry['ts_id']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error applying batch changes to {table_name}: {str(e)}")
|
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
return 0
|
|
||||||
|
|
||||||
|
async def get_last_synced_version(self, conn, table_name, server_id):
|
||||||
|
if table_name in self.SPECIAL_TABLES:
|
||||||
|
return 0 # Special handling for tables without version column
|
||||||
|
|
||||||
|
return await conn.fetchval(f"""
|
||||||
|
SELECT COALESCE(MAX(version), 0)
|
||||||
|
FROM "{table_name}"
|
||||||
|
WHERE server_id = $1
|
||||||
|
""", server_id)
|
||||||
|
|
||||||
|
async def check_postgis(self, conn):
|
||||||
|
try:
|
||||||
|
result = await conn.fetchval("SELECT PostGIS_version();")
|
||||||
|
if result:
|
||||||
|
info(f"PostGIS version: {result}")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
warn("PostGIS is not installed or not working properly")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Error checking PostGIS: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def sync_special_table(self, source_conn, dest_conn, table_name):
|
||||||
|
if table_name == 'spatial_ref_sys':
|
||||||
|
return await self.sync_spatial_ref_sys(source_conn, dest_conn)
|
||||||
|
# Add more special cases as needed
|
||||||
|
|
||||||
async def sync_spatial_ref_sys(self, source_conn, dest_conn):
|
async def sync_spatial_ref_sys(self, source_conn, dest_conn):
|
||||||
try:
|
try:
|
||||||
|
@ -803,126 +699,59 @@ class APIConfig(BaseModel):
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
async def get_most_recent_source(self):
|
||||||
async def push_changes_to_all(self):
|
most_recent_source = None
|
||||||
for pool_entry in self.POOL:
|
max_version = -1
|
||||||
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
|
local_ts_id = os.environ.get('TS_ID')
|
||||||
try:
|
online_hosts = await self.get_online_hosts()
|
||||||
await self.push_changes_to_one(pool_entry)
|
|
||||||
except Exception as e:
|
for pool_entry in online_hosts:
|
||||||
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
|
if pool_entry['ts_id'] == local_ts_id:
|
||||||
|
continue # Skip local database
|
||||||
async def push_changes_to_one(self, pool_entry):
|
|
||||||
try:
|
try:
|
||||||
async with self.get_connection() as local_conn:
|
async with self.get_connection(pool_entry) as conn:
|
||||||
async with self.get_connection(pool_entry) as remote_conn:
|
tables = await conn.fetch("""
|
||||||
tables = await local_conn.fetch("""
|
|
||||||
SELECT tablename FROM pg_tables
|
SELECT tablename FROM pg_tables
|
||||||
WHERE schemaname = 'public'
|
WHERE schemaname = 'public'
|
||||||
""")
|
""")
|
||||||
|
|
||||||
for table in tables:
|
for table in tables:
|
||||||
table_name = table['tablename']
|
table_name = table['tablename']
|
||||||
|
if table_name in self.SPECIAL_TABLES:
|
||||||
|
continue # Skip special tables for version comparison
|
||||||
try:
|
try:
|
||||||
last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
|
result = await conn.fetchrow(f"""
|
||||||
|
SELECT MAX(version) as max_version, server_id
|
||||||
changes = await local_conn.fetch(f"""
|
FROM "{table_name}"
|
||||||
SELECT * FROM "{table_name}"
|
WHERE version = (SELECT MAX(version) FROM "{table_name}")
|
||||||
WHERE version > $1 AND server_id = $2
|
GROUP BY server_id
|
||||||
ORDER BY version ASC
|
ORDER BY MAX(version) DESC
|
||||||
""", last_synced_version, os.environ.get('TS_ID'))
|
LIMIT 1
|
||||||
|
""")
|
||||||
if changes:
|
if result:
|
||||||
debug(f"Pushing changes for table {table_name}")
|
version, server_id = result['max_version'], result['server_id']
|
||||||
debug(f"Columns: {', '.join(changes[0].keys())}")
|
info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})")
|
||||||
|
if version > max_version:
|
||||||
columns = list(changes[0].keys())
|
max_version = version
|
||||||
placeholders = [f'${i+1}' for i in range(len(columns))]
|
most_recent_source = pool_entry
|
||||||
|
else:
|
||||||
insert_query = f"""
|
info(f"No data in table {table_name} for {pool_entry['ts_id']}")
|
||||||
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
|
except asyncpg.exceptions.UndefinedColumnError:
|
||||||
VALUES ({', '.join(placeholders)})
|
warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping.")
|
||||||
ON CONFLICT (id) DO UPDATE SET
|
|
||||||
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != 'id')}
|
|
||||||
"""
|
|
||||||
|
|
||||||
debug(f"Insert query: {insert_query}")
|
|
||||||
|
|
||||||
for change in changes:
|
|
||||||
values = [change[col] for col in columns]
|
|
||||||
await remote_conn.execute(insert_query, *values)
|
|
||||||
|
|
||||||
await self.update_sync_status(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version'])
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error pushing changes for table {table_name}: {str(e)}")
|
err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}")
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
|
||||||
|
except asyncpg.exceptions.ConnectionFailureError:
|
||||||
info(f"Successfully pushed changes to {pool_entry['ts_id']}")
|
warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
|
err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}")
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
return most_recent_source
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def update_sync_status(self, conn, table_name, server_id, version):
|
|
||||||
await conn.execute("""
|
|
||||||
INSERT INTO sync_status (table_name, server_id, last_synced_version, last_sync_time)
|
|
||||||
VALUES ($1, $2, $3, NOW())
|
|
||||||
ON CONFLICT (table_name, server_id) DO UPDATE
|
|
||||||
SET last_synced_version = EXCLUDED.last_synced_version,
|
|
||||||
last_sync_time = EXCLUDED.last_sync_time
|
|
||||||
""", table_name, server_id, version)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_last_synced_version(self, conn, table_name, server_id):
|
|
||||||
return await conn.fetchval(f"""
|
|
||||||
SELECT COALESCE(MAX(version), 0)
|
|
||||||
FROM "{table_name}"
|
|
||||||
WHERE server_id = $1
|
|
||||||
""", server_id)
|
|
||||||
|
|
||||||
async def update_last_synced_version(self, conn, table_name, server_id, version):
|
|
||||||
await conn.execute(f"""
|
|
||||||
INSERT INTO "{table_name}" (server_id, version)
|
|
||||||
VALUES ($1, $2)
|
|
||||||
ON CONFLICT (server_id) DO UPDATE
|
|
||||||
SET version = EXCLUDED.version
|
|
||||||
WHERE "{table_name}".version < EXCLUDED.version
|
|
||||||
""", server_id, version)
|
|
||||||
|
|
||||||
async def get_schema_version(self, pool_entry):
|
|
||||||
async with self.get_connection(pool_entry) as conn:
|
|
||||||
return await conn.fetchval("""
|
|
||||||
SELECT COALESCE(MAX(version), 0) FROM (
|
|
||||||
SELECT MAX(version) as version FROM pg_tables
|
|
||||||
WHERE schemaname = 'public'
|
|
||||||
) as subquery
|
|
||||||
""")
|
|
||||||
|
|
||||||
async def create_sequence_if_not_exists(self, conn, sequence_name):
|
|
||||||
await conn.execute(f"""
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
|
|
||||||
CREATE SEQUENCE {sequence_name};
|
|
||||||
END IF;
|
|
||||||
END $$;
|
|
||||||
""")
|
|
||||||
|
|
||||||
async def check_postgis(self, conn):
|
|
||||||
try:
|
|
||||||
result = await conn.fetchval("SELECT PostGIS_version();")
|
|
||||||
if result:
|
|
||||||
info(f"PostGIS version: {result}")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
warn("PostGIS is not installed or not working properly")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
err(f"Error checking PostGIS: {str(e)}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -1231,44 +1060,6 @@ class Geocoder:
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.executor.shutdown()
|
self.executor.shutdown()
|
||||||
|
|
||||||
class Database(BaseModel):
|
|
||||||
host: str = Field(..., description="Database host")
|
|
||||||
port: int = Field(5432, description="Database port")
|
|
||||||
user: str = Field(..., description="Database user")
|
|
||||||
password: str = Field(..., description="Database password")
|
|
||||||
database: str = Field(..., description="Database name")
|
|
||||||
db_schema: Optional[str] = Field(None, description="Database schema")
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def get_connection(self):
|
|
||||||
conn = await asyncpg.connect(
|
|
||||||
host=self.host,
|
|
||||||
port=self.port,
|
|
||||||
user=self.user,
|
|
||||||
password=self.password,
|
|
||||||
database=self.database
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
if self.db_schema:
|
|
||||||
await conn.execute(f"SET search_path TO {self.db_schema}")
|
|
||||||
yield conn
|
|
||||||
finally:
|
|
||||||
await conn.close()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_env(cls):
|
|
||||||
import os
|
|
||||||
return cls(
|
|
||||||
host=os.getenv("DB_HOST", "localhost"),
|
|
||||||
port=int(os.getenv("DB_PORT", 5432)),
|
|
||||||
user=os.getenv("DB_USER"),
|
|
||||||
password=os.getenv("DB_PASSWORD"),
|
|
||||||
database=os.getenv("DB_NAME"),
|
|
||||||
db_schema=os.getenv("DB_SCHEMA")
|
|
||||||
)
|
|
||||||
|
|
||||||
def to_dict(self):
|
|
||||||
return self.dict(exclude_none=True)
|
|
||||||
|
|
||||||
class IMAPConfig(BaseModel):
|
class IMAPConfig(BaseModel):
|
||||||
username: str
|
username: str
|
||||||
|
|
Loading…
Reference in a new issue