Auto-update: Mon Jul 29 17:16:23 PDT 2024

This commit is contained in:
sanj 2024-07-29 17:16:23 -07:00
parent 2354fb1588
commit 57ea1db0b2

View file

@ -8,6 +8,7 @@ import re
import aiofiles import aiofiles
import aiohttp import aiohttp
import asyncpg import asyncpg
import traceback
import reverse_geocoder as rg import reverse_geocoder as rg
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar
@ -391,7 +392,7 @@ class APIConfig(BaseModel):
try: try:
async with self.get_connection(source_pool_entry) as source_conn: async with self.get_connection(source_pool_entry) as source_conn:
async with self.get_connection() as dest_conn: async with self.get_connection(self.local_db) as dest_conn: # Connect to local DB explicitly
# Compare tables # Compare tables
source_tables = await self.get_tables(source_conn) source_tables = await self.get_tables(source_conn)
dest_tables = await self.get_tables(dest_conn) dest_tables = await self.get_tables(dest_conn)
@ -412,6 +413,10 @@ class APIConfig(BaseModel):
total_updates += updates total_updates += updates
table_changes[table] = {'inserts': inserts, 'updates': updates} table_changes[table] = {'inserts': inserts, 'updates': updates}
# Optionally, handle tables only in source
for table in tables_only_in_source:
warn(f"Table '{table}' exists in source but not in destination. Consider manual migration.")
info(f"Comprehensive sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip})") info(f"Comprehensive sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip})")
info(f"Total changes: {total_inserts} inserts, {total_updates} updates") info(f"Total changes: {total_inserts} inserts, {total_updates} updates")
info("Changes by table:") info("Changes by table:")
@ -420,10 +425,12 @@ class APIConfig(BaseModel):
except Exception as e: except Exception as e:
err(f"Error during sync process: {str(e)}") err(f"Error during sync process: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
return total_inserts + total_updates return total_inserts + total_updates
async def get_tables(self, conn): async def get_tables(self, conn):
tables = await conn.fetch(""" tables = await conn.fetch("""
SELECT tablename FROM pg_tables SELECT tablename FROM pg_tables
@ -464,8 +471,9 @@ class APIConfig(BaseModel):
try: try:
primary_keys = await self.get_primary_keys(dest_conn, table_name) primary_keys = await self.get_primary_keys(dest_conn, table_name)
if not primary_keys: if not primary_keys:
warn(f"Table {table_name} has no primary keys. Skipping data sync.") warn(f"Table {table_name} has no primary keys. Using all columns for comparison.")
return inserts, updates columns = await self.get_table_columns(dest_conn, table_name)
primary_keys = columns # Use all columns if no primary key
last_synced_version = await self.get_last_synced_version(table_name, source_id) last_synced_version = await self.get_last_synced_version(table_name, source_id)
@ -479,11 +487,14 @@ class APIConfig(BaseModel):
columns = list(change.keys()) columns = list(change.keys())
values = [change[col] for col in columns] values = [change[col] for col in columns]
conflict_clause = f"({', '.join(primary_keys)})"
update_clause = ', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in primary_keys)
insert_query = f""" insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(columns)}) INSERT INTO "{table_name}" ({', '.join(columns)})
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))}) VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
ON CONFLICT ({', '.join(primary_keys)}) DO UPDATE SET ON CONFLICT {conflict_clause} DO UPDATE SET
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in primary_keys)} {update_clause}
""" """
try: try:
@ -512,6 +523,14 @@ class APIConfig(BaseModel):
return inserts, updates return inserts, updates
async def get_table_columns(self, conn, table_name):
columns = await conn.fetch("""
SELECT column_name
FROM information_schema.columns
WHERE table_name = $1
ORDER BY ordinal_position
""", table_name)
return [col['column_name'] for col in columns]
async def get_primary_keys(self, conn, table_name): async def get_primary_keys(self, conn, table_name):
primary_keys = await conn.fetch(""" primary_keys = await conn.fetch("""
@ -584,10 +603,12 @@ class APIConfig(BaseModel):
async def sync_schema(self): async def sync_schema(self):
local_id = os.environ.get('TS_ID')
source_entry = self.local_db source_entry = self.local_db
source_schema = await self.get_schema(source_entry) source_schema = await self.get_schema(source_entry)
for pool_entry in self.POOL[1:]: for pool_entry in self.POOL:
if pool_entry['ts_id'] != local_id: # Skip the local instance
try: try:
target_schema = await self.get_schema(pool_entry) target_schema = await self.get_schema(pool_entry)
await self.apply_schema_changes(pool_entry, source_schema, target_schema) await self.apply_schema_changes(pool_entry, source_schema, target_schema)