Auto-update: Thu Jul 25 09:06:06 PDT 2024
This commit is contained in:
parent
0945b1eb84
commit
e3ef2781b0
1 changed files with 22 additions and 23 deletions
|
@ -164,6 +164,7 @@ class Configuration(BaseModel):
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class APIConfig(BaseModel):
|
class APIConfig(BaseModel):
|
||||||
HOST: str
|
HOST: str
|
||||||
PORT: int
|
PORT: int
|
||||||
|
@ -286,9 +287,10 @@ class APIConfig(BaseModel):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def local_db(self):
|
def local_db(self):
|
||||||
local_db = next((db for db in self.POOL if db['ts_id'] == TS_ID), None)
|
ts_id = os.environ.get('TS_ID')
|
||||||
|
local_db = next((db for db in self.POOL if db['ts_id'] == ts_id), None)
|
||||||
if local_db is None:
|
if local_db is None:
|
||||||
raise ValueError(f"No database configuration found for TS_ID: {TS_ID}")
|
raise ValueError(f"No database configuration found for TS_ID: {ts_id}")
|
||||||
return local_db
|
return local_db
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
|
@ -316,7 +318,6 @@ class APIConfig(BaseModel):
|
||||||
|
|
||||||
async def initialize_sync(self):
|
async def initialize_sync(self):
|
||||||
async with self.get_connection() as conn:
|
async with self.get_connection() as conn:
|
||||||
# Create sync_status table
|
|
||||||
await conn.execute("""
|
await conn.execute("""
|
||||||
CREATE TABLE IF NOT EXISTS sync_status (
|
CREATE TABLE IF NOT EXISTS sync_status (
|
||||||
table_name TEXT,
|
table_name TEXT,
|
||||||
|
@ -326,25 +327,23 @@ class APIConfig(BaseModel):
|
||||||
)
|
)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Get all tables
|
|
||||||
tables = await conn.fetch("""
|
tables = await conn.fetch("""
|
||||||
SELECT tablename FROM pg_tables
|
SELECT tablename FROM pg_tables
|
||||||
WHERE schemaname = 'public'
|
WHERE schemaname = 'public'
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Add version and server_id columns to all tables, create triggers
|
|
||||||
for table in tables:
|
for table in tables:
|
||||||
table_name = table['tablename']
|
table_name = table['tablename']
|
||||||
await conn.execute(f"""
|
await conn.execute(f"""
|
||||||
ALTER TABLE "{table_name}"
|
ALTER TABLE "{table_name}"
|
||||||
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
|
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
|
||||||
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{TS_ID}';
|
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
|
||||||
|
|
||||||
CREATE OR REPLACE FUNCTION update_version_and_server_id()
|
CREATE OR REPLACE FUNCTION update_version_and_server_id()
|
||||||
RETURNS TRIGGER AS $$
|
RETURNS TRIGGER AS $$
|
||||||
BEGIN
|
BEGIN
|
||||||
NEW.version = COALESCE(OLD.version, 0) + 1;
|
NEW.version = COALESCE(OLD.version, 0) + 1;
|
||||||
NEW.server_id = '{TS_ID}';
|
NEW.server_id = '{os.environ.get('TS_ID')}';
|
||||||
RETURN NEW;
|
RETURN NEW;
|
||||||
END;
|
END;
|
||||||
$$ LANGUAGE plpgsql;
|
$$ LANGUAGE plpgsql;
|
||||||
|
@ -355,7 +354,7 @@ class APIConfig(BaseModel):
|
||||||
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
|
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
|
||||||
|
|
||||||
INSERT INTO sync_status (table_name, server_id, last_synced_version)
|
INSERT INTO sync_status (table_name, server_id, last_synced_version)
|
||||||
VALUES ('{table_name}', '{TS_ID}', 0)
|
VALUES ('{table_name}', '{os.environ.get('TS_ID')}', 0)
|
||||||
ON CONFLICT (table_name, server_id) DO NOTHING;
|
ON CONFLICT (table_name, server_id) DO NOTHING;
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
@ -364,13 +363,13 @@ class APIConfig(BaseModel):
|
||||||
max_version = -1
|
max_version = -1
|
||||||
|
|
||||||
for pool_entry in self.POOL:
|
for pool_entry in self.POOL:
|
||||||
if pool_entry['ts_id'] == TS_ID:
|
if pool_entry['ts_id'] == os.environ.get('TS_ID'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with self.get_connection(pool_entry) as conn:
|
async with self.get_connection(pool_entry) as conn:
|
||||||
version = await conn.fetchval("""
|
version = await conn.fetchval("""
|
||||||
SELECT MAX(last_synced_version) FROM sync_status
|
SELECT COALESCE(MAX(last_synced_version), -1) FROM sync_status
|
||||||
""")
|
""")
|
||||||
if version > max_version:
|
if version > max_version:
|
||||||
max_version = version
|
max_version = version
|
||||||
|
@ -419,7 +418,7 @@ class APIConfig(BaseModel):
|
||||||
""")
|
""")
|
||||||
|
|
||||||
for pool_entry in self.POOL:
|
for pool_entry in self.POOL:
|
||||||
if pool_entry['ts_id'] == TS_ID:
|
if pool_entry['ts_id'] == os.environ.get('TS_ID'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -432,7 +431,7 @@ class APIConfig(BaseModel):
|
||||||
SELECT * FROM "{table_name}"
|
SELECT * FROM "{table_name}"
|
||||||
WHERE version > $1 AND server_id = $2
|
WHERE version > $1 AND server_id = $2
|
||||||
ORDER BY version ASC
|
ORDER BY version ASC
|
||||||
""", last_synced_version, TS_ID)
|
""", last_synced_version, os.environ.get('TS_ID'))
|
||||||
|
|
||||||
for change in changes:
|
for change in changes:
|
||||||
columns = change.keys()
|
columns = change.keys()
|
||||||
|
@ -512,16 +511,6 @@ class APIConfig(BaseModel):
|
||||||
'constraints': constraints
|
'constraints': constraints
|
||||||
}
|
}
|
||||||
|
|
||||||
async def create_sequence_if_not_exists(self, conn, sequence_name):
|
|
||||||
await conn.execute(f"""
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
|
|
||||||
CREATE SEQUENCE {sequence_name};
|
|
||||||
END IF;
|
|
||||||
END $$;
|
|
||||||
""")
|
|
||||||
|
|
||||||
async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
|
async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
|
||||||
async with self.get_connection(pool_entry) as conn:
|
async with self.get_connection(pool_entry) as conn:
|
||||||
source_tables = {t['table_name']: t for t in source_schema['tables']}
|
source_tables = {t['table_name']: t for t in source_schema['tables']}
|
||||||
|
@ -529,7 +518,7 @@ class APIConfig(BaseModel):
|
||||||
|
|
||||||
def get_column_type(data_type):
|
def get_column_type(data_type):
|
||||||
if data_type == 'ARRAY':
|
if data_type == 'ARRAY':
|
||||||
return 'text[]' # or another appropriate type
|
return 'text[]'
|
||||||
elif data_type == 'USER-DEFINED':
|
elif data_type == 'USER-DEFINED':
|
||||||
return 'geometry'
|
return 'geometry'
|
||||||
else:
|
else:
|
||||||
|
@ -630,6 +619,16 @@ class APIConfig(BaseModel):
|
||||||
|
|
||||||
info(f"Schema synchronization completed for {pool_entry['ts_ip']}")
|
info(f"Schema synchronization completed for {pool_entry['ts_ip']}")
|
||||||
|
|
||||||
|
async def create_sequence_if_not_exists(self, conn, sequence_name):
|
||||||
|
await conn.execute(f"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
|
||||||
|
CREATE SEQUENCE {sequence_name};
|
||||||
|
END IF;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
class Location(BaseModel):
|
class Location(BaseModel):
|
||||||
latitude: float
|
latitude: float
|
||||||
|
|
Loading…
Reference in a new issue