diff --git a/sijapi/__main__.py b/sijapi/__main__.py index caea1c1..c70c4f6 100755 --- a/sijapi/__main__.py +++ b/sijapi/__main__.py @@ -62,11 +62,7 @@ async def lifespan(app: FastAPI): # Initialize sync structures await API.initialize_sync() - # Sync schema across all databases - await API.sync_schema() - crit("Schema synchronization complete.") - - # Check if other instances have more recent data + # Now that tables are initialized, check for the most recent source source = await API.get_most_recent_source() if source: crit(f"Pulling changes from {source['ts_id']} ({source['ts_ip']})...") @@ -75,7 +71,6 @@ async def lifespan(app: FastAPI): else: crit("No instances with more recent data found.") - except Exception as e: crit(f"Error during startup: {str(e)}") crit(f"Traceback: {traceback.format_exc()}") @@ -86,6 +81,7 @@ async def lifespan(app: FastAPI): crit("Shutting down...") # Perform any cleanup operations here if needed + app = FastAPI(lifespan=lifespan) app.add_middleware( diff --git a/sijapi/classes.py b/sijapi/classes.py index f61e60c..85db29b 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -317,15 +317,6 @@ class APIConfig(BaseModel): async def initialize_sync(self): async with self.get_connection() as conn: - await conn.execute(""" - CREATE TABLE IF NOT EXISTS sync_status ( - table_name TEXT, - server_id TEXT, - last_synced_version INTEGER, - PRIMARY KEY (table_name, server_id) - ) - """) - tables = await conn.fetch(""" SELECT tablename FROM pg_tables WHERE schemaname = 'public' @@ -351,12 +342,9 @@ class APIConfig(BaseModel): CREATE TRIGGER update_version_and_server_id_trigger BEFORE INSERT OR UPDATE ON "{table_name}" FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id(); - - INSERT INTO sync_status (table_name, server_id, last_synced_version) - VALUES ('{table_name}', '{os.environ.get('TS_ID')}', 0) - ON CONFLICT (table_name, server_id) DO NOTHING; """) + async def get_most_recent_source(self): most_recent_source = None max_version = -1 @@ -368,7 +356,10 @@ class APIConfig(BaseModel): try: async with self.get_connection(pool_entry) as conn: version = await conn.fetchval(""" - SELECT COALESCE(MAX(last_synced_version), -1) FROM sync_status + SELECT COALESCE(MAX(version), -1) FROM ( + SELECT MAX(version) as version FROM pg_tables + WHERE schemaname = 'public' + ) as subquery """) if version > max_version: max_version = version @@ -379,6 +370,7 @@ class APIConfig(BaseModel): return most_recent_source + async def pull_changes(self, source_pool_entry): if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): info("Skipping self-sync") @@ -405,7 +397,7 @@ class APIConfig(BaseModel): for table in tables: table_name = table['tablename'] - last_synced_version = await self.get_last_synced_version(table_name, source_id) + last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id) changes = await source_conn.fetch(f""" SELECT * FROM "{table_name}" @@ -434,7 +426,7 @@ class APIConfig(BaseModel): inserts += 1 if changes: - await self.update_sync_status(table_name, source_id, changes[-1]['version']) + await self.update_last_synced_version(dest_conn, table_name, source_id, changes[-1]['version']) total_inserts += inserts total_updates += updates @@ -454,119 +446,6 @@ class APIConfig(BaseModel): return total_inserts + total_updates - async def get_tables(self, conn): - tables = await conn.fetch(""" - SELECT tablename FROM pg_tables - WHERE schemaname = 'public' - """) - return [table['tablename'] for table in tables] - - async def compare_table_structure(self, source_conn, dest_conn, table_name): - source_columns = await self.get_table_structure(source_conn, table_name) - dest_columns = await self.get_table_structure(dest_conn, table_name) - - columns_only_in_source = set(source_columns.keys()) - set(dest_columns.keys()) - columns_only_in_dest = set(dest_columns.keys()) - set(source_columns.keys()) - common_columns = set(source_columns.keys()) & set(dest_columns.keys()) - - info(f"Table {table_name}:") - info(f" Columns only in source: {columns_only_in_source}") - info(f" Columns only in destination: {columns_only_in_dest}") - info(f" Common columns: {common_columns}") - - for col in common_columns: - if source_columns[col] != dest_columns[col]: - warn(f" Column {col} has different types: source={source_columns[col]}, dest={dest_columns[col]}") - - async def get_table_structure(self, conn, table_name): - columns = await conn.fetch(""" - SELECT column_name, data_type - FROM information_schema.columns - WHERE table_name = $1 - """, table_name) - return {col['column_name']: col['data_type'] for col in columns} - - async def compare_and_sync_data(self, source_conn, dest_conn, table_name, source_id): - inserts = 0 - updates = 0 - error_count = 0 - - try: - primary_keys = await self.get_primary_keys(dest_conn, table_name) - if not primary_keys: - warn(f"Table {table_name} has no primary keys. Using all columns for comparison.") - columns = await self.get_table_columns(dest_conn, table_name) - primary_keys = columns # Use all columns if no primary key - - last_synced_version = await self.get_last_synced_version(table_name, source_id) - - changes = await source_conn.fetch(f""" - SELECT * FROM "{table_name}" - WHERE version > $1 AND server_id = $2 - ORDER BY version ASC - """, last_synced_version, source_id) - - for change in changes: - columns = list(change.keys()) - values = [change[col] for col in columns] - - conflict_clause = f"({', '.join(primary_keys)})" - update_clause = ', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in primary_keys) - - insert_query = f""" - INSERT INTO "{table_name}" ({', '.join(columns)}) - VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))}) - ON CONFLICT {conflict_clause} DO UPDATE SET - {update_clause} - """ - - try: - result = await dest_conn.execute(insert_query, *values) - if 'UPDATE' in result: - updates += 1 - else: - inserts += 1 - except Exception as e: - if error_count < 10: # Limit error logging - err(f"Error syncing data for table {table_name}: {str(e)}") - error_count += 1 - elif error_count == 10: - err(f"Suppressing further errors for table {table_name}") - error_count += 1 - - if changes: - await self.update_sync_status(table_name, source_id, changes[-1]['version']) - - info(f"Synced {table_name}: {inserts} inserts, {updates} updates") - if error_count > 10: - info(f"Total of {error_count} errors occurred for table {table_name}") - - except Exception as e: - err(f"Error processing table {table_name}: {str(e)}") - - return inserts, updates - - async def get_table_columns(self, conn, table_name): - columns = await conn.fetch(""" - SELECT column_name - FROM information_schema.columns - WHERE table_name = $1 - ORDER BY ordinal_position - """, table_name) - return [col['column_name'] for col in columns] - - async def get_primary_keys(self, conn, table_name): - primary_keys = await conn.fetch(""" - SELECT a.attname - FROM pg_index i - JOIN pg_attribute a ON a.attrelid = i.indrelid - AND a.attnum = ANY(i.indkey) - WHERE i.indrelid = $1::regclass - AND i.indisprimary - """, table_name) - return [pk['attname'] for pk in primary_keys] - - async def push_changes_to_all(self): for pool_entry in self.POOL: if pool_entry['ts_id'] != os.environ.get('TS_ID'): @@ -583,7 +462,7 @@ class APIConfig(BaseModel): for table in tables: table_name = table['tablename'] - last_synced_version = await self.get_last_synced_version(table_name, pool_entry['ts_id']) + last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID')) changes = await local_conn.fetch(f""" SELECT * FROM "{table_name}" @@ -606,218 +485,37 @@ class APIConfig(BaseModel): await remote_conn.execute(insert_query, *values) if changes: - await self.update_sync_status(table_name, pool_entry['ts_id'], changes[-1]['version']) + await self.update_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version']) info(f"Successfully pushed changes to {pool_entry['ts_id']}") except Exception as e: err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") + async def get_last_synced_version(self, conn, table_name, server_id): + return await conn.fetchval(f""" + SELECT COALESCE(MAX(version), 0) + FROM "{table_name}" + WHERE server_id = $1 + """, server_id) - async def get_last_synced_version(self, table_name, server_id): - async with self.get_connection() as conn: - return await conn.fetchval(""" - SELECT last_synced_version FROM sync_status - WHERE table_name = $1 AND server_id = $2 - """, table_name, server_id) or 0 - - - async def update_sync_status(self, table_name, server_id, version): - async with self.get_connection() as conn: - await conn.execute(""" - INSERT INTO sync_status (table_name, server_id, last_synced_version) - VALUES ($1, $2, $3) - ON CONFLICT (table_name, server_id) DO UPDATE - SET last_synced_version = EXCLUDED.last_synced_version - """, table_name, server_id, version) - - - async def sync_schema(self): - local_schema_version = await self.get_schema_version(self.local_db) - for pool_entry in self.POOL: - if pool_entry['ts_id'] != os.environ.get('TS_ID'): - remote_schema_version = await self.get_schema_version(pool_entry) - if remote_schema_version != local_schema_version: - await self.apply_schema_changes(pool_entry) - - - async def get_schema(self, pool_entry: Dict[str, Any]): - async with self.get_connection(pool_entry) as conn: - tables = await conn.fetch(""" - SELECT table_name, column_name, data_type, character_maximum_length, - is_nullable, column_default, ordinal_position - FROM information_schema.columns - WHERE table_schema = 'public' - ORDER BY table_name, ordinal_position - """) - - indexes = await conn.fetch(""" - SELECT indexname, indexdef - FROM pg_indexes - WHERE schemaname = 'public' - """) - - constraints = await conn.fetch(""" - SELECT conname, contype, conrelid::regclass::text as table_name, - pg_get_constraintdef(oid) as definition - FROM pg_constraint - WHERE connamespace = 'public'::regnamespace - """) - - return { - 'tables': tables, - 'indexes': indexes, - 'constraints': constraints - } - - - async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema): - async with self.get_connection(pool_entry) as conn: - # Check schema version - source_version = await self.get_schema_version(self.local_db) - target_version = await self.get_schema_version(pool_entry) - if source_version == target_version: - info(f"Schema versions match for {pool_entry['ts_ip']}. Skipping synchronization.") - return - - source_tables = {t['table_name']: t for t in source_schema['tables']} - target_tables = {t['table_name']: t for t in target_schema['tables']} - - def get_column_type(data_type): - if data_type == 'ARRAY': - return 'text[]' - elif data_type == 'USER-DEFINED': - return 'geometry' - else: - return data_type - - for table_name, source_table in source_tables.items(): - try: - if table_name not in target_tables: - columns = [] - for t in source_schema['tables']: - if t['table_name'] == table_name: - col_type = get_column_type(t['data_type']) - col_def = f"\"{t['column_name']}\" {col_type}" - if t['character_maximum_length']: - col_def += f"({t['character_maximum_length']})" - if t['is_nullable'] == 'NO': - col_def += " NOT NULL" - if t['column_default']: - if 'nextval' in t['column_default']: - sequence_name = t['column_default'].split("'")[1] - await self.create_sequence_if_not_exists(conn, sequence_name) - col_def += f" DEFAULT {t['column_default']}" - columns.append(col_def) - - primary_key_constraint = next( - (con['definition'] for con in source_schema['constraints'] if con['table_name'] == table_name and con['contype'] == 'p'), - None - ) - - sql = f'CREATE TABLE "{table_name}" ({", ".join(columns)}' - if primary_key_constraint: - sql += f', {primary_key_constraint}' - sql += ')' - - info(f"Executing SQL: {sql}") - await conn.execute(sql) - else: - target_table = target_tables[table_name] - source_columns = {t['column_name']: t for t in source_schema['tables'] if t['table_name'] == table_name} - target_columns = {t['column_name']: t for t in target_schema['tables'] if t['table_name'] == table_name} - - for col_name, source_col in source_columns.items(): - if col_name not in target_columns: - col_type = get_column_type(source_col['data_type']) - col_def = f"\"{col_name}\" {col_type}" + \ - (f"({source_col['character_maximum_length']})" if source_col['character_maximum_length'] else "") + \ - (" NOT NULL" if source_col['is_nullable'] == 'NO' else "") + \ - (f" DEFAULT {source_col['column_default']}" if source_col['column_default'] else "") - sql = f'ALTER TABLE "{table_name}" ADD COLUMN {col_def}' - debug(f"Executing SQL: {sql}") - await conn.execute(sql) - else: - target_col = target_columns[col_name] - if source_col != target_col: - col_type = get_column_type(source_col['data_type']) - sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" TYPE {col_type}' - debug(f"Executing SQL: {sql}") - await conn.execute(sql) - if source_col['is_nullable'] != target_col['is_nullable']: - null_constraint = "DROP NOT NULL" if source_col['is_nullable'] == 'YES' else "SET NOT NULL" - sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {null_constraint}' - debug(f"Executing SQL: {sql}") - await conn.execute(sql) - if source_col['column_default'] != target_col['column_default']: - default_clause = f"SET DEFAULT {source_col['column_default']}" if source_col['column_default'] else "DROP DEFAULT" - sql = f'ALTER TABLE "{table_name}" ALTER COLUMN "{col_name}" {default_clause}' - debug(f"Executing SQL: {sql}") - await conn.execute(sql) - - # Ensure primary key constraint exists - primary_key_constraint = next( - (con['definition'] for con in source_schema['constraints'] if con['table_name'] == table_name and con['contype'] == 'p'), - None - ) - if primary_key_constraint: - constraint_name = f"{table_name}_pkey" - constraint_exists = await conn.fetchval(f""" - SELECT 1 FROM pg_constraint - WHERE conname = '{constraint_name}' - """) - if not constraint_exists: - sql = f'ALTER TABLE "{table_name}" ADD CONSTRAINT {constraint_name} {primary_key_constraint}' - debug(f"Executing SQL: {sql}") - await conn.execute(sql) - except Exception as e: - err(f"Error processing table {table_name}: {str(e)}") - - try: - source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']} - target_indexes = {idx['indexname']: idx['indexdef'] for idx in target_schema['indexes']} - - for idx_name, idx_def in source_indexes.items(): - if idx_name not in target_indexes: - debug(f"Executing SQL: {idx_def}") - await conn.execute(idx_def) - elif idx_def != target_indexes[idx_name]: - sql = f'DROP INDEX IF EXISTS "{idx_name}"' - debug(f"Executing SQL: {sql}") - await conn.execute(sql) - debug(f"Executing SQL: {idx_def}") - await conn.execute(idx_def) - except Exception as e: - err(f"Error processing indexes: {str(e)}") - - try: - source_constraints = {con['conname']: con for con in source_schema['constraints']} - target_constraints = {con['conname']: con for con in target_schema['constraints']} - - for con_name, source_con in source_constraints.items(): - if con_name not in target_constraints: - sql = f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}' - debug(f"Executing SQL: {sql}") - await conn.execute(sql) - elif source_con != target_constraints[con_name]: - sql = f'ALTER TABLE "{source_con["table_name"]}" DROP CONSTRAINT IF EXISTS "{con_name}"' - debug(f"Executing SQL: {sql}") - await conn.execute(sql) - sql = f'ALTER TABLE "{source_con["table_name"]}" ADD CONSTRAINT "{con_name}" {source_con["definition"]}' - debug(f"Executing SQL: {sql}") - await conn.execute(sql) - except Exception as e: - err(f"Error processing constraints: {str(e)}") - - # Update schema version - await conn.execute("UPDATE schema_version SET version = $1", source_version) - info(f"Schema synchronization completed for {pool_entry['ts_ip']}") - + async def update_last_synced_version(self, conn, table_name, server_id, version): + await conn.execute(f""" + INSERT INTO "{table_name}" (server_id, version) + VALUES ($1, $2) + ON CONFLICT (server_id) DO UPDATE + SET version = EXCLUDED.version + WHERE "{table_name}".version < EXCLUDED.version + """, server_id, version) async def get_schema_version(self, pool_entry): async with self.get_connection(pool_entry) as conn: - return await conn.fetchval("SELECT version FROM schema_version") - + return await conn.fetchval(""" + SELECT COALESCE(MAX(version), 0) FROM ( + SELECT MAX(version) as version FROM pg_tables + WHERE schemaname = 'public' + ) as subquery + """) async def create_sequence_if_not_exists(self, conn, sequence_name): await conn.execute(f""" @@ -831,6 +529,7 @@ class APIConfig(BaseModel): + class Location(BaseModel): latitude: float longitude: float