Auto-update: Thu Jul 25 01:27:56 PDT 2024

This commit is contained in:
sanj 2024-07-25 01:27:56 -07:00
parent 8775c05927
commit 0be1915aeb
2 changed files with 186 additions and 169 deletions

View file

@ -5,7 +5,6 @@ from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import ClientDisconnect
from hypercorn.asyncio import serve
from hypercorn.config import Config as HypercornConfig
@ -44,7 +43,6 @@ err(f"Error message.")
def crit(text: str): logger.critical(text)
crit(f"Critical message.")
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
@ -52,36 +50,30 @@ async def lifespan(app: FastAPI):
crit(f"Arguments: {args}")
# Load routers
if args.test:
load_router(args.test)
else:
for module_name in API.MODULES.__fields__:
if getattr(API.MODULES, module_name):
load_router(module_name)
crit("Starting database synchronization...")
try:
# Log the current TS_ID
crit(f"Current TS_ID: {os.environ.get('TS_ID', 'Not set')}")
# Log the local_db configuration
local_db = API.local_db
crit(f"Local DB configuration: {local_db}")
# Test local connection
async with API.get_connection() as conn:
version = await conn.fetchval("SELECT version()")
crit(f"Successfully connected to local database. PostgreSQL version: {version}")
# Initialize sync structures
await API.initialize_sync()
# Sync schema across all databases
await API.sync_schema()
crit("Schema synchronization complete.")
# Attempt to pull changes from another database
source = await API.get_default_source()
# Check if other instances have more recent data
source = await API.get_most_recent_source()
if source:
crit(f"Pulling changes from {source['ts_id']}...")
await API.pull_changes(source)
crit("Data pull complete.")
else:
crit("No available source for pulling changes. This might be the only active database.")
crit("No instances with more recent data found.")
except Exception as e:
crit(f"Error during startup: {str(e)}")
@ -93,7 +85,6 @@ async def lifespan(app: FastAPI):
crit("Shutting down...")
# Perform any cleanup operations here if needed
app = FastAPI(lifespan=lifespan)
app.add_middleware(
@ -124,7 +115,6 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
content={"detail": "Invalid or missing API key"}
)
response = await call_next(request)
# debug(f"Request from {client_ip} is complete")
return response
# Add the middleware to your FastAPI app
@ -136,7 +126,6 @@ async def http_exception_handler(request: Request, exc: HTTPException):
err(f"Request: {request.method} {request.url}")
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
@app.middleware("http")
async def handle_exception_middleware(request: Request, call_next):
try:
@ -149,6 +138,19 @@ async def handle_exception_middleware(request: Request, call_next):
raise
return response
@app.middleware("http")
async def sync_middleware(request: Request, call_next):
response = await call_next(request)
# Check if the request was a database write operation
if request.method in ["POST", "PUT", "PATCH", "DELETE"]:
try:
# Push changes to other databases
await API.push_changes_to_all()
except Exception as e:
err(f"Error pushing changes to other databases: {str(e)}")
return response
def load_router(router_name):
router_file = ROUTER_DIR / f'{router_name}.py'
@ -160,24 +162,15 @@ def load_router(router_name):
module = importlib.import_module(module_path)
router = getattr(module, router_name)
app.include_router(router)
# module_logger.info(f"{router_name.capitalize()} router loaded.")
except (ImportError, AttributeError) as e:
module_logger.critical(f"Failed to load router {router_name}: {e}")
else:
module_logger.error(f"Router file for {router_name} does not exist.")
def main(argv):
if args.test:
load_router(args.test)
else:
crit(f"sijapi launched")
crit(f"Arguments: {args}")
for module_name in API.MODULES.__fields__:
if getattr(API.MODULES, module_name):
load_router(module_name)
config = HypercornConfig()
config.bind = [API.BIND]
config.startup_timeout = 3600 # 1 hour
asyncio.run(serve(app, config))
if __name__ == "__main__":

View file

