Auto-update: Tue Jul 30 14:20:47 PDT 2024

This commit is contained in:
sanj 2024-07-30 14:20:47 -07:00
parent 9a16e9f46b
commit cedbc4dfd2

View file

@ -1,5 +1,4 @@
# classes.py # classes.py
import asyncio
import json import json
import yaml import yaml
import math import math
@ -8,7 +7,9 @@ import re
import uuid import uuid
import aiofiles import aiofiles
import aiohttp import aiohttp
import asyncio
import asyncpg import asyncpg
import socket
import traceback import traceback
import reverse_geocoder as rg import reverse_geocoder as rg
from pathlib import Path from pathlib import Path
@ -330,6 +331,7 @@ class APIConfig(BaseModel):
err(f"Unexpected error when acquiring connection from pool for {pool_key}: {str(e)}") err(f"Unexpected error when acquiring connection from pool for {pool_key}: {str(e)}")
raise raise
async def close_db_pools(self): async def close_db_pools(self):
info("Closing database connection pools...") info("Closing database connection pools...")
for pool_key, pool in self._db_pools.items(): for pool_key, pool in self._db_pools.items():
@ -341,6 +343,7 @@ class APIConfig(BaseModel):
self._db_pools.clear() self._db_pools.clear()
info("All database connection pools closed.") info("All database connection pools closed.")
async def initialize_sync(self): async def initialize_sync(self):
local_ts_id = os.environ.get('TS_ID') local_ts_id = os.environ.get('TS_ID')
for pool_entry in self.POOL: for pool_entry in self.POOL:
@ -348,10 +351,13 @@ class APIConfig(BaseModel):
continue # Skip local database continue # Skip local database
try: try:
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
info(f"Starting sync initialization for {pool_entry['ts_ip']}...")
await self.ensure_sync_structure(conn) await self.ensure_sync_structure(conn)
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.") info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables should now have version and server_id columns with appropriate triggers.")
except Exception as e: except Exception as e:
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}") err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
async def ensure_sync_structure(self, conn): async def ensure_sync_structure(self, conn):
tables = await conn.fetch(""" tables = await conn.fetch("""
@ -364,37 +370,45 @@ class APIConfig(BaseModel):
await self.ensure_sync_columns(conn, table_name) await self.ensure_sync_columns(conn, table_name)
await self.ensure_sync_trigger(conn, table_name) await self.ensure_sync_trigger(conn, table_name)
async def ensure_sync_columns(self, conn, table_name): async def ensure_sync_columns(self, conn, table_name):
await conn.execute(f""" try:
DO $$ await conn.execute(f"""
BEGIN DO $$
BEGIN BEGIN
ALTER TABLE "{table_name}" BEGIN
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1; ALTER TABLE "{table_name}"
EXCEPTION ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1;
WHEN duplicate_column THEN EXCEPTION
-- Do nothing, column already exists WHEN duplicate_column THEN
END; NULL;
END;
BEGIN BEGIN
ALTER TABLE "{table_name}" ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}'; ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
EXCEPTION EXCEPTION
WHEN duplicate_column THEN WHEN duplicate_column THEN
-- Do nothing, column already exists NULL;
END; END;
END $$; END $$;
""") """)
# Verify that the columns were added # Verify that the columns were added
result = await conn.fetchrow(f""" result = await conn.fetchrow(f"""
SELECT SELECT
EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'version') as has_version, EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'version') as has_version,
EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'server_id') as has_server_id EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = '{table_name}' AND column_name = 'server_id') as has_server_id
""") """)
if not (result['has_version'] and result['has_server_id']):
raise Exception(f"Failed to add version and/or server_id columns to table {table_name}")
else:
info(f"Successfully added/verified version and server_id columns for table {table_name}")
except Exception as e:
err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
if not (result['has_version'] and result['has_server_id']):
raise Exception(f"Failed to add version and/or server_id columns to table {table_name}")
async def ensure_sync_trigger(self, conn, table_name): async def ensure_sync_trigger(self, conn, table_name):
await conn.execute(f""" await conn.execute(f"""
@ -423,11 +437,18 @@ class APIConfig(BaseModel):
if pool_entry['ts_id'] == local_ts_id: if pool_entry['ts_id'] == local_ts_id:
continue # Skip local database continue # Skip local database
if not await self.is_server_accessible(pool_entry['ts_ip'], pool_entry['db_port']):
warn(f"Server {pool_entry['ts_id']} ({pool_entry['ts_ip']}:{pool_entry['db_port']}) is not accessible. Skipping.")
continue
try: try:
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
if not await self.check_version_column_exists(conn): if not await self.check_version_column_exists(conn):
warn(f"Version column does not exist in {pool_entry['ts_id']}. Skipping.") warn(f"Version column does not exist in some tables for {pool_entry['ts_id']}. Attempting to add...")
continue await self.ensure_sync_structure(conn)
if not await self.check_version_column_exists(conn):
warn(f"Failed to add version column to all tables in {pool_entry['ts_id']}. Skipping.")
continue
version = await conn.fetchval(""" version = await conn.fetchval("""
SELECT COALESCE(MAX(version), -1) SELECT COALESCE(MAX(version), -1)
@ -437,27 +458,61 @@ class APIConfig(BaseModel):
WHERE schemaname = 'public' WHERE schemaname = 'public'
) as subquery ) as subquery
""") """)
info(f"Max version for {pool_entry['ts_id']}: {version}")
if version > max_version: if version > max_version:
max_version = version max_version = version
most_recent_source = pool_entry most_recent_source = pool_entry
except asyncpg.exceptions.ConnectionFailureError: except asyncpg.exceptions.ConnectionFailureError as e:
warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") err(f"Failed to establish database connection with {pool_entry['ts_id']} ({pool_entry['ts_ip']}:{pool_entry['db_port']}): {str(e)}")
except asyncpg.exceptions.PostgresError as e:
err(f"PostgreSQL error occurred while querying {pool_entry['ts_id']}: {str(e)}")
if "column \"version\" does not exist" in str(e):
err(f"The 'version' column is missing in one or more tables on {pool_entry['ts_id']}. This might indicate a synchronization issue.")
except Exception as e: except Exception as e:
warn(f"Error checking version for {pool_entry['ts_id']}: {str(e)}") err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
return most_recent_source return most_recent_source
async def is_server_accessible(self, host, port, timeout=2):
try:
future = asyncio.open_connection(host, port)
await asyncio.wait_for(future, timeout=timeout)
return True
except (asyncio.TimeoutError, ConnectionRefusedError, socket.gaierror):
return False
async def check_version_column_exists(self, conn): async def check_version_column_exists(self, conn):
result = await conn.fetchval(""" try:
SELECT EXISTS ( result = await conn.fetchval("""
SELECT 1 SELECT EXISTS (
FROM information_schema.columns SELECT 1
WHERE table_schema = 'public' FROM information_schema.columns
AND column_name = 'version' WHERE table_schema = 'public'
AND table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') AND column_name = 'version'
) AND table_name IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public')
""") )
return result """)
if not result:
tables_without_version = await conn.fetch("""
SELECT tablename
FROM pg_tables
WHERE schemaname = 'public'
AND tablename NOT IN (
SELECT table_name
FROM information_schema.columns
WHERE table_schema = 'public' AND column_name = 'version'
)
""")
table_names = ", ".join([t['tablename'] for t in tables_without_version])
warn(f"Tables without 'version' column: {table_names}")
return result
except Exception as e:
err(f"Error checking for 'version' column existence: {str(e)}")
return False
async def pull_changes(self, source_pool_entry, batch_size=10000): async def pull_changes(self, source_pool_entry, batch_size=10000):