Auto-update: Thu Jul 25 01:27:56 PDT 2024
This commit is contained in:
parent
8775c05927
commit
0be1915aeb
2 changed files with 186 additions and 169 deletions
|
@ -5,7 +5,6 @@ from fastapi import FastAPI, Request, HTTPException, Response
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.requests import ClientDisconnect
|
from starlette.requests import ClientDisconnect
|
||||||
from hypercorn.asyncio import serve
|
from hypercorn.asyncio import serve
|
||||||
from hypercorn.config import Config as HypercornConfig
|
from hypercorn.config import Config as HypercornConfig
|
||||||
|
@ -44,7 +43,6 @@ err(f"Error message.")
|
||||||
def crit(text: str): logger.critical(text)
|
def crit(text: str): logger.critical(text)
|
||||||
crit(f"Critical message.")
|
crit(f"Critical message.")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup
|
# Startup
|
||||||
|
@ -52,36 +50,30 @@ async def lifespan(app: FastAPI):
|
||||||
crit(f"Arguments: {args}")
|
crit(f"Arguments: {args}")
|
||||||
|
|
||||||
# Load routers
|
# Load routers
|
||||||
for module_name in API.MODULES.__fields__:
|
if args.test:
|
||||||
if getattr(API.MODULES, module_name):
|
load_router(args.test)
|
||||||
load_router(module_name)
|
else:
|
||||||
|
for module_name in API.MODULES.__fields__:
|
||||||
|
if getattr(API.MODULES, module_name):
|
||||||
|
load_router(module_name)
|
||||||
|
|
||||||
crit("Starting database synchronization...")
|
crit("Starting database synchronization...")
|
||||||
try:
|
try:
|
||||||
# Log the current TS_ID
|
# Initialize sync structures
|
||||||
crit(f"Current TS_ID: {os.environ.get('TS_ID', 'Not set')}")
|
await API.initialize_sync()
|
||||||
|
|
||||||
# Log the local_db configuration
|
|
||||||
local_db = API.local_db
|
|
||||||
crit(f"Local DB configuration: {local_db}")
|
|
||||||
|
|
||||||
# Test local connection
|
|
||||||
async with API.get_connection() as conn:
|
|
||||||
version = await conn.fetchval("SELECT version()")
|
|
||||||
crit(f"Successfully connected to local database. PostgreSQL version: {version}")
|
|
||||||
|
|
||||||
# Sync schema across all databases
|
# Sync schema across all databases
|
||||||
await API.sync_schema()
|
await API.sync_schema()
|
||||||
crit("Schema synchronization complete.")
|
crit("Schema synchronization complete.")
|
||||||
|
|
||||||
# Attempt to pull changes from another database
|
# Check if other instances have more recent data
|
||||||
source = await API.get_default_source()
|
source = await API.get_most_recent_source()
|
||||||
if source:
|
if source:
|
||||||
crit(f"Pulling changes from {source['ts_id']}...")
|
crit(f"Pulling changes from {source['ts_id']}...")
|
||||||
await API.pull_changes(source)
|
await API.pull_changes(source)
|
||||||
crit("Data pull complete.")
|
crit("Data pull complete.")
|
||||||
else:
|
else:
|
||||||
crit("No available source for pulling changes. This might be the only active database.")
|
crit("No instances with more recent data found.")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
crit(f"Error during startup: {str(e)}")
|
crit(f"Error during startup: {str(e)}")
|
||||||
|
@ -93,7 +85,6 @@ async def lifespan(app: FastAPI):
|
||||||
crit("Shutting down...")
|
crit("Shutting down...")
|
||||||
# Perform any cleanup operations here if needed
|
# Perform any cleanup operations here if needed
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
|
@ -124,7 +115,6 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
|
||||||
content={"detail": "Invalid or missing API key"}
|
content={"detail": "Invalid or missing API key"}
|
||||||
)
|
)
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
# debug(f"Request from {client_ip} is complete")
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# Add the middleware to your FastAPI app
|
# Add the middleware to your FastAPI app
|
||||||
|
@ -136,7 +126,6 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
||||||
err(f"Request: {request.method} {request.url}")
|
err(f"Request: {request.method} {request.url}")
|
||||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def handle_exception_middleware(request: Request, call_next):
|
async def handle_exception_middleware(request: Request, call_next):
|
||||||
try:
|
try:
|
||||||
|
@ -149,6 +138,19 @@ async def handle_exception_middleware(request: Request, call_next):
|
||||||
raise
|
raise
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def sync_middleware(request: Request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# Check if the request was a database write operation
|
||||||
|
if request.method in ["POST", "PUT", "PATCH", "DELETE"]:
|
||||||
|
try:
|
||||||
|
# Push changes to other databases
|
||||||
|
await API.push_changes_to_all()
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Error pushing changes to other databases: {str(e)}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
def load_router(router_name):
|
def load_router(router_name):
|
||||||
router_file = ROUTER_DIR / f'{router_name}.py'
|
router_file = ROUTER_DIR / f'{router_name}.py'
|
||||||
|
@ -160,25 +162,16 @@ def load_router(router_name):
|
||||||
module = importlib.import_module(module_path)
|
module = importlib.import_module(module_path)
|
||||||
router = getattr(module, router_name)
|
router = getattr(module, router_name)
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
# module_logger.info(f"{router_name.capitalize()} router loaded.")
|
|
||||||
except (ImportError, AttributeError) as e:
|
except (ImportError, AttributeError) as e:
|
||||||
module_logger.critical(f"Failed to load router {router_name}: {e}")
|
module_logger.critical(f"Failed to load router {router_name}: {e}")
|
||||||
else:
|
else:
|
||||||
module_logger.error(f"Router file for {router_name} does not exist.")
|
module_logger.error(f"Router file for {router_name} does not exist.")
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
if args.test:
|
|
||||||
load_router(args.test)
|
|
||||||
else:
|
|
||||||
crit(f"sijapi launched")
|
|
||||||
crit(f"Arguments: {args}")
|
|
||||||
for module_name in API.MODULES.__fields__:
|
|
||||||
if getattr(API.MODULES, module_name):
|
|
||||||
load_router(module_name)
|
|
||||||
|
|
||||||
config = HypercornConfig()
|
config = HypercornConfig()
|
||||||
config.bind = [API.BIND]
|
config.bind = [API.BIND]
|
||||||
|
config.startup_timeout = 3600 # 1 hour
|
||||||
asyncio.run(serve(app, config))
|
asyncio.run(serve(app, config))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main(sys.argv[1:])
|
main(sys.argv[1:])
|
||||||
|
|
|
@ -31,6 +31,8 @@ def warn(text: str): logger.warning(text)
|
||||||
def err(text: str): logger.error(text)
|
def err(text: str): logger.error(text)
|
||||||
def crit(text: str): logger.critical(text)
|
def crit(text: str): logger.critical(text)
|
||||||
|
|
||||||
|
TS_ID=os.getenv("TS_ID", "NULL")
|
||||||
|
|
||||||
T = TypeVar('T', bound='Configuration')
|
T = TypeVar('T', bound='Configuration')
|
||||||
class Configuration(BaseModel):
|
class Configuration(BaseModel):
|
||||||
HOME: Path = Path.home()
|
HOME: Path = Path.home()
|
||||||
|
@ -156,6 +158,17 @@ class Configuration(BaseModel):
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
from pathlib import Path
|
||||||
|
import yaml
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
import asyncpg
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from sijapi import TS_ID
|
||||||
|
import traceback
|
||||||
|
|
||||||
class APIConfig(BaseModel):
|
class APIConfig(BaseModel):
|
||||||
HOST: str
|
HOST: str
|
||||||
PORT: int
|
PORT: int
|
||||||
|
@ -183,7 +196,6 @@ class APIConfig(BaseModel):
|
||||||
try:
|
try:
|
||||||
with open(secrets_path, 'r') as file:
|
with open(secrets_path, 'r') as file:
|
||||||
secrets_data = yaml.safe_load(file)
|
secrets_data = yaml.safe_load(file)
|
||||||
# info(f"Loaded secrets: {secrets_data}")
|
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
err(f"Secrets file not found: {secrets_path}")
|
err(f"Secrets file not found: {secrets_path}")
|
||||||
secrets_data = {}
|
secrets_data = {}
|
||||||
|
@ -279,8 +291,7 @@ class APIConfig(BaseModel):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def local_db(self):
|
def local_db(self):
|
||||||
ts_id = os.environ.get('TS_ID')
|
return next((db for db in self.POOL if db['ts_id'] == TS_ID), None)
|
||||||
return next((db for db in self.POOL if db['ts_id'] == ts_id), None)
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def get_connection(self, pool_entry: Dict[str, Any] = None):
|
async def get_connection(self, pool_entry: Dict[str, Any] = None):
|
||||||
|
@ -305,104 +316,158 @@ class APIConfig(BaseModel):
|
||||||
crit(f"Error: {str(e)}")
|
crit(f"Error: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def push_changes(self, query: str, *args):
|
async def initialize_sync(self):
|
||||||
connections = []
|
async with self.get_connection() as conn:
|
||||||
try:
|
# Create sync_status table
|
||||||
for pool_entry in self.POOL[1:]: # Skip the first (local) database
|
await conn.execute("""
|
||||||
conn = await self.get_connection(pool_entry).__aenter__()
|
CREATE TABLE IF NOT EXISTS sync_status (
|
||||||
connections.append(conn)
|
table_name TEXT,
|
||||||
|
server_id TEXT,
|
||||||
|
last_synced_version INTEGER,
|
||||||
|
PRIMARY KEY (table_name, server_id)
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Get all tables
|
||||||
|
tables = await conn.fetch("""
|
||||||
|
SELECT tablename FROM pg_tables
|
||||||
|
WHERE schemaname = 'public'
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Add version and server_id columns to all tables, create triggers
|
||||||
|
for table in tables:
|
||||||
|
table_name = table['tablename']
|
||||||
|
await conn.execute(f"""
|
||||||
|
ALTER TABLE "{table_name}"
|
||||||
|
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
|
||||||
|
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{TS_ID}';
|
||||||
|
|
||||||
results = await asyncio.gather(
|
CREATE OR REPLACE FUNCTION update_version_and_server_id()
|
||||||
*[conn.execute(query, *args) for conn in connections],
|
RETURNS TRIGGER AS $$
|
||||||
return_exceptions=True
|
BEGIN
|
||||||
)
|
NEW.version = COALESCE(OLD.version, 0) + 1;
|
||||||
|
NEW.server_id = '{TS_ID}';
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
|
||||||
for pool_entry, result in zip(self.POOL[1:], results):
|
DROP TRIGGER IF EXISTS update_version_and_server_id_trigger ON "{table_name}";
|
||||||
if isinstance(result, Exception):
|
CREATE TRIGGER update_version_and_server_id_trigger
|
||||||
err(f"Failed to push to {pool_entry['ts_ip']}: {str(result)}")
|
BEFORE INSERT OR UPDATE ON "{table_name}"
|
||||||
else:
|
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
|
||||||
info(f"Successfully pushed to {pool_entry['ts_ip']}")
|
|
||||||
|
|
||||||
finally:
|
INSERT INTO sync_status (table_name, server_id, last_synced_version)
|
||||||
for conn in connections:
|
VALUES ('{table_name}', '{TS_ID}', 0)
|
||||||
await conn.__aexit__(None, None, None)
|
ON CONFLICT (table_name, server_id) DO NOTHING;
|
||||||
|
""")
|
||||||
|
|
||||||
async def get_default_source(self):
|
async def get_most_recent_source(self):
|
||||||
local = self.local_db
|
most_recent_source = None
|
||||||
for db in self.POOL:
|
max_version = -1
|
||||||
if db != local:
|
|
||||||
try:
|
for pool_entry in self.POOL:
|
||||||
async with self.get_connection(db):
|
if pool_entry['ts_id'] == TS_ID:
|
||||||
return db
|
continue
|
||||||
except:
|
|
||||||
|
try:
|
||||||
|
async with self.get_connection(pool_entry) as conn:
|
||||||
|
version = await conn.fetchval("""
|
||||||
|
SELECT MAX(last_synced_version) FROM sync_status
|
||||||
|
""")
|
||||||
|
if version > max_version:
|
||||||
|
max_version = version
|
||||||
|
most_recent_source = pool_entry
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}")
|
||||||
|
|
||||||
|
return most_recent_source
|
||||||
|
|
||||||
|
async def pull_changes(self, source_pool_entry):
|
||||||
|
async with self.get_connection(source_pool_entry) as source_conn:
|
||||||
|
async with self.get_connection() as dest_conn:
|
||||||
|
tables = await source_conn.fetch("""
|
||||||
|
SELECT tablename FROM pg_tables
|
||||||
|
WHERE schemaname = 'public'
|
||||||
|
""")
|
||||||
|
|
||||||
|
for table in tables:
|
||||||
|
table_name = table['tablename']
|
||||||
|
last_synced_version = await self.get_last_synced_version(table_name, source_pool_entry['ts_id'])
|
||||||
|
|
||||||
|
changes = await source_conn.fetch(f"""
|
||||||
|
SELECT * FROM "{table_name}"
|
||||||
|
WHERE version > $1 AND server_id = $2
|
||||||
|
ORDER BY version ASC
|
||||||
|
""", last_synced_version, source_pool_entry['ts_id'])
|
||||||
|
|
||||||
|
for change in changes:
|
||||||
|
columns = change.keys()
|
||||||
|
values = [change[col] for col in columns]
|
||||||
|
await dest_conn.execute(f"""
|
||||||
|
INSERT INTO "{table_name}" ({', '.join(columns)})
|
||||||
|
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
|
||||||
|
ON CONFLICT (id) DO UPDATE SET
|
||||||
|
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
|
||||||
|
""", *values)
|
||||||
|
|
||||||
|
if changes:
|
||||||
|
await self.update_sync_status(table_name, source_pool_entry['ts_id'], changes[-1]['version'])
|
||||||
|
|
||||||
|
async def push_changes_to_all(self):
|
||||||
|
async with self.get_connection() as local_conn:
|
||||||
|
tables = await local_conn.fetch("""
|
||||||
|
SELECT tablename FROM pg_tables
|
||||||
|
WHERE schemaname = 'public'
|
||||||
|
""")
|
||||||
|
|
||||||
|
for pool_entry in self.POOL:
|
||||||
|
if pool_entry['ts_id'] == TS_ID:
|
||||||
continue
|
continue
|
||||||
return None
|
|
||||||
|
try:
|
||||||
|
async with self.get_connection(pool_entry) as remote_conn:
|
||||||
async def pull_changes(self, source_pool_entry: Dict[str, Any] = None):
|
for table in tables:
|
||||||
try:
|
table_name = table['tablename']
|
||||||
if source_pool_entry is None:
|
last_synced_version = await self.get_last_synced_version(table_name, pool_entry['ts_id'])
|
||||||
source_pool_entry = await self.get_default_source()
|
|
||||||
|
changes = await local_conn.fetch(f"""
|
||||||
if source_pool_entry is None:
|
SELECT * FROM "{table_name}"
|
||||||
err("No available source for pulling changes")
|
WHERE version > $1 AND server_id = $2
|
||||||
return
|
ORDER BY version ASC
|
||||||
|
""", last_synced_version, TS_ID)
|
||||||
async with self.get_connection(source_pool_entry) as source_conn:
|
|
||||||
async with self.get_connection() as dest_conn:
|
for change in changes:
|
||||||
tables = await source_conn.fetch(
|
columns = change.keys()
|
||||||
"SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
|
values = [change[col] for col in columns]
|
||||||
)
|
await remote_conn.execute(f"""
|
||||||
for table in tables:
|
INSERT INTO "{table_name}" ({', '.join(columns)})
|
||||||
table_name = table['tablename']
|
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
|
||||||
info(f"Processing table: {table_name}")
|
ON CONFLICT (id) DO UPDATE SET
|
||||||
|
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
|
||||||
# Get primary key column(s)
|
""", *values)
|
||||||
pk_columns = await source_conn.fetch("""
|
|
||||||
SELECT a.attname
|
if changes:
|
||||||
FROM pg_index i
|
await self.update_sync_status(table_name, pool_entry['ts_id'], changes[-1]['version'])
|
||||||
JOIN pg_attribute a ON a.attrelid = i.indrelid
|
|
||||||
AND a.attnum = ANY(i.indkey)
|
info(f"Successfully pushed changes to {pool_entry['ts_id']}")
|
||||||
WHERE i.indrelid = $1::regclass
|
except Exception as e:
|
||||||
AND i.indisprimary;
|
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
|
||||||
""", table_name)
|
|
||||||
|
|
||||||
pk_cols = [col['attname'] for col in pk_columns]
|
|
||||||
info(f"Primary key columns for {table_name}: {pk_cols}")
|
|
||||||
if not pk_cols:
|
|
||||||
warn(f"No primary key found for table {table_name}. Skipping.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Fetch all rows from the source table
|
|
||||||
rows = await source_conn.fetch(f"SELECT * FROM {table_name}")
|
|
||||||
info(f"Fetched {len(rows)} rows from {table_name}")
|
|
||||||
if rows:
|
|
||||||
columns = list(rows[0].keys())
|
|
||||||
info(f"Columns for {table_name}: {columns}")
|
|
||||||
# Upsert records to the destination table
|
|
||||||
for row in rows:
|
|
||||||
try:
|
|
||||||
query = f"""
|
|
||||||
INSERT INTO {table_name} ({', '.join(columns)})
|
|
||||||
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
|
|
||||||
ON CONFLICT ({', '.join(pk_cols)}) DO UPDATE SET
|
|
||||||
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in pk_cols)}
|
|
||||||
"""
|
|
||||||
info(f"Executing query: {query}")
|
|
||||||
info(f"With values: {[row[col] for col in columns]}")
|
|
||||||
await dest_conn.execute(query, *[row[col] for col in columns])
|
|
||||||
except Exception as e:
|
|
||||||
err(f"Error processing row in {table_name}: {str(e)}")
|
|
||||||
err(f"Problematic row: {row}")
|
|
||||||
|
|
||||||
info(f"Completed processing table: {table_name}")
|
|
||||||
|
|
||||||
info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
|
|
||||||
except Exception as e:
|
|
||||||
err(f"Unexpected error in pull_changes: {str(e)}")
|
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
|
async def get_last_synced_version(self, table_name, server_id):
|
||||||
|
async with self.get_connection() as conn:
|
||||||
|
return await conn.fetchval("""
|
||||||
|
SELECT last_synced_version FROM sync_status
|
||||||
|
WHERE table_name = $1 AND server_id = $2
|
||||||
|
""", table_name, server_id) or 0
|
||||||
|
|
||||||
|
async def update_sync_status(self, table_name, server_id, version):
|
||||||
|
async with self.get_connection() as conn:
|
||||||
|
await conn.execute("""
|
||||||
|
INSERT INTO sync_status (table_name, server_id, last_synced_version)
|
||||||
|
VALUES ($1, $2, $3)
|
||||||
|
ON CONFLICT (table_name, server_id) DO UPDATE
|
||||||
|
SET last_synced_version = EXCLUDED.last_synced_version
|
||||||
|
""", table_name, server_id, version)
|
||||||
|
|
||||||
async def sync_schema(self):
|
async def sync_schema(self):
|
||||||
for pool_entry in self.POOL:
|
for pool_entry in self.POOL:
|
||||||
|
@ -459,45 +524,6 @@ class APIConfig(BaseModel):
|
||||||
END $$;
|
END $$;
|
||||||
""")
|
""")
|
||||||
|
|
||||||
async def get_schema(self, pool_entry: Dict[str, Any]):
|
|
||||||
async with self.get_connection(pool_entry) as conn:
|
|
||||||
tables = await conn.fetch("""
|
|
||||||
SELECT table_name, column_name, data_type, character_maximum_length,
|
|
||||||
is_nullable, column_default, ordinal_position
|
|
||||||
FROM information_schema.columns
|
|
||||||
WHERE table_schema = 'public'
|
|
||||||
ORDER BY table_name, ordinal_position
|
|
||||||
""")
|
|
||||||
|
|
||||||
indexes = await conn.fetch("""
|
|
||||||
SELECT indexname, indexdef
|
|
||||||
FROM pg_indexes
|
|
||||||
WHERE schemaname = 'public'
|
|
||||||
""")
|
|
||||||
|
|
||||||
constraints = await conn.fetch("""
|
|
||||||
SELECT conname, contype, conrelid::regclass::text as table_name,
|
|
||||||
pg_get_constraintdef(oid) as definition
|
|
||||||
FROM pg_constraint
|
|
||||||
WHERE connamespace = 'public'::regnamespace
|
|
||||||
""")
|
|
||||||
|
|
||||||
return {
|
|
||||||
'tables': tables,
|
|
||||||
'indexes': indexes,
|
|
||||||
'constraints': constraints
|
|
||||||
}
|
|
||||||
|
|
||||||
async def create_sequence_if_not_exists(self, conn, sequence_name):
|
|
||||||
await conn.execute(f"""
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
|
|
||||||
CREATE SEQUENCE {sequence_name};
|
|
||||||
END IF;
|
|
||||||
END $$;
|
|
||||||
""")
|
|
||||||
|
|
||||||
async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
|
async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
|
||||||
async with self.get_connection(pool_entry) as conn:
|
async with self.get_connection(pool_entry) as conn:
|
||||||
source_tables = {t['table_name']: t for t in source_schema['tables']}
|
source_tables = {t['table_name']: t for t in source_schema['tables']}
|
||||||
|
@ -567,8 +593,6 @@ class APIConfig(BaseModel):
|
||||||
await conn.execute(sql)
|
await conn.execute(sql)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error processing table {table_name}: {str(e)}")
|
err(f"Error processing table {table_name}: {str(e)}")
|
||||||
# Optionally, you might want to raise this exception if you want to stop the entire process
|
|
||||||
# raise
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']}
|
source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']}
|
||||||
|
|
Loading…
Reference in a new issue