Auto-update: Mon Jul 29 19:27:49 PDT 2024
This commit is contained in:
parent
05400d4fa4
commit
cda6481a97
2 changed files with 33 additions and 338 deletions
|
@ -62,11 +62,7 @@ async def lifespan(app: FastAPI):
|
|||
# Initialize sync structures
|
||||
await API.initialize_sync()
|
||||
|
||||
# Sync schema across all databases
|
||||
await API.sync_schema()
|
||||
crit("Schema synchronization complete.")
|
||||
|
||||
# Check if other instances have more recent data
|
||||
# Now that tables are initialized, check for the most recent source
|
||||
source = await API.get_most_recent_source()
|
||||
if source:
|
||||
crit(f"Pulling changes from {source['ts_id']} ({source['ts_ip']})...")
|
||||
|
@ -75,7 +71,6 @@ async def lifespan(app: FastAPI):
|
|||
else:
|
||||
crit("No instances with more recent data found.")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
crit(f"Error during startup: {str(e)}")
|
||||
crit(f"Traceback: {traceback.format_exc()}")
|
||||
|
@ -86,6 +81,7 @@ async def lifespan(app: FastAPI):
|
|||
crit("Shutting down...")
|
||||
# Perform any cleanup operations here if needed
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
|
|
|
@ -317,15 +317,6 @@ class APIConfig(BaseModel):
|
|||
|
||||
async def initialize_sync(self):
|
||||
async with self.get_connection() as conn:
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS sync_status (
|
||||
table_name TEXT,
|
||||
server_id TEXT,
|
||||
last_synced_version INTEGER,
|
||||
PRIMARY KEY (table_name, server_id)
|
||||
)
|
||||
""")
|
||||
|
||||
tables = await conn.fetch("""
|
||||
SELECT tablename FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
|
@ -351,12 +342,9 @@ class APIConfig(BaseModel):
|
|||
CREATE TRIGGER update_version_and_server_id_trigger
|
||||
BEFORE INSERT OR UPDATE ON "{table_name}"
|
||||
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
|
||||
|
||||
INSERT INTO sync_status (table_name, server_id, last_synced_version)
|
||||
VALUES ('{table_name}', '{os.environ.get('TS_ID')}', 0)
|
||||
ON CONFLICT (table_name, server_id) DO NOTHING;
|
||||
""")
|
||||
|
||||
|
||||
async def get_most_recent_source(self):
|
||||
most_recent_source = None
|
||||
max_version = -1
|
||||
|
@ -368,7 +356,10 @@ class APIConfig(BaseModel):
|
|||
try:
|
||||
async with self.get_connection(pool_entry) as conn:
|
||||
version = await conn.fetchval("""
|
||||
SELECT COALESCE(MAX(last_synced_version), -1) FROM sync_status
|
||||
SELECT COALESCE(MAX(version), -1) FROM (
|
||||
SELECT MAX(version) as version FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
) as subquery
|
||||
""")
|
||||
if version > max_version:
|
||||
max_version = version
|
||||
|
@ -379,6 +370,7 @@ class APIConfig(BaseModel):
|
|||
return most_recent_source
|
||||
|
||||
|
||||
|
||||
async def pull_changes(self, source_pool_entry):
|
||||
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
|
||||
info("Skipping self-sync")
|
||||
|
@ -405,7 +397,7 @@ class APIConfig(BaseModel):
|
|||
|
||||
for table in tables:
|
||||
table_name = table['tablename']
|
||||
last_synced_version = await self.get_last_synced_version(table_name, source_id)
|
||||
last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
|
||||
|
||||
changes = await source_conn.fetch(f"""
|
||||
SELECT * FROM "{table_name}"
|
||||
|
@ -434,7 +426,7 @@ class APIConfig(BaseModel):
|
|||
inserts += 1
|
||||
|
||||
if changes:
|
||||
await self.update_sync_status(table_name, source_id, changes[-1]['version'])
|
||||
await self.update_last_synced_version(dest_conn, table_name, source_id, changes[-1]['version'])
|
||||
|
||||
total_inserts += inserts
|
||||
total_updates += updates
|
||||
|
@ -454,119 +446,6 @@ class APIConfig(BaseModel):
|
|||
|
||||
return total_inserts + total_updates
|
||||
|
||||
async def get_tables(self, conn):
|
||||
tables = await conn.fetch("""
|
||||
SELECT tablename FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
""")
|
||||
return [table['tablename'] for table in tables]
|
||||
|
||||
async def compare_table_structure(self, source_conn, dest_conn, table_name):
|
||||
source_columns = await self.get_table_structure(source_conn, table_name)
|
||||
dest_columns = await self.get_table_structure(dest_conn, table_name)
|
||||
|
||||
columns_only_in_source = set(source_columns.keys()) - set(dest_columns.keys())
|
||||
columns_only_in_dest = set(dest_columns.keys()) - set(source_columns.keys())
|
||||
common_columns = set(source_columns.keys()) & set(dest_columns.keys())
|
||||
|
||||
info(f"Table {table_name}:")
|
||||
info(f" Columns only in source: {columns_only_in_source}")
|
||||
info(f" Columns only in destination: {columns_only_in_dest}")
|
||||
info(f" Common columns: {common_columns}")
|
||||
|
||||
for col in common_columns:
|
||||
if source_columns[col] != dest_columns[col]:
|
||||
warn(f" Column {col} has different types: source={source_columns[col]}, dest={dest_columns[col]}")
|
||||
|
||||
async def get_table_structure(self, conn, table_name):
|
||||
columns = await conn.fetch("""
|
||||
SELECT column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = $1
|
||||
""", table_name)
|
||||
return {col['column_name']: col['data_type'] for col in columns}
|
||||
|
||||
async def compare_and_sync_data(self, source_conn, dest_conn, table_name, source_id):
|
||||
inserts = 0
|
||||
updates = 0
|
||||
error_count = 0
|
||||
|
||||
try:
|
||||
primary_keys = await self.get_primary_keys(dest_conn, table_name)
|
||||
if not primary_keys:
|
||||
warn(f"Table {table_name} has no primary keys. Using all columns for comparison.")
|
||||
columns = await self.get_table_columns(dest_conn, table_name)
|
||||
primary_keys = columns # Use all columns if no primary key
|
||||
|
||||
last_synced_version = await self.get_last_synced_version(table_name, source_id)
|
||||
|
||||
changes = await source_conn.fetch(f"""
|
||||
SELECT * FROM "{table_name}"
|
||||
WHERE version > $1 AND server_id = $2
|
||||
ORDER BY version ASC
|
||||
""", last_synced_version, source_id)
|
||||
|
||||
for change in changes:
|
||||
columns = list(change.keys())
|
||||
values = [change[col] for col in columns]
|
||||
|
||||
conflict_clause = f"({', '.join(primary_keys)})"
|
||||
update_clause = ', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in primary_keys)
|
||||
|
||||
insert_query = f"""
|
||||
INSERT INTO "{table_name}" ({', '.join(columns)})
|
||||
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
|
||||
ON CONFLICT {conflict_clause} DO UPDATE SET
|
||||
{update_clause}
|
||||
"""
|
||||
|
||||
try:
|
||||
result = await dest_conn.execute(insert_query, *values)
|
||||
if 'UPDATE' in result:
|
||||
updates += 1
|
||||
else:
|
||||
inserts += 1
|
||||
except Exception as e:
|
||||
if error_count < 10: # Limit error logging
|
||||
err(f"Error syncing data for table {table_name}: {str(e)}")
|
||||
error_count += 1
|
||||
elif error_count == 10:
|
||||
err(f"Suppressing further errors for table {table_name}")
|
||||
error_count += 1
|
||||
|
||||
if changes:
|
||||
await self.update_sync_status(table_name, source_id, changes[-1]['version'])
|
||||
|
||||
info(f"Synced {table_name}: {inserts} inserts, {updates} updates")
|
||||
if error_count > 10:
|
||||
info(f"Total of {error_count} errors occurred for table {table_name}")
|
||||
|
||||
except Exception as e:
|
||||
err(f"Error processing table {table_name}: {str(e)}")
|
||||
|
||||
return inserts, updates
|
||||
|
||||
async def get_table_columns(self, conn, table_name):
|
||||
columns = await conn.fetch("""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = $1
|
||||
ORDER BY ordinal_position
|
||||
""", table_name)
|
||||
return [col['column_name'] for col in columns]
|
||||
|
||||
async def get_primary_keys(self, conn, table_name):
|
||||
primary_keys = await conn.fetch("""
|
||||
SELECT a.attname
|
||||
FROM pg_index i
|
||||
JOIN pg_attribute a ON a.attrelid = i.indrelid
|
||||
AND a.attnum = ANY(i.indkey)
|
||||
WHERE i.indrelid = $1::regclass
|
||||
AND i.indisprimary
|
||||
""", table_name)
|
||||
return [pk['attname'] for pk in primary_keys]
|
||||
|
||||
|
||||
async def push_changes_to_all(self):
|
||||
for pool_entry in self.POOL:
|
||||
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
|
||||
|
@ -583,7 +462,7 @@ class APIConfig(BaseModel):
|
|||
|
||||
for table in tables:
|
||||
table_name = table['tablename']
|
||||
last_synced_version = await self.get_last_synced_version(table_name, pool_entry['ts_id'])
|
||||
last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
|
||||
|
||||
changes = await local_conn.fetch(f"""
|
||||
SELECT * FROM "{table_name}"
|
||||
|
@ -606,218 +485,37 @@ class APIConfig(BaseModel):
|
|||
await remote_conn.execute(insert_query, *values)
|
||||
|
||||
if changes:
|
||||
await self.update_sync_status(table_name, pool_entry['ts_id'], changes[-1]['version'])
|
||||
await self.update_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version'])
|
||||
|
||||
info(f"Successfully pushed changes to {pool_entry['ts_id']}")
|
||||
except Exception as e:
|
||||
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
|
||||
err(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def get_last_synced_version(self, conn, table_name, server_id):
|
||||
return await conn.fetchval(f"""
|
||||
SELECT COALESCE(MAX(version), 0)
|
||||
FROM "{table_name}"
|
||||
WHERE server_id = $1
|
||||
""", server_id)
|
||||
|
||||
async def get_last_synced_version(self, table_name, server_id):
|
||||
async with self.get_connection() as conn:
|
||||
return await conn.fetchval("""
|
||||
SELECT last_synced_version FROM sync_status
|
||||
WHERE table_name = $1 AND server_id = $2
|
||||
""", table_name, server_id) or 0
|
||||
|
||||
|
||||
async def update_sync_status(self, table_name, server_id, version):
|
||||
async with self.get_connection() as conn:
|
||||
await conn.execute("""
|
||||
INSERT INTO sync_status (table_name, server_id, last_synced_version)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (table_name, server_id) DO UPDATE
|
||||
SET last_synced_version = EXCLUDED.last_synced_version
|
||||
""", table_name, server_id, version)
|
||||
|
||||
|
||||
async def sync_schema(self):
|
||||
local_schema_version = await self.get_schema_version(self.local_db)
|
||||
for pool_entry in self.POOL:
|
||||
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
|
||||
remote_schema_version = await self.get_schema_version(pool_entry)
|
||||
if remote_schema_version != local_schema_version:
|
||||
await self.apply_schema_changes(pool_entry)
|
||||
|
||||
|
||||
async def get_schema(self, pool_entry: Dict[str, Any]):
|
||||
async with self.get_connection(pool_entry) as conn:
|
||||
tables = await conn.fetch("""
|
||||
SELECT table_name, column_name, data_type, character_maximum_length,
|
||||
is_nullable, column_default, ordinal_position
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
ORDER BY table_name, ordinal_position
|
||||
""")
|
||||
|
||||
indexes = await conn.fetch("""
|
||||
SELECT indexname, indexdef
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'public'
|
||||
""")
|
||||
|
||||
constraints = await conn.fetch("""
|
||||
SELECT conname, contype, conrelid::regclass::text as table_name,
|
||||
pg_get_constraintdef(oid) as definition
|
||||
FROM pg_constraint
|
||||
WHERE connamespace = 'public'::regnamespace
|
||||
""")
|
||||
|
||||
return {
|
||||
'tables': tables,
|
||||
'indexes': indexes,
|
||||
'constraints': constraints
|
||||
}
|
||||
|
||||
|
||||
async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
|
||||
async with self.get_connection(pool_entry) as conn:
|
||||
# Check schema version
|
||||
source_version = await self.get_schema_version(self.local_db)
|
||||
target_version = await self.get_schema_version(pool_entry)
|
||||
if source_version == target_version:
|
||||
info(f"Schema versions match for {pool_entry['ts_ip']}. Skipping synchronization.")
|
||||
return
|
||||
|
||||
source_tables = {t['table_name']: t for t in source_schema['tables']}
|
||||
target_tables = {t['table_name']: t for t in target_schema['tables']}
|
||||
|
||||
def get_column_type(data_type):
|
||||
if data_type == 'ARRAY':
|
||||
return 'text[]'
|
||||
elif data_type == 'USER-DEFINED':
|
||||
return 'geometry'
|
||||
else:
|
||||
return data_type
|
||||
|
||||
for table_name, source_table in source_tables.items():
|
||||
try:
|
||||
if table_name not in target_tables:
|
||||
columns = []
|
||||
for t in source_schema['tables']:
|
||||
if t['table_name'] == table_name:
|
||||
col_type = get_column_type(t['data_type'])
|
||||
col_def = f"\"{t['column_name']}\" {col_type}"
|
||||
if t['character_maximum_length']:
|
||||
col_def += f"({t['character_maximum_length']})"
|
||||
if t['is_nullable'] == 'NO':
|
||||
col_def += " NOT NULL"
|
||||
if t['column_default']:
|
||||
if 'nextval' in t['column_default']:
|
||||
sequence_name = t['column_default'].split("'")[1]
|
||||
await self.create_sequence_if_not_exists(conn, sequence_name)
|
||||
col_def += f" DEFAULT {t['column_default']}"
|
||||
columns.append(col_def)
|
||||
|
||||
primary_key_constraint = next(
|
||||
(con['definition'] for con in source_schema['constraints'] if con['table_name'] == table_name and con['contype'] == 'p'),
|
||||
None
|
||||
)
|
||||
|
||||
sql = f'CREATE TABLE "{table_name}" ({", ".join(columns)}'
|
||||
if primary_key_constraint:
|
||||
sql += f', {primary_key_constraint}'
|
||||
sql += ')'
|
||||
|
||||
info(f"Executing SQL: {sql}")
|
||||
await conn.execute(sql)
|
||||
else:
|
||||
target_table = target_tables[table_name]
|
||||
source_columns = {t['column_name']: t for t in source_schema['tables'] if t['table_name'] == table_name}
|
||||
target_columns = {t['column_name']: t for t in target_schema['tables'] if t['table_name'] == table_name}
|
||||
|
||||
for col_name, source_col in source_columns.items():
|
||||
if col_name not in target_columns:
|
||||
col_type = get_column_type(source_col['data_type'])
|
||||
col_def = f"\"{col_name}\" {col_type}" + \
|
||||
(f"({source_col['character_maximum_length']})" if source_col['character_maximum_length'] else "") + \
|
||||
(" NOT NULL" if source_col['is_nullable'] == 'NO' else "") + \
|
||||
(f" DEFAULT {source_col['column_default']}" if source_col['column_default'] else "")
|
||||
sql = f'ALTER TABLE "{table_name}" ADD COLUMN {col_def}'
|
||||
debug(f"Executing SQL: {sql}")
|
||||
await conn.execute(sql)
|
||||
else:
|
||||
target_col = target_columns[col_name]
|
||||
if source_col != target_col:
|
||||
col_type = get_column_type(source_col['data_type'])
|
||||
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" TYPE {col_type}'
|
||||
debug(f"Executing SQL: {sql}")
|
||||
await conn.execute(sql)
|
||||
if source_col['is_nullable'] != target_col['is_nullable']:
|
||||
null_constraint = "DROP NOT NULL" if source_col['is_nullable'] == 'YES' else "SET NOT NULL"
|
||||
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {null_constraint}'
|
||||
debug(f"Executing SQL: {sql}")
|
||||
await conn.execute(sql)
|
||||
if source_col['column_default'] != target_col['column_default']:
|
||||
default_clause = f"SET DEFAULT {source_col['column_default']}" if source_col['column_default'] else "DROP DEFAULT"
|
||||
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {default_clause}'
|
||||
debug(f"Executing SQL: {sql}")
|
||||
await conn.execute(sql)
|
||||
|
||||
# Ensure primary key constraint exists
|
||||
primary_key_constraint = next(
|
||||
(con['definition'] for con in source_schema['constraints'] if con['table_name'] == table_name and con['contype'] == 'p'),
|
||||
None
|
||||
)
|
||||
if primary_key_constraint:
|
||||
constraint_name = f"{table_name}_pkey"
|
||||
constraint_exists = await conn.fetchval(f"""
|
||||
SELECT 1 FROM pg_constraint
|
||||
WHERE conname = '{constraint_name}'
|
||||
""")
|
||||
if not constraint_exists:
|
||||
sql = f'ALTER TABLE "{table_name}" ADD CONSTRAINT {constraint_name} {primary_key_constraint}'
|
||||
debug(f"Executing SQL: {sql}")
|
||||
await conn.execute(sql)
|
||||
except Exception as e:
|
||||
err(f"Error processing table {table_name}: {str(e)}")
|
||||
|
||||
try:
|
||||
source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']}
|
||||
target_indexes = {idx['indexname']: idx['indexdef'] for idx in target_schema['indexes']}
|
||||
|
||||
for idx_name, idx_def in source_indexes.items():
|
||||
if idx_name not in target_indexes:
|
||||
debug(f"Executing SQL: {idx_def}")
|
||||
await conn.execute(idx_def)
|
||||
elif idx_def != target_indexes[idx_name]:
|
||||
sql = f'DROP INDEX IF EXISTS "{idx_name}"'
|
||||
debug(f"Executing SQL: {sql}")
|
||||
await conn.execute(sql)
|
||||
debug(f"Executing SQL: {idx_def}")
|
||||
await conn.execute(idx_def)
|
||||
except Exception as e:
|
||||
err(f"Error processing indexes: {str(e)}")
|
||||
|
||||
try:
|
||||
source_constraints = {con['conname']: con for con in source_schema['constraints']}
|
||||
target_constraints = {con['conname']: con for con in target_schema['constraints']}
|
||||
|
||||
for con_name, source_con in source_constraints.items():
|
||||
if con_name not in target_constraints:
|
||||
sql = f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}'
|
||||
debug(f"Executing SQL: {sql}")
|
||||
await conn.execute(sql)
|
||||
elif source_con != target_constraints[con_name]:
|
||||
sql = f'ALTER TABLE "{source_con["table_name"]}" DROP CONSTRAINT IF EXISTS "{con_name}"'
|
||||
debug(f"Executing SQL: {sql}")
|
||||
await conn.execute(sql)
|
||||
sql = f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}'
|
||||
debug(f"Executing SQL: {sql}")
|
||||
await conn.execute(sql)
|
||||
except Exception as e:
|
||||
err(f"Error processing constraints: {str(e)}")
|
||||
|
||||
# Update schema version
|
||||
await conn.execute("UPDATE schema_version SET version = $1", source_version)
|
||||
info(f"Schema synchronization completed for {pool_entry['ts_ip']}")
|
||||
|
||||
async def update_last_synced_version(self, conn, table_name, server_id, version):
|
||||
await conn.execute(f"""
|
||||
INSERT INTO "{table_name}" (server_id, version)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT (server_id) DO UPDATE
|
||||
SET version = EXCLUDED.version
|
||||
WHERE "{table_name}".version < EXCLUDED.version
|
||||
""", server_id, version)
|
||||
|
||||
async def get_schema_version(self, pool_entry):
|
||||
async with self.get_connection(pool_entry) as conn:
|
||||
return await conn.fetchval("SELECT version FROM schema_version")
|
||||
|
||||
return await conn.fetchval("""
|
||||
SELECT COALESCE(MAX(version), 0) FROM (
|
||||
SELECT MAX(version) as version FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
) as subquery
|
||||
""")
|
||||
|
||||
async def create_sequence_if_not_exists(self, conn, sequence_name):
|
||||
await conn.execute(f"""
|
||||
|
@ -831,6 +529,7 @@ class APIConfig(BaseModel):
|
|||
|
||||
|
||||
|
||||
|
||||
class Location(BaseModel):
|
||||
latitude: float
|
||||
longitude: float
|
||||
|
|
Loading…
Reference in a new issue