Auto-update: Tue Jul 30 13:18:29 PDT 2024

This commit is contained in:
sanj 2024-07-30 13:18:29 -07:00
parent 6e958c5f1f
commit 0185e4d622

View file

@ -14,7 +14,7 @@ import reverse_geocoder as rg
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel, Field, create_model, PrivateAttr from pydantic import BaseModel, Field, create_model
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@ -165,32 +165,6 @@ class Configuration(BaseModel):
class DatabasePool:
def __init__(self):
self.pools = {}
async def get_connection(self, pool_entry):
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
if pool_key not in self.pools:
self.pools[pool_key] = await asyncpg.create_pool(
host=pool_entry['ts_ip'],
port=pool_entry['db_port'],
user=pool_entry['db_user'],
password=pool_entry['db_pass'],
database=pool_entry['db_name'],
min_size=1,
max_size=10
)
return await self.pools[pool_key].acquire()
async def release_connection(self, pool_entry, connection):
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
await self.pools[pool_key].release(connection)
async def close_all(self):
for pool in self.pools.values():
await pool.close()
class APIConfig(BaseModel): class APIConfig(BaseModel):
HOST: str HOST: str
PORT: int PORT: int
@ -204,18 +178,6 @@ class APIConfig(BaseModel):
TZ: str TZ: str
KEYS: List[str] KEYS: List[str]
GARBAGE: Dict[str, Any] GARBAGE: Dict[str, Any]
_db_pool: DatabasePool = PrivateAttr(default_factory=DatabasePool)
class Config:
arbitrary_types_allowed = True
@property
def db_pool(self):
return self._db_pool
@property
def db_pool(self):
return self._db_pool
@classmethod @classmethod
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]): def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
@ -258,7 +220,6 @@ class APIConfig(BaseModel):
config_data['EXTENSIONS'] = cls._create_dynamic_config(config_data.get('EXTENSIONS', {}), 'DynamicExtensionsConfig') config_data['EXTENSIONS'] = cls._create_dynamic_config(config_data.get('EXTENSIONS', {}), 'DynamicExtensionsConfig')
return cls(**config_data) return cls(**config_data)
@classmethod @classmethod
def _create_dynamic_config(cls, data: Dict[str, Any], model_name: str): def _create_dynamic_config(cls, data: Dict[str, Any], model_name: str):
fields = {} fields = {}
@ -337,7 +298,7 @@ class APIConfig(BaseModel):
if pool_entry is None: if pool_entry is None:
pool_entry = self.local_db pool_entry = self.local_db
for attempt in range(3): # Retry up to 3 times info(f"Attempting to connect to database: {pool_entry}")
try: try:
conn = await asyncpg.connect( conn = await asyncpg.connect(
host=pool_entry['ts_ip'], host=pool_entry['ts_ip'],
@ -350,21 +311,14 @@ class APIConfig(BaseModel):
yield conn yield conn
finally: finally:
await conn.close() await conn.close()
return
except asyncpg.exceptions.CannotConnectNowError:
if attempt < 2: # Don't sleep on the last attempt
await asyncio.sleep(1) # Wait before retrying
except Exception as e: except Exception as e:
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
err(f"Error: {str(e)}") err(f"Error: {str(e)}")
if attempt == 2: # Raise the exception on the last attempt
raise raise
raise Exception(f"Failed to connect to database after 3 attempts: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
async def initialize_sync(self): async def initialize_sync(self):
for pool_entry in self.POOL: for pool_entry in self.POOL:
for attempt in range(3): # Retry up to 3 times
try: try:
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
tables = await conn.fetch(""" tables = await conn.fetch("""
@ -374,22 +328,7 @@ class APIConfig(BaseModel):
for table in tables: for table in tables:
table_name = table['tablename'] table_name = table['tablename']
await self.ensure_sync_columns(conn, table_name) # Add version and server_id columns if they don't exist
await self.create_sync_trigger(conn, table_name)
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.")
break # If successful, break the retry loop
except asyncpg.exceptions.ConnectionFailureError:
err(f"Failed to connect to database during initialization: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
if attempt < 2: # Don't sleep on the last attempt
await asyncio.sleep(1) # Wait before retrying
except Exception as e:
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
break # Don't retry for unexpected errors
async def ensure_sync_columns(self, conn, table_name):
try:
await conn.execute(f""" await conn.execute(f"""
DO $$ DO $$
BEGIN BEGIN
@ -399,17 +338,12 @@ class APIConfig(BaseModel):
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}'; ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
EXCEPTION EXCEPTION
WHEN duplicate_column THEN WHEN duplicate_column THEN
NULL; -- Silently handle duplicate column -- Do nothing, column already exists
END; END;
END $$; END $$;
""") """)
info(f"Ensured sync columns for table {table_name}")
except Exception as e:
err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
# Create or replace the trigger function
async def create_sync_trigger(self, conn, table_name):
await conn.execute(f""" await conn.execute(f"""
CREATE OR REPLACE FUNCTION update_version_and_server_id() CREATE OR REPLACE FUNCTION update_version_and_server_id()
RETURNS TRIGGER AS $$ RETURNS TRIGGER AS $$
@ -419,7 +353,10 @@ class APIConfig(BaseModel):
RETURN NEW; RETURN NEW;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
""")
# Create the trigger if it doesn't exist
await conn.execute(f"""
DO $$ DO $$
BEGIN BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_version_and_server_id_trigger' AND tgrelid = '{table_name}'::regclass) THEN IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_version_and_server_id_trigger' AND tgrelid = '{table_name}'::regclass) THEN
@ -430,6 +367,10 @@ class APIConfig(BaseModel):
END $$; END $$;
""") """)
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.")
except Exception as e:
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
async def get_most_recent_source(self): async def get_most_recent_source(self):
most_recent_source = None most_recent_source = None
max_version = -1 max_version = -1
@ -438,52 +379,25 @@ class APIConfig(BaseModel):
if pool_entry['ts_id'] == os.environ.get('TS_ID'): if pool_entry['ts_id'] == os.environ.get('TS_ID'):
continue continue
for _ in range(3): # Retry up to 3 times
try: try:
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
# Check if the version column exists in any table
version_exists = await conn.fetchval("""
SELECT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_schema = 'public'
AND column_name = 'version'
)
""")
if not version_exists:
info(f"Version column does not exist in any table for {pool_entry['ts_id']}")
break # Move to the next pool entry
version = await conn.fetchval(""" version = await conn.fetchval("""
SELECT COALESCE(MAX(version), -1) SELECT COALESCE(MAX(version), -1) FROM (
FROM ( SELECT MAX(version) as version FROM information_schema.columns
SELECT MAX(version) as version WHERE table_schema = 'public' AND column_name = 'version'
FROM information_schema.columns
WHERE table_schema = 'public'
AND column_name = 'version'
AND is_updatable = 'YES'
) as subquery ) as subquery
""") """)
info(f"Max version for {pool_entry['ts_id']}: {version}")
if version > max_version: if version > max_version:
max_version = version max_version = version
most_recent_source = pool_entry most_recent_source = pool_entry
break # If successful, break the retry loop
except asyncpg.exceptions.PostgresError as e:
err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}")
await asyncio.sleep(1) # Wait before retrying
except Exception as e: except Exception as e:
err(f"Unexpected error for {pool_entry['ts_id']}: {str(e)}") err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}")
break # Don't retry for unexpected errors
if most_recent_source:
info(f"Most recent source: {most_recent_source['ts_id']} with version {max_version}")
else:
info("No valid source found with version information")
return most_recent_source return most_recent_source
async def pull_changes(self, source_pool_entry, batch_size=10000): async def pull_changes(self, source_pool_entry, batch_size=10000):
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
info("Skipping self-sync") info("Skipping self-sync")
@ -500,10 +414,6 @@ class APIConfig(BaseModel):
try: try:
async with self.get_connection(source_pool_entry) as source_conn: async with self.get_connection(source_pool_entry) as source_conn:
async with self.get_connection(self.local_db) as dest_conn: async with self.get_connection(self.local_db) as dest_conn:
# Sync schema first
schema_changes = await self.detect_schema_changes(source_conn, dest_conn)
await self.apply_schema_changes(dest_conn, schema_changes)
tables = await source_conn.fetch(""" tables = await source_conn.fetch("""
SELECT tablename FROM pg_tables SELECT tablename FROM pg_tables
WHERE schemaname = 'public' WHERE schemaname = 'public'
@ -557,22 +467,12 @@ class APIConfig(BaseModel):
columns = changes[0].keys() columns = changes[0].keys()
await conn.copy_records_to_table(temp_table_name, records=[tuple(change[col] for col in columns) for change in changes]) await conn.copy_records_to_table(temp_table_name, records=[tuple(change[col] for col in columns) for change in changes])
# Perform upsert with spatial awareness # Perform upsert
result = await conn.execute(f""" result = await conn.execute(f"""
INSERT INTO "{table_name}" INSERT INTO "{table_name}"
SELECT tc.* SELECT * FROM {temp_table_name}
FROM {temp_table_name} tc
LEFT JOIN "{table_name}" t ON t.id = tc.id
WHERE t.id IS NULL
ON CONFLICT (id) DO UPDATE SET ON CONFLICT (id) DO UPDATE SET
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')} {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
WHERE (
CASE
WHEN "{table_name}".geometry IS NOT NULL AND EXCLUDED.geometry IS NOT NULL
THEN NOT ST_Equals("{table_name}".geometry, EXCLUDED.geometry)
ELSE FALSE
END
) OR {' OR '.join(f"COALESCE({col} <> EXCLUDED.{col}, TRUE)" for col in columns if col not in ['id', 'geometry'])}
""") """)
# Parse the result to get the number of affected rows # Parse the result to get the number of affected rows
@ -583,6 +483,7 @@ class APIConfig(BaseModel):
# Ensure temporary table is dropped # Ensure temporary table is dropped
await conn.execute(f"DROP TABLE IF EXISTS {temp_table_name}") await conn.execute(f"DROP TABLE IF EXISTS {temp_table_name}")
async def push_changes_to_all(self): async def push_changes_to_all(self):
for pool_entry in self.POOL: for pool_entry in self.POOL:
if pool_entry['ts_id'] != os.environ.get('TS_ID'): if pool_entry['ts_id'] != os.environ.get('TS_ID'):
@ -591,14 +492,11 @@ class APIConfig(BaseModel):
except Exception as e: except Exception as e:
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
async def push_changes_to_one(self, pool_entry, batch_size=10000):
async def push_changes_to_one(self, pool_entry):
try: try:
async with self.get_connection() as local_conn: async with self.get_connection() as local_conn:
async with self.get_connection(pool_entry) as remote_conn: async with self.get_connection(pool_entry) as remote_conn:
# Sync schema first
schema_changes = await self.detect_schema_changes(local_conn, remote_conn)
await self.apply_schema_changes(remote_conn, schema_changes)
tables = await local_conn.fetch(""" tables = await local_conn.fetch("""
SELECT tablename FROM pg_tables SELECT tablename FROM pg_tables
WHERE schemaname = 'public' WHERE schemaname = 'public'
@ -608,29 +506,49 @@ class APIConfig(BaseModel):
table_name = table['tablename'] table_name = table['tablename']
last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID')) last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
while True:
changes = await local_conn.fetch(f""" changes = await local_conn.fetch(f"""
SELECT * FROM "{table_name}" SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2 WHERE version > $1 AND server_id = $2
ORDER BY version ASC ORDER BY version ASC
LIMIT $3 """, last_synced_version, os.environ.get('TS_ID'))
""", last_synced_version, os.environ.get('TS_ID'), batch_size)
if not changes: for change in changes:
break columns = list(change.keys())
values = [change[col] for col in columns]
placeholders = [f'${i+1}' for i in range(len(columns))]
changes_count = await self.apply_batch_changes(remote_conn, table_name, changes) insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT (id) DO UPDATE SET
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
"""
last_synced_version = changes[-1]['version'] await remote_conn.execute(insert_query, *values)
await self.update_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'), last_synced_version)
info(f"Pushed batch for {table_name}: {changes_count} changes to {pool_entry['ts_id']}") if changes:
await self.update_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version'])
info(f"Successfully pushed changes to {pool_entry['ts_id']}") info(f"Successfully pushed changes to {pool_entry['ts_id']}")
except Exception as e: except Exception as e:
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
async def ensure_sync_columns(self, conn, table_name):
await conn.execute(f"""
DO $$
BEGIN
BEGIN
ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
EXCEPTION
WHEN duplicate_column THEN
-- Do nothing, column already exists
END;
END $$;
""")
async def get_last_synced_version(self, conn, table_name, server_id): async def get_last_synced_version(self, conn, table_name, server_id):
return await conn.fetchval(f""" return await conn.fetchval(f"""
SELECT COALESCE(MAX(version), 0) SELECT COALESCE(MAX(version), 0)
@ -666,85 +584,7 @@ class APIConfig(BaseModel):
END $$; END $$;
""") """)
async def detect_schema_changes(self, source_conn, dest_conn):
schema_changes = {
'new_tables': [],
'new_columns': {}
}
# Detect new tables
source_tables = await source_conn.fetch("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
dest_tables = await dest_conn.fetch("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
source_table_names = set(table['tablename'] for table in source_tables)
dest_table_names = set(table['tablename'] for table in dest_tables)
new_tables = source_table_names - dest_table_names
schema_changes['new_tables'] = list(new_tables)
# Detect new columns
for table_name in source_table_names:
if table_name in dest_table_names:
source_columns = await source_conn.fetch(f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}'")
dest_columns = await dest_conn.fetch(f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}'")
source_column_names = set(column['column_name'] for column in source_columns)
dest_column_names = set(column['column_name'] for column in dest_columns)
new_columns = source_column_names - dest_column_names
if new_columns:
schema_changes['new_columns'][table_name] = [
{'name': column['column_name'], 'type': column['data_type']}
for column in source_columns if column['column_name'] in new_columns
]
return schema_changes
async def apply_schema_changes(self, conn, schema_changes):
for table_name in schema_changes['new_tables']:
create_table_sql = await self.get_table_creation_sql(conn, table_name)
await conn.execute(create_table_sql)
info(f"Created new table: {table_name}")
for table_name, columns in schema_changes['new_columns'].items():
for column in columns:
await conn.execute(f"""
ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS {column['name']} {column['type']}
""")
info(f"Added new column {column['name']} to table {table_name}")
async def get_table_creation_sql(self, conn, table_name):
create_table_sql = await conn.fetchval(f"""
SELECT pg_get_tabledef('{table_name}'::regclass::oid)
""")
return create_table_sql
async def table_exists(self, conn, table_name):
exists = await conn.fetchval(f"""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = $1
)
""", table_name)
return exists
async def column_exists(self, conn, table_name, column_name):
exists = await conn.fetchval(f"""
SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
AND column_name = $2
)
""", table_name, column_name)
return exists
async def close_db_pools(self):
if self._db_pool:
await self._db_pool.close_all()
class Location(BaseModel): class Location(BaseModel):