Auto-update: Mon Jul 29 19:27:49 PDT 2024

This commit is contained in:
sanj 2024-07-29 19:27:49 -07:00
parent 05400d4fa4
commit cda6481a97
2 changed files with 33 additions and 338 deletions

View file

@ -62,11 +62,7 @@ async def lifespan(app: FastAPI):
# Initialize sync structures # Initialize sync structures
await API.initialize_sync() await API.initialize_sync()
# Sync schema across all databases # Now that tables are initialized, check for the most recent source
await API.sync_schema()
crit("Schema synchronization complete.")
# Check if other instances have more recent data
source = await API.get_most_recent_source() source = await API.get_most_recent_source()
if source: if source:
crit(f"Pulling changes from {source['ts_id']} ({source['ts_ip']})...") crit(f"Pulling changes from {source['ts_id']} ({source['ts_ip']})...")
@ -75,7 +71,6 @@ async def lifespan(app: FastAPI):
else: else:
crit("No instances with more recent data found.") crit("No instances with more recent data found.")
except Exception as e: except Exception as e:
crit(f"Error during startup: {str(e)}") crit(f"Error during startup: {str(e)}")
crit(f"Traceback: {traceback.format_exc()}") crit(f"Traceback: {traceback.format_exc()}")
@ -86,6 +81,7 @@ async def lifespan(app: FastAPI):
crit("Shutting down...") crit("Shutting down...")
# Perform any cleanup operations here if needed # Perform any cleanup operations here if needed
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware( app.add_middleware(

View file

@ -317,15 +317,6 @@ class APIConfig(BaseModel):
async def initialize_sync(self): async def initialize_sync(self):
async with self.get_connection() as conn: async with self.get_connection() as conn:
await conn.execute("""
CREATE TABLE IF NOT EXISTS sync_status (
table_name TEXT,
server_id TEXT,
last_synced_version INTEGER,
PRIMARY KEY (table_name, server_id)
)
""")
tables = await conn.fetch(""" tables = await conn.fetch("""
SELECT tablename FROM pg_tables SELECT tablename FROM pg_tables
WHERE schemaname = 'public' WHERE schemaname = 'public'
@ -351,12 +342,9 @@ class APIConfig(BaseModel):
CREATE TRIGGER update_version_and_server_id_trigger CREATE TRIGGER update_version_and_server_id_trigger
BEFORE INSERT OR UPDATE ON "{table_name}" BEFORE INSERT OR UPDATE ON "{table_name}"
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id(); FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
INSERT INTO sync_status (table_name, server_id, last_synced_version)
VALUES ('{table_name}', '{os.environ.get('TS_ID')}', 0)
ON CONFLICT (table_name, server_id) DO NOTHING;
""") """)
async def get_most_recent_source(self): async def get_most_recent_source(self):
most_recent_source = None most_recent_source = None
max_version = -1 max_version = -1
@ -368,7 +356,10 @@ class APIConfig(BaseModel):
try: try:
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
version = await conn.fetchval(""" version = await conn.fetchval("""
SELECT COALESCE(MAX(last_synced_version), -1) FROM sync_status SELECT COALESCE(MAX(version), -1) FROM (
SELECT MAX(version) as version FROM pg_tables
WHERE schemaname = 'public'
) as subquery
""") """)
if version > max_version: if version > max_version:
max_version = version max_version = version
@ -379,6 +370,7 @@ class APIConfig(BaseModel):
return most_recent_source return most_recent_source
async def pull_changes(self, source_pool_entry): async def pull_changes(self, source_pool_entry):
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
info("Skipping self-sync") info("Skipping self-sync")
@ -405,7 +397,7 @@ class APIConfig(BaseModel):
for table in tables: for table in tables:
table_name = table['tablename'] table_name = table['tablename']
last_synced_version = await self.get_last_synced_version(table_name, source_id) last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
changes = await source_conn.fetch(f""" changes = await source_conn.fetch(f"""
SELECT * FROM "{table_name}" SELECT * FROM "{table_name}"
@ -434,7 +426,7 @@ class APIConfig(BaseModel):
inserts += 1 inserts += 1
if changes: if changes:
await self.update_sync_status(table_name, source_id, changes[-1]['version']) await self.update_last_synced_version(dest_conn, table_name, source_id, changes[-1]['version'])
total_inserts += inserts total_inserts += inserts
total_updates += updates total_updates += updates
@ -454,119 +446,6 @@ class APIConfig(BaseModel):
return total_inserts + total_updates return total_inserts + total_updates
async def get_tables(self, conn):
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
return [table['tablename'] for table in tables]
async def compare_table_structure(self, source_conn, dest_conn, table_name):
source_columns = await self.get_table_structure(source_conn, table_name)
dest_columns = await self.get_table_structure(dest_conn, table_name)
columns_only_in_source = set(source_columns.keys()) - set(dest_columns.keys())
columns_only_in_dest = set(dest_columns.keys()) - set(source_columns.keys())
common_columns = set(source_columns.keys()) & set(dest_columns.keys())
info(f"Table {table_name}:")
info(f" Columns only in source: {columns_only_in_source}")
info(f" Columns only in destination: {columns_only_in_dest}")
info(f" Common columns: {common_columns}")
for col in common_columns:
if source_columns[col] != dest_columns[col]:
warn(f" Column {col} has different types: source={source_columns[col]}, dest={dest_columns[col]}")
async def get_table_structure(self, conn, table_name):
columns = await conn.fetch("""
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_name = $1
""", table_name)
return {col['column_name']: col['data_type'] for col in columns}
async def compare_and_sync_data(self, source_conn, dest_conn, table_name, source_id):
inserts = 0
updates = 0
error_count = 0
try:
primary_keys = await self.get_primary_keys(dest_conn, table_name)
if not primary_keys:
warn(f"Table {table_name} has no primary keys. Using all columns for comparison.")
columns = await self.get_table_columns(dest_conn, table_name)
primary_keys = columns # Use all columns if no primary key
last_synced_version = await self.get_last_synced_version(table_name, source_id)
changes = await source_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2
ORDER BY version ASC
""", last_synced_version, source_id)
for change in changes:
columns = list(change.keys())
values = [change[col] for col in columns]
conflict_clause = f"({', '.join(primary_keys)})"
update_clause = ', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in primary_keys)
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(columns)})
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
ON CONFLICT {conflict_clause} DO UPDATE SET
{update_clause}
"""
try:
result = await dest_conn.execute(insert_query, *values)
if 'UPDATE' in result:
updates += 1
else:
inserts += 1
except Exception as e:
if error_count < 10: # Limit error logging
err(f"Error syncing data for table {table_name}: {str(e)}")
error_count += 1
elif error_count == 10:
err(f"Suppressing further errors for table {table_name}")
error_count += 1
if changes:
await self.update_sync_status(table_name, source_id, changes[-1]['version'])
info(f"Synced {table_name}: {inserts} inserts, {updates} updates")
if error_count > 10:
info(f"Total of {error_count} errors occurred for table {table_name}")
except Exception as e:
err(f"Error processing table {table_name}: {str(e)}")
return inserts, updates
async def get_table_columns(self, conn, table_name):
columns = await conn.fetch("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = $1
ORDER BY ordinal_position
""", table_name)
return [col['column_name'] for col in columns]
async def get_primary_keys(self, conn, table_name):
primary_keys = await conn.fetch("""
SELECT a.attname
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid
AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = $1::regclass
AND i.indisprimary
""", table_name)
return [pk['attname'] for pk in primary_keys]
async def push_changes_to_all(self): async def push_changes_to_all(self):
for pool_entry in self.POOL: for pool_entry in self.POOL:
if pool_entry['ts_id'] != os.environ.get('TS_ID'): if pool_entry['ts_id'] != os.environ.get('TS_ID'):
@ -583,7 +462,7 @@ class APIConfig(BaseModel):
for table in tables: for table in tables:
table_name = table['tablename'] table_name = table['tablename']
last_synced_version = await self.get_last_synced_version(table_name, pool_entry['ts_id']) last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
changes = await local_conn.fetch(f""" changes = await local_conn.fetch(f"""
SELECT * FROM "{table_name}" SELECT * FROM "{table_name}"
@ -606,218 +485,37 @@ class APIConfig(BaseModel):
await remote_conn.execute(insert_query, *values) await remote_conn.execute(insert_query, *values)
if changes: if changes:
await self.update_sync_status(table_name, pool_entry['ts_id'], changes[-1]['version']) await self.update_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version'])
info(f"Successfully pushed changes to {pool_entry['ts_id']}") info(f"Successfully pushed changes to {pool_entry['ts_id']}")
except Exception as e: except Exception as e:
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
async def get_last_synced_version(self, conn, table_name, server_id):
return await conn.fetchval(f"""
SELECT COALESCE(MAX(version), 0)
FROM "{table_name}"
WHERE server_id = $1
""", server_id)
async def get_last_synced_version(self, table_name, server_id): async def update_last_synced_version(self, conn, table_name, server_id, version):
async with self.get_connection() as conn: await conn.execute(f"""
return await conn.fetchval(""" INSERT INTO "{table_name}" (server_id, version)
SELECT last_synced_version FROM sync_status VALUES ($1, $2)
WHERE table_name = $1 AND server_id = $2 ON CONFLICT (server_id) DO UPDATE
""", table_name, server_id) or 0 SET version = EXCLUDED.version
WHERE "{table_name}".version < EXCLUDED.version
""", server_id, version)
async def update_sync_status(self, table_name, server_id, version):
async with self.get_connection() as conn:
await conn.execute("""
INSERT INTO sync_status (table_name, server_id, last_synced_version)
VALUES ($1, $2, $3)
ON CONFLICT (table_name, server_id) DO UPDATE
SET last_synced_version = EXCLUDED.last_synced_version
""", table_name, server_id, version)
async def sync_schema(self):
local_schema_version = await self.get_schema_version(self.local_db)
for pool_entry in self.POOL:
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
remote_schema_version = await self.get_schema_version(pool_entry)
if remote_schema_version != local_schema_version:
await self.apply_schema_changes(pool_entry)
async def get_schema(self, pool_entry: Dict[str, Any]):
async with self.get_connection(pool_entry) as conn:
tables = await conn.fetch("""
SELECT table_name, column_name, data_type, character_maximum_length,
is_nullable, column_default, ordinal_position
FROM information_schema.columns
WHERE table_schema = 'public'
ORDER BY table_name, ordinal_position
""")
indexes = await conn.fetch("""
SELECT indexname, indexdef
FROM pg_indexes
WHERE schemaname = 'public'
""")
constraints = await conn.fetch("""
SELECT conname, contype, conrelid::regclass::text as table_name,
pg_get_constraintdef(oid) as definition
FROM pg_constraint
WHERE connamespace = 'public'::regnamespace
""")
return {
'tables': tables,
'indexes': indexes,
'constraints': constraints
}
async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
async with self.get_connection(pool_entry) as conn:
# Check schema version
source_version = await self.get_schema_version(self.local_db)
target_version = await self.get_schema_version(pool_entry)
if source_version == target_version:
info(f"Schema versions match for {pool_entry['ts_ip']}. Skipping synchronization.")
return
source_tables = {t['table_name']: t for t in source_schema['tables']}
target_tables = {t['table_name']: t for t in target_schema['tables']}
def get_column_type(data_type):
if data_type == 'ARRAY':
return 'text[]'
elif data_type == 'USER-DEFINED':
return 'geometry'
else:
return data_type
for table_name, source_table in source_tables.items():
try:
if table_name not in target_tables:
columns = []
for t in source_schema['tables']:
if t['table_name'] == table_name:
col_type = get_column_type(t['data_type'])
col_def = f"\"{t['column_name']}\" {col_type}"
if t['character_maximum_length']:
col_def += f"({t['character_maximum_length']})"
if t['is_nullable'] == 'NO':
col_def += " NOT NULL"
if t['column_default']:
if 'nextval' in t['column_default']:
sequence_name = t['column_default'].split("'")[1]
await self.create_sequence_if_not_exists(conn, sequence_name)
col_def += f" DEFAULT {t['column_default']}"
columns.append(col_def)
primary_key_constraint = next(
(con['definition'] for con in source_schema['constraints'] if con['table_name'] == table_name and con['contype'] == 'p'),
None
)
sql = f'CREATE TABLE "{table_name}" ({", ".join(columns)}'
if primary_key_constraint:
sql += f', {primary_key_constraint}'
sql += ')'
info(f"Executing SQL: {sql}")
await conn.execute(sql)
else:
target_table = target_tables[table_name]
source_columns = {t['column_name']: t for t in source_schema['tables'] if t['table_name'] == table_name}
target_columns = {t['column_name']: t for t in target_schema['tables'] if t['table_name'] == table_name}
for col_name, source_col in source_columns.items():
if col_name not in target_columns:
col_type = get_column_type(source_col['data_type'])
col_def = f"\"{col_name}\" {col_type}" + \
(f"({source_col['character_maximum_length']})" if source_col['character_maximum_length'] else "") + \
(" NOT NULL" if source_col['is_nullable'] == 'NO' else "") + \
(f" DEFAULT {source_col['column_default']}" if source_col['column_default'] else "")
sql = f'ALTER TABLE "{table_name}" ADD COLUMN {col_def}'
debug(f"Executing SQL: {sql}")
await conn.execute(sql)
else:
target_col = target_columns[col_name]
if source_col != target_col:
col_type = get_column_type(source_col['data_type'])
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" TYPE {col_type}'
debug(f"Executing SQL: {sql}")
await conn.execute(sql)
if source_col['is_nullable'] != target_col['is_nullable']:
null_constraint = "DROP NOT NULL" if source_col['is_nullable'] == 'YES' else "SET NOT NULL"
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {null_constraint}'
debug(f"Executing SQL: {sql}")
await conn.execute(sql)
if source_col['column_default'] != target_col['column_default']:
default_clause = f"SET DEFAULT {source_col['column_default']}" if source_col['column_default'] else "DROP DEFAULT"
sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {default_clause}'
debug(f"Executing SQL: {sql}")
await conn.execute(sql)
# Ensure primary key constraint exists
primary_key_constraint = next(
(con['definition'] for con in source_schema['constraints'] if con['table_name'] == table_name and con['contype'] == 'p'),
None
)
if primary_key_constraint:
constraint_name = f"{table_name}_pkey"
constraint_exists = await conn.fetchval(f"""
SELECT 1 FROM pg_constraint
WHERE conname = '{constraint_name}'
""")
if not constraint_exists:
sql = f'ALTER TABLE "{table_name}" ADD CONSTRAINT {constraint_name} {primary_key_constraint}'
debug(f"Executing SQL: {sql}")
await conn.execute(sql)
except Exception as e:
err(f"Error processing table {table_name}: {str(e)}")
try:
source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']}
target_indexes = {idx['indexname']: idx['indexdef'] for idx in target_schema['indexes']}
for idx_name, idx_def in source_indexes.items():
if idx_name not in target_indexes:
debug(f"Executing SQL: {idx_def}")
await conn.execute(idx_def)
elif idx_def != target_indexes[idx_name]:
sql = f'DROP INDEX IF EXISTS "{idx_name}"'
debug(f"Executing SQL: {sql}")
await conn.execute(sql)
debug(f"Executing SQL: {idx_def}")
await conn.execute(idx_def)
except Exception as e:
err(f"Error processing indexes: {str(e)}")
try:
source_constraints = {con['conname']: con for con in source_schema['constraints']}
target_constraints = {con['conname']: con for con in target_schema['constraints']}
for con_name, source_con in source_constraints.items():
if con_name not in target_constraints:
sql = f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}'
debug(f"Executing SQL: {sql}")
await conn.execute(sql)
elif source_con != target_constraints[con_name]:
sql = f'ALTER TABLE "{source_con["table_name"]}" DROP CONSTRAINT IF EXISTS "{con_name}"'
debug(f"Executing SQL: {sql}")
await conn.execute(sql)
sql = f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}'
debug(f"Executing SQL: {sql}")
await conn.execute(sql)
except Exception as e:
err(f"Error processing constraints: {str(e)}")
# Update schema version
await conn.execute("UPDATE schema_version SET version = $1", source_version)
info(f"Schema synchronization completed for {pool_entry['ts_ip']}")
async def get_schema_version(self, pool_entry): async def get_schema_version(self, pool_entry):
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
return await conn.fetchval("SELECT version FROM schema_version") return await conn.fetchval("""
SELECT COALESCE(MAX(version), 0) FROM (
SELECT MAX(version) as version FROM pg_tables
WHERE schemaname = 'public'
) as subquery
""")
async def create_sequence_if_not_exists(self, conn, sequence_name): async def create_sequence_if_not_exists(self, conn, sequence_name):
await conn.execute(f""" await conn.execute(f"""
@ -831,6 +529,7 @@ class APIConfig(BaseModel):
class Location(BaseModel): class Location(BaseModel):
latitude: float latitude: float
longitude: float longitude: float