Auto-update: Tue Jul 30 13:18:29 PDT 2024

This commit is contained in:
sanj 2024-07-30 13:18:29 -07:00
parent 6e958c5f1f
commit 0185e4d622

View file

@ -14,7 +14,7 @@ import reverse_geocoder as rg
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel, Field, create_model, PrivateAttr from pydantic import BaseModel, Field, create_model
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@ -165,32 +165,6 @@ class Configuration(BaseModel):
class DatabasePool:
def __init__(self):
self.pools = {}
async def get_connection(self, pool_entry):
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
if pool_key not in self.pools:
self.pools[pool_key] = await asyncpg.create_pool(
host=pool_entry['ts_ip'],
port=pool_entry['db_port'],
user=pool_entry['db_user'],
password=pool_entry['db_pass'],
database=pool_entry['db_name'],
min_size=1,
max_size=10
)
return await self.pools[pool_key].acquire()
async def release_connection(self, pool_entry, connection):
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
await self.pools[pool_key].release(connection)
async def close_all(self):
for pool in self.pools.values():
await pool.close()
class APIConfig(BaseModel): class APIConfig(BaseModel):
HOST: str HOST: str
PORT: int PORT: int
@ -204,18 +178,6 @@ class APIConfig(BaseModel):
TZ: str TZ: str
KEYS: List[str] KEYS: List[str]
GARBAGE: Dict[str, Any] GARBAGE: Dict[str, Any]
_db_pool: DatabasePool = PrivateAttr(default_factory=DatabasePool)
class Config:
arbitrary_types_allowed = True
@property
def db_pool(self):
return self._db_pool
@property
def db_pool(self):
return self._db_pool
@classmethod @classmethod
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]): def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
@ -258,7 +220,6 @@ class APIConfig(BaseModel):
config_data['EXTENSIONS'] = cls._create_dynamic_config(config_data.get('EXTENSIONS', {}), 'DynamicExtensionsConfig') config_data['EXTENSIONS'] = cls._create_dynamic_config(config_data.get('EXTENSIONS', {}), 'DynamicExtensionsConfig')
return cls(**config_data) return cls(**config_data)
@classmethod @classmethod
def _create_dynamic_config(cls, data: Dict[str, Any], model_name: str): def _create_dynamic_config(cls, data: Dict[str, Any], model_name: str):
fields = {} fields = {}
@ -337,98 +298,78 @@ class APIConfig(BaseModel):
if pool_entry is None: if pool_entry is None:
pool_entry = self.local_db pool_entry = self.local_db
for attempt in range(3): # Retry up to 3 times info(f"Attempting to connect to database: {pool_entry}")
try:
conn = await asyncpg.connect(
host=pool_entry['ts_ip'],
port=pool_entry['db_port'],
user=pool_entry['db_user'],
password=pool_entry['db_pass'],
database=pool_entry['db_name']
)
try: try:
conn = await asyncpg.connect( yield conn
host=pool_entry['ts_ip'], finally:
port=pool_entry['db_port'], await conn.close()
user=pool_entry['db_user'], except Exception as e:
password=pool_entry['db_pass'], warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
database=pool_entry['db_name'] err(f"Error: {str(e)}")
) raise
try:
yield conn
finally:
await conn.close()
return
except asyncpg.exceptions.CannotConnectNowError:
if attempt < 2: # Don't sleep on the last attempt
await asyncio.sleep(1) # Wait before retrying
except Exception as e:
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
err(f"Error: {str(e)}")
if attempt == 2: # Raise the exception on the last attempt
raise
raise Exception(f"Failed to connect to database after 3 attempts: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
async def initialize_sync(self): async def initialize_sync(self):
for pool_entry in self.POOL: for pool_entry in self.POOL:
for attempt in range(3): # Retry up to 3 times try:
try: async with self.get_connection(pool_entry) as conn:
async with self.get_connection(pool_entry) as conn: tables = await conn.fetch("""
tables = await conn.fetch(""" SELECT tablename FROM pg_tables
SELECT tablename FROM pg_tables WHERE schemaname = 'public'
WHERE schemaname = 'public' """)
for table in tables:
table_name = table['tablename']
# Add version and server_id columns if they don't exist
await conn.execute(f"""
DO $$
BEGIN
BEGIN
ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
EXCEPTION
WHEN duplicate_column THEN
-- Do nothing, column already exists
END;
END $$;
""") """)
for table in tables: # Create or replace the trigger function
table_name = table['tablename'] await conn.execute(f"""
await self.ensure_sync_columns(conn, table_name) CREATE OR REPLACE FUNCTION update_version_and_server_id()
await self.create_sync_trigger(conn, table_name) RETURNS TRIGGER AS $$
BEGIN
NEW.version = COALESCE(OLD.version, 0) + 1;
NEW.server_id = '{os.environ.get('TS_ID')}';
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
""")
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.") # Create the trigger if it doesn't exist
break # If successful, break the retry loop await conn.execute(f"""
except asyncpg.exceptions.ConnectionFailureError: DO $$
err(f"Failed to connect to database during initialization: {pool_entry['ts_ip']}:{pool_entry['db_port']}") BEGIN
if attempt < 2: # Don't sleep on the last attempt IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_version_and_server_id_trigger' AND tgrelid = '{table_name}'::regclass) THEN
await asyncio.sleep(1) # Wait before retrying CREATE TRIGGER update_version_and_server_id_trigger
except Exception as e: BEFORE INSERT OR UPDATE ON "{table_name}"
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}") FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
err(f"Traceback: {traceback.format_exc()}") END IF;
break # Don't retry for unexpected errors END $$;
""")
async def ensure_sync_columns(self, conn, table_name): info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.")
try: except Exception as e:
await conn.execute(f""" err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
DO $$
BEGIN
BEGIN
ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
EXCEPTION
WHEN duplicate_column THEN
NULL; -- Silently handle duplicate column
END;
END $$;
""")
info(f"Ensured sync columns for table {table_name}")
except Exception as e:
err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
async def create_sync_trigger(self, conn, table_name):
await conn.execute(f"""
CREATE OR REPLACE FUNCTION update_version_and_server_id()
RETURNS TRIGGER AS $$
BEGIN
NEW.version = COALESCE(OLD.version, 0) + 1;
NEW.server_id = '{os.environ.get('TS_ID')}';
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_version_and_server_id_trigger' AND tgrelid = '{table_name}'::regclass) THEN
CREATE TRIGGER update_version_and_server_id_trigger
BEFORE INSERT OR UPDATE ON "{table_name}"
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
END IF;
END $$;
""")
async def get_most_recent_source(self): async def get_most_recent_source(self):
most_recent_source = None most_recent_source = None
@ -438,52 +379,25 @@ class APIConfig(BaseModel):
if pool_entry['ts_id'] == os.environ.get('TS_ID'): if pool_entry['ts_id'] == os.environ.get('TS_ID'):
continue continue
for _ in range(3): # Retry up to 3 times try:
try: async with self.get_connection(pool_entry) as conn:
async with self.get_connection(pool_entry) as conn: version = await conn.fetchval("""
# Check if the version column exists in any table SELECT COALESCE(MAX(version), -1) FROM (
version_exists = await conn.fetchval(""" SELECT MAX(version) as version FROM information_schema.columns
SELECT EXISTS ( WHERE table_schema = 'public' AND column_name = 'version'
SELECT 1 ) as subquery
FROM information_schema.columns """)
WHERE table_schema = 'public' if version > max_version:
AND column_name = 'version' max_version = version
) most_recent_source = pool_entry
""") except Exception as e:
err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}")
if not version_exists:
info(f"Version column does not exist in any table for {pool_entry['ts_id']}")
break # Move to the next pool entry
version = await conn.fetchval("""
SELECT COALESCE(MAX(version), -1)
FROM (
SELECT MAX(version) as version
FROM information_schema.columns
WHERE table_schema = 'public'
AND column_name = 'version'
AND is_updatable = 'YES'
) as subquery
""")
info(f"Max version for {pool_entry['ts_id']}: {version}")
if version > max_version:
max_version = version
most_recent_source = pool_entry
break # If successful, break the retry loop
except asyncpg.exceptions.PostgresError as e:
err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}")
await asyncio.sleep(1) # Wait before retrying
except Exception as e:
err(f"Unexpected error for {pool_entry['ts_id']}: {str(e)}")
break # Don't retry for unexpected errors
if most_recent_source:
info(f"Most recent source: {most_recent_source['ts_id']} with version {max_version}")
else:
info("No valid source found with version information")
return most_recent_source return most_recent_source
async def pull_changes(self, source_pool_entry, batch_size=10000): async def pull_changes(self, source_pool_entry, batch_size=10000):
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
info("Skipping self-sync") info("Skipping self-sync")
@ -500,10 +414,6 @@ class APIConfig(BaseModel):
try: try:
async with self.get_connection(source_pool_entry) as source_conn: async with self.get_connection(source_pool_entry) as source_conn:
async with self.get_connection(self.local_db) as dest_conn: async with self.get_connection(self.local_db) as dest_conn:
# Sync schema first
schema_changes = await self.detect_schema_changes(source_conn, dest_conn)
await self.apply_schema_changes(dest_conn, schema_changes)
tables = await source_conn.fetch(""" tables = await source_conn.fetch("""
SELECT tablename FROM pg_tables SELECT tablename FROM pg_tables
WHERE schemaname = 'public' WHERE schemaname = 'public'
@ -557,22 +467,12 @@ class APIConfig(BaseModel):
columns = changes[0].keys() columns = changes[0].keys()
await conn.copy_records_to_table(temp_table_name, records=[tuple(change[col] for col in columns) for change in changes]) await conn.copy_records_to_table(temp_table_name, records=[tuple(change[col] for col in columns) for change in changes])
# Perform upsert with spatial awareness # Perform upsert
result = await conn.execute(f""" result = await conn.execute(f"""
INSERT INTO "{table_name}" INSERT INTO "{table_name}"
SELECT tc.* SELECT * FROM {temp_table_name}
FROM {temp_table_name} tc
LEFT JOIN "{table_name}" t ON t.id = tc.id
WHERE t.id IS NULL
ON CONFLICT (id) DO UPDATE SET ON CONFLICT (id) DO UPDATE SET
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')} {', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
WHERE (
CASE
WHEN "{table_name}".geometry IS NOT NULL AND EXCLUDED.geometry IS NOT NULL
THEN NOT ST_Equals("{table_name}".geometry, EXCLUDED.geometry)
ELSE FALSE
END
) OR {' OR '.join(f"COALESCE({col} <> EXCLUDED.{col}, TRUE)" for col in columns if col not in ['id', 'geometry'])}
""") """)
# Parse the result to get the number of affected rows # Parse the result to get the number of affected rows
@ -583,6 +483,7 @@ class APIConfig(BaseModel):
# Ensure temporary table is dropped # Ensure temporary table is dropped
await conn.execute(f"DROP TABLE IF EXISTS {temp_table_name}") await conn.execute(f"DROP TABLE IF EXISTS {temp_table_name}")
async def push_changes_to_all(self): async def push_changes_to_all(self):
for pool_entry in self.POOL: for pool_entry in self.POOL:
if pool_entry['ts_id'] != os.environ.get('TS_ID'): if pool_entry['ts_id'] != os.environ.get('TS_ID'):
@ -591,14 +492,11 @@ class APIConfig(BaseModel):
except Exception as e: except Exception as e:
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
async def push_changes_to_one(self, pool_entry, batch_size=10000):
async def push_changes_to_one(self, pool_entry):
try: try:
async with self.get_connection() as local_conn: async with self.get_connection() as local_conn:
async with self.get_connection(pool_entry) as remote_conn: async with self.get_connection(pool_entry) as remote_conn:
# Sync schema first
schema_changes = await self.detect_schema_changes(local_conn, remote_conn)
await self.apply_schema_changes(remote_conn, schema_changes)
tables = await local_conn.fetch(""" tables = await local_conn.fetch("""
SELECT tablename FROM pg_tables SELECT tablename FROM pg_tables
WHERE schemaname = 'public' WHERE schemaname = 'public'
@ -608,29 +506,49 @@ class APIConfig(BaseModel):
table_name = table['tablename'] table_name = table['tablename']
last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID')) last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
while True: changes = await local_conn.fetch(f"""
changes = await local_conn.fetch(f""" SELECT * FROM "{table_name}"
SELECT * FROM "{table_name}" WHERE version > $1 AND server_id = $2
WHERE version > $1 AND server_id = $2 ORDER BY version ASC
ORDER BY version ASC """, last_synced_version, os.environ.get('TS_ID'))
LIMIT $3
""", last_synced_version, os.environ.get('TS_ID'), batch_size)
if not changes: for change in changes:
break columns = list(change.keys())
values = [change[col] for col in columns]
placeholders = [f'${i+1}' for i in range(len(columns))]
changes_count = await self.apply_batch_changes(remote_conn, table_name, changes) insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT (id) DO UPDATE SET
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
"""
last_synced_version = changes[-1]['version'] await remote_conn.execute(insert_query, *values)
await self.update_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'), last_synced_version)
info(f"Pushed batch for {table_name}: {changes_count} changes to {pool_entry['ts_id']}") if changes:
await self.update_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'), changes[-1]['version'])
info(f"Successfully pushed changes to {pool_entry['ts_id']}") info(f"Successfully pushed changes to {pool_entry['ts_id']}")
except Exception as e: except Exception as e:
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
async def ensure_sync_columns(self, conn, table_name):
await conn.execute(f"""
DO $$
BEGIN
BEGIN
ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
EXCEPTION
WHEN duplicate_column THEN
-- Do nothing, column already exists
END;
END $$;
""")
async def get_last_synced_version(self, conn, table_name, server_id): async def get_last_synced_version(self, conn, table_name, server_id):
return await conn.fetchval(f""" return await conn.fetchval(f"""
SELECT COALESCE(MAX(version), 0) SELECT COALESCE(MAX(version), 0)
@ -666,85 +584,7 @@ class APIConfig(BaseModel):
END $$; END $$;
""") """)
async def detect_schema_changes(self, source_conn, dest_conn):
schema_changes = {
'new_tables': [],
'new_columns': {}
}
# Detect new tables
source_tables = await source_conn.fetch("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
dest_tables = await dest_conn.fetch("SELECT tablename FROM pg_tables WHERE schemaname = 'public'")
source_table_names = set(table['tablename'] for table in source_tables)
dest_table_names = set(table['tablename'] for table in dest_tables)
new_tables = source_table_names - dest_table_names
schema_changes['new_tables'] = list(new_tables)
# Detect new columns
for table_name in source_table_names:
if table_name in dest_table_names:
source_columns = await source_conn.fetch(f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}'")
dest_columns = await dest_conn.fetch(f"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}'")
source_column_names = set(column['column_name'] for column in source_columns)
dest_column_names = set(column['column_name'] for column in dest_columns)
new_columns = source_column_names - dest_column_names
if new_columns:
schema_changes['new_columns'][table_name] = [
{'name': column['column_name'], 'type': column['data_type']}
for column in source_columns if column['column_name'] in new_columns
]
return schema_changes
async def apply_schema_changes(self, conn, schema_changes):
for table_name in schema_changes['new_tables']:
create_table_sql = await self.get_table_creation_sql(conn, table_name)
await conn.execute(create_table_sql)
info(f"Created new table: {table_name}")
for table_name, columns in schema_changes['new_columns'].items():
for column in columns:
await conn.execute(f"""
ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS {column['name']} {column['type']}
""")
info(f"Added new column {column['name']} to table {table_name}")
async def get_table_creation_sql(self, conn, table_name):
create_table_sql = await conn.fetchval(f"""
SELECT pg_get_tabledef('{table_name}'::regclass::oid)
""")
return create_table_sql
async def table_exists(self, conn, table_name):
exists = await conn.fetchval(f"""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = $1
)
""", table_name)
return exists
async def column_exists(self, conn, table_name, column_name):
exists = await conn.fetchval(f"""
SELECT EXISTS (
SELECT FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = $1
AND column_name = $2
)
""", table_name, column_name)
return exists
async def close_db_pools(self):
if self._db_pool:
await self._db_pool.close_all()
class Location(BaseModel): class Location(BaseModel):