@ -31,6 +31,8 @@ def warn(text: str): logger.warning(text)
def err(text: str): logger.error(text)
def crit(text: str): logger.critical(text)
TS_ID=os.getenv("TS_ID", "NULL")
T = TypeVar('T', bound='Configuration')
class Configuration(BaseModel):
HOME: Path = Path.home()
@ -156,6 +158,17 @@ class Configuration(BaseModel):
arbitrary_types_allowed = True
from pydantic import BaseModel, Field
from typing import Any, Dict, List, Union
from pathlib import Path
import yaml
import re
import os
import asyncpg
from contextlib import asynccontextmanager
from sijapi import TS_ID
import traceback
class APIConfig(BaseModel):
HOST: str
PORT: int
@ -183,7 +196,6 @@ class APIConfig(BaseModel):
try:
with open(secrets_path, 'r') as file:
secrets_data = yaml.safe_load(file)
# info(f"Loaded secrets: {secrets_data}")
except FileNotFoundError:
err(f"Secrets file not found: {secrets_path}")
secrets_data = {}
@ -279,8 +291,7 @@ class APIConfig(BaseModel):
@property
def local_db(self):
ts_id = os.environ.get('TS_ID')
return next((db for db in self.POOL if db['ts_id'] == ts_id), None)
return next((db for db in self.POOL if db['ts_id'] == TS_ID), None)
@asynccontextmanager
async def get_connection(self, pool_entry: Dict[str, Any] = None):
@ -305,104 +316,158 @@ class APIConfig(BaseModel):
crit(f"Error: {str(e)}")
raise
async def push_changes(self, query: str, *args):
connections = []
try:
for pool_entry in self.POOL[1:]: # Skip the first (local) database
conn = await self.get_connection(pool_entry).__aenter__()
connections.append(conn)
results = await asyncio.gather(
*[conn.execute(query, *args) for conn in connections],
return_exceptions=True
async def initialize_sync(self):
async with self.get_connection() as conn:
# Create sync_status table
await conn.execute("""
CREATE TABLE IF NOT EXISTS sync_status (
table_name TEXT,
server_id TEXT,
last_synced_version INTEGER,
PRIMARY KEY (table_name, server_id)
)
""")
for pool_entry, result in zip(self.POOL[1:], results):
if isinstance(result, Exception):
err(f"Failed to push to {pool_entry['ts_ip']}: {str(result)}")
else:
info(f"Successfully pushed to {pool_entry['ts_ip']}")
# Get all tables
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
finally:
for conn in connections:
await conn.__aexit__(None, None, None)
async def get_default_source(self):
local = self.local_db
for db in self.POOL:
if db != local:
try:
async with self.get_connection(db):
return db
except:
continue
return None
async def pull_changes(self, source_pool_entry: Dict[str, Any] = None):
try:
if source_pool_entry is None:
source_pool_entry = await self.get_default_source()
if source_pool_entry is None:
err("No available source for pulling changes")
return
async with self.get_connection(source_pool_entry) as source_conn:
async with self.get_connection() as dest_conn:
tables = await source_conn.fetch(
"SELECT tablename FROM pg_tables WHERE schemaname = 'public'"
)
# Add version and server_id columns to all tables, create triggers
for table in tables:
table_name = table['tablename']
info(f"Processing table: {table_name}")
await conn.execute(f"""
ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{TS_ID}';
# Get primary key column(s)
pk_columns = await source_conn.fetch("""
SELECT a.attname
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid
AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = $1::regclass
AND i.indisprimary;
""", table_name)
CREATE OR REPLACE FUNCTION update_version_and_server_id()
RETURNS TRIGGER AS $$
BEGIN
NEW.version = COALESCE(OLD.version, 0) + 1;
NEW.server_id = '{TS_ID}';
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
pk_cols = [col['attname'] for col in pk_columns]
info(f"Primary key columns for {table_name}: {pk_cols}")
if not pk_cols:
warn(f"No primary key found for table {table_name}. Skipping.")
DROP TRIGGER IF EXISTS update_version_and_server_id_trigger ON "{table_name}";
CREATE TRIGGER update_version_and_server_id_trigger
BEFORE INSERT OR UPDATE ON "{table_name}"
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
INSERT INTO sync_status (table_name, server_id, last_synced_version)
VALUES ('{table_name}', '{TS_ID}', 0)
ON CONFLICT (table_name, server_id) DO NOTHING;
""")
async def get_most_recent_source(self):
most_recent_source = None
max_version = -1
for pool_entry in self.POOL:
if pool_entry['ts_id'] == TS_ID:
continue
# Fetch all rows from the source table
rows = await source_conn.fetch(f"SELECT * FROM {table_name}")
info(f"Fetched {len(rows)} rows from {table_name}")
if rows:
columns = list(rows[0].keys())
info(f"Columns for {table_name}: {columns}")
# Upsert records to the destination table
for row in rows:
try:
query = f"""
INSERT INTO {table_name} ({', '.join(columns)})
async with self.get_connection(pool_entry) as conn:
version = await conn.fetchval("""
SELECT MAX(last_synced_version) FROM sync_status
""")
if version > max_version:
max_version = version
most_recent_source = pool_entry
except Exception as e:
err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}")
return most_recent_source
async def pull_changes(self, source_pool_entry):
async with self.get_connection(source_pool_entry) as source_conn:
async with self.get_connection() as dest_conn:
tables = await source_conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
last_synced_version = await self.get_last_synced_version(table_name, source_pool_entry['ts_id'])
changes = await source_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2
ORDER BY version ASC
""", last_synced_version, source_pool_entry['ts_id'])
for change in changes:
columns = change.keys()
values = [change[col] for col in columns]
await dest_conn.execute(f"""
INSERT INTO "{table_name}" ({', '.join(columns)})
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
ON CONFLICT ({', '.join(pk_cols)}) DO UPDATE SET
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col not in pk_cols)}
"""
info(f"Executing query: {query}")
info(f"With values: {[row[col] for col in columns]}")
await dest_conn.execute(query, *[row[col] for col in columns])
ON CONFLICT (id) DO UPDATE SET
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
""", *values)
if changes:
await self.update_sync_status(table_name, source_pool_entry['ts_id'], changes[-1]['version'])
async def push_changes_to_all(self):
async with self.get_connection() as local_conn:
tables = await local_conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for pool_entry in self.POOL:
if pool_entry['ts_id'] == TS_ID:
continue
try:
async with self.get_connection(pool_entry) as remote_conn:
for table in tables:
table_name = table['tablename']
last_synced_version = await self.get_last_synced_version(table_name, pool_entry['ts_id'])
changes = await local_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2
ORDER BY version ASC
""", last_synced_version, TS_ID)
for change in changes:
columns = change.keys()
values = [change[col] for col in columns]
await remote_conn.execute(f"""
INSERT INTO "{table_name}" ({', '.join(columns)})
VALUES ({', '.join(f'${i+1}' for i in range(len(columns)))})
ON CONFLICT (id) DO UPDATE SET
{', '.join(f"{col} = EXCLUDED.{col}" for col in columns if col != 'id')}
""", *values)
if changes:
await self.update_sync_status(table_name, pool_entry['ts_id'], changes[-1]['version'])
info(f"Successfully pushed changes to {pool_entry['ts_id']}")
except Exception as e:
err(f"Error processing row in {table_name}: {str(e)}")
err(f"Problematic row: {row}")
info(f"Completed processing table: {table_name}")
info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
except Exception as e:
err(f"Unexpected error in pull_changes: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
async def get_last_synced_version(self, table_name, server_id):
async with self.get_connection() as conn:
return await conn.fetchval("""
SELECT last_synced_version FROM sync_status
WHERE table_name = $1 AND server_id = $2
""", table_name, server_id) or 0
async def update_sync_status(self, table_name, server_id, version):
async with self.get_connection() as conn:
await conn.execute("""
INSERT INTO sync_status (table_name, server_id, last_synced_version)
VALUES ($1, $2, $3)
ON CONFLICT (table_name, server_id) DO UPDATE
SET last_synced_version = EXCLUDED.last_synced_version
""", table_name, server_id, version)
async def sync_schema(self):
for pool_entry in self.POOL:
@ -459,45 +524,6 @@ class APIConfig(BaseModel):
END $$;
""")
async def get_schema(self, pool_entry: Dict[str, Any]):
async with self.get_connection(pool_entry) as conn:
tables = await conn.fetch("""
SELECT table_name, column_name, data_type, character_maximum_length,
is_nullable, column_default, ordinal_position
FROM information_schema.columns
WHERE table_schema = 'public'
ORDER BY table_name, ordinal_position
""")
indexes = await conn.fetch("""
SELECT indexname, indexdef
FROM pg_indexes
WHERE schemaname = 'public'
""")
constraints = await conn.fetch("""
SELECT conname, contype, conrelid::regclass::text as table_name,
pg_get_constraintdef(oid) as definition
FROM pg_constraint
WHERE connamespace = 'public'::regnamespace
""")
return {
'tables': tables,
'indexes': indexes,
'constraints': constraints
}
async def create_sequence_if_not_exists(self, conn, sequence_name):
await conn.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
CREATE SEQUENCE {sequence_name};
END IF;
END $$;
""")
async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
async with self.get_connection(pool_entry) as conn:
source_tables = {t['table_name']: t for t in source_schema['tables']}
@ -567,8 +593,6 @@ class APIConfig(BaseModel):
await conn.execute(sql)
except Exception as e:
err(f"Error processing table {table_name}: {str(e)}")
# Optionally, you might want to raise this exception if you want to stop the entire process
# raise
try:
source_indexes = {idx['indexname']: idx['indexdef'] for idx in source_schema['indexes']}