Auto-update: Thu Jul 25 09:06:06 PDT 2024

This commit is contained in:
sanj 2024-07-25 09:06:06 -07:00
parent 0945b1eb84
commit e3ef2781b0

View file

@ -164,6 +164,7 @@ class Configuration(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
class APIConfig(BaseModel): class APIConfig(BaseModel):
HOST: str HOST: str
PORT: int PORT: int
@ -286,9 +287,10 @@ class APIConfig(BaseModel):
@property @property
def local_db(self): def local_db(self):
local_db = next((db for db in self.POOL if db['ts_id'] == TS_ID), None) ts_id = os.environ.get('TS_ID')
local_db = next((db for db in self.POOL if db['ts_id'] == ts_id), None)
if local_db is None: if local_db is None:
raise ValueError(f"No database configuration found for TS_ID: {TS_ID}") raise ValueError(f"No database configuration found for TS_ID: {ts_id}")
return local_db return local_db
@asynccontextmanager @asynccontextmanager
@ -316,7 +318,6 @@ class APIConfig(BaseModel):
async def initialize_sync(self): async def initialize_sync(self):
async with self.get_connection() as conn: async with self.get_connection() as conn:
# Create sync_status table
await conn.execute(""" await conn.execute("""
CREATE TABLE IF NOT EXISTS sync_status ( CREATE TABLE IF NOT EXISTS sync_status (
table_name TEXT, table_name TEXT,
@ -326,25 +327,23 @@ class APIConfig(BaseModel):
) )
""") """)
# Get all tables
tables = await conn.fetch(""" tables = await conn.fetch("""
SELECT tablename FROM pg_tables SELECT tablename FROM pg_tables
WHERE schemaname = 'public' WHERE schemaname = 'public'
""") """)
# Add version and server_id columns to all tables, create triggers
for table in tables: for table in tables:
table_name = table['tablename'] table_name = table['tablename']
await conn.execute(f""" await conn.execute(f"""
ALTER TABLE "{table_name}" ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1, ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{TS_ID}'; ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
CREATE OR REPLACE FUNCTION update_version_and_server_id() CREATE OR REPLACE FUNCTION update_version_and_server_id()
RETURNS TRIGGER AS $$ RETURNS TRIGGER AS $$
BEGIN BEGIN
NEW.version = COALESCE(OLD.version, 0) + 1; NEW.version = COALESCE(OLD.version, 0) + 1;
NEW.server_id = '{TS_ID}'; NEW.server_id = '{os.environ.get('TS_ID')}';
RETURN NEW; RETURN NEW;
END; END;
$$ LANGUAGE plpgsql; $$ LANGUAGE plpgsql;
@ -355,7 +354,7 @@ class APIConfig(BaseModel):
FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id(); FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
INSERT INTO sync_status (table_name, server_id, last_synced_version) INSERT INTO sync_status (table_name, server_id, last_synced_version)
VALUES ('{table_name}', '{TS_ID}', 0) VALUES ('{table_name}', '{os.environ.get('TS_ID')}', 0)
ON CONFLICT (table_name, server_id) DO NOTHING; ON CONFLICT (table_name, server_id) DO NOTHING;
""") """)
@ -364,13 +363,13 @@ class APIConfig(BaseModel):
max_version = -1 max_version = -1
for pool_entry in self.POOL: for pool_entry in self.POOL:
if pool_entry['ts_id'] == TS_ID: if pool_entry['ts_id'] == os.environ.get('TS_ID'):
continue continue
try: try:
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
version = await conn.fetchval(""" version = await conn.fetchval("""
SELECT MAX(last_synced_version) FROM sync_status SELECT COALESCE(MAX(last_synced_version), -1) FROM sync_status
""") """)
if version > max_version: if version > max_version:
max_version = version max_version = version
@ -419,7 +418,7 @@ class APIConfig(BaseModel):
""") """)
for pool_entry in self.POOL: for pool_entry in self.POOL:
if pool_entry['ts_id'] == TS_ID: if pool_entry['ts_id'] == os.environ.get('TS_ID'):
continue continue
try: try:
@ -432,7 +431,7 @@ class APIConfig(BaseModel):
SELECT * FROM "{table_name}" SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2 WHERE version > $1 AND server_id = $2
ORDER BY version ASC ORDER BY version ASC
""", last_synced_version, TS_ID) """, last_synced_version, os.environ.get('TS_ID'))
for change in changes: for change in changes:
columns = change.keys() columns = change.keys()
@ -512,16 +511,6 @@ class APIConfig(BaseModel):
'constraints': constraints 'constraints': constraints
} }
async def create_sequence_if_not_exists(self, conn, sequence_name):
await conn.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
CREATE SEQUENCE {sequence_name};
END IF;
END $$;
""")
async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema): async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
async with self.get_connection(pool_entry) as conn: async with self.get_connection(pool_entry) as conn:
source_tables = {t['table_name']: t for t in source_schema['tables']} source_tables = {t['table_name']: t for t in source_schema['tables']}
@ -529,7 +518,7 @@ class APIConfig(BaseModel):
def get_column_type(data_type): def get_column_type(data_type):
if data_type == 'ARRAY': if data_type == 'ARRAY':
return 'text[]' # or another appropriate type return 'text[]'
elif data_type == 'USER-DEFINED': elif data_type == 'USER-DEFINED':
return 'geometry' return 'geometry'
else: else:
@ -630,6 +619,16 @@ class APIConfig(BaseModel):
info(f"Schema synchronization completed for {pool_entry['ts_ip']}") info(f"Schema synchronization completed for {pool_entry['ts_ip']}")
async def create_sequence_if_not_exists(self, conn, sequence_name):
await conn.execute(f"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
CREATE SEQUENCE {sequence_name};
END IF;
END $$;
""")
class Location(BaseModel): class Location(BaseModel):
latitude: float latitude: float