From 5978f02c66ee2e5e7bb1b878eb3427753d9bb148 Mon Sep 17 00:00:00 2001
From: sanj <67624670+iodrift@users.noreply.github.com>
Date: Thu, 25 Jul 2024 09:06:06 -0700
Subject: [PATCH] Auto-update: Thu Jul 25 09:06:06 PDT 2024

---
 sijapi/classes.py | 45 ++++++++++++++++++++++-----------------------
 1 file changed, 22 insertions(+), 23 deletions(-)

diff --git a/sijapi/classes.py b/sijapi/classes.py
index b659c46..fa63327 100644
--- a/sijapi/classes.py
+++ b/sijapi/classes.py
@@ -164,6 +164,7 @@ class Configuration(BaseModel):
         arbitrary_types_allowed = True
 
 
+
 class APIConfig(BaseModel):
     HOST: str
     PORT: int
@@ -286,9 +287,10 @@ class APIConfig(BaseModel):
 
     @property
     def local_db(self):
-        local_db = next((db for db in self.POOL if db['ts_id'] == TS_ID), None)
+        ts_id = os.environ.get('TS_ID')
+        local_db = next((db for db in self.POOL if db['ts_id'] == ts_id), None)
         if local_db is None:
-            raise ValueError(f"No database configuration found for TS_ID: {TS_ID}")
+            raise ValueError(f"No database configuration found for TS_ID: {ts_id}")
         return local_db
 
     @asynccontextmanager
@@ -316,7 +318,6 @@ class APIConfig(BaseModel):
 
     async def initialize_sync(self):
         async with self.get_connection() as conn:
-            # Create sync_status table
             await conn.execute("""
                 CREATE TABLE IF NOT EXISTS sync_status (
                     table_name TEXT,
@@ -326,25 +327,23 @@ class APIConfig(BaseModel):
                 )
             """)
             
-            # Get all tables
             tables = await conn.fetch("""
                 SELECT tablename FROM pg_tables 
                 WHERE schemaname = 'public'
             """)
             
-            # Add version and server_id columns to all tables, create triggers
             for table in tables:
                 table_name = table['tablename']
                 await conn.execute(f"""
                     ALTER TABLE "{table_name}" 
                     ADD COLUMN IF NOT EXISTS version INTEGER DEFAULT 1,
-                    ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{TS_ID}';
+                    ADD COLUMN IF NOT EXISTS server_id TEXT DEFAULT '{os.environ.get('TS_ID')}';
 
                     CREATE OR REPLACE FUNCTION update_version_and_server_id()
                     RETURNS TRIGGER AS $$
                     BEGIN
                         NEW.version = COALESCE(OLD.version, 0) + 1;
-                        NEW.server_id = '{TS_ID}';
+                        NEW.server_id = '{os.environ.get('TS_ID')}';
                         RETURN NEW;
                     END;
                     $$ LANGUAGE plpgsql;
@@ -355,7 +354,7 @@ class APIConfig(BaseModel):
                     FOR EACH ROW EXECUTE FUNCTION update_version_and_server_id();
 
                     INSERT INTO sync_status (table_name, server_id, last_synced_version)
-                    VALUES ('{table_name}', '{TS_ID}', 0)
+                    VALUES ('{table_name}', '{os.environ.get('TS_ID')}', 0)
                     ON CONFLICT (table_name, server_id) DO NOTHING;
                 """)
 
@@ -364,13 +363,13 @@ class APIConfig(BaseModel):
         max_version = -1
         
         for pool_entry in self.POOL:
-            if pool_entry['ts_id'] == TS_ID:
+            if pool_entry['ts_id'] == os.environ.get('TS_ID'):
                 continue
             
             try:
                 async with self.get_connection(pool_entry) as conn:
                     version = await conn.fetchval("""
-                        SELECT MAX(last_synced_version) FROM sync_status
+                        SELECT COALESCE(MAX(last_synced_version), -1) FROM sync_status
                     """)
                     if version > max_version:
                         max_version = version
@@ -419,7 +418,7 @@ class APIConfig(BaseModel):
             """)
             
             for pool_entry in self.POOL:
-                if pool_entry['ts_id'] == TS_ID:
+                if pool_entry['ts_id'] == os.environ.get('TS_ID'):
                     continue
                 
                 try:
@@ -432,7 +431,7 @@ class APIConfig(BaseModel):
                                 SELECT * FROM "{table_name}"
                                 WHERE version > $1 AND server_id = $2
                                 ORDER BY version ASC
-                            """, last_synced_version, TS_ID)
+                            """, last_synced_version, os.environ.get('TS_ID'))
                             
                             for change in changes:
                                 columns = change.keys()
@@ -512,16 +511,6 @@ class APIConfig(BaseModel):
                 'constraints': constraints
             }
 
-    async def create_sequence_if_not_exists(self, conn, sequence_name):
-        await conn.execute(f"""
-        DO $$
-        BEGIN
-            IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
-                CREATE SEQUENCE {sequence_name};
-            END IF;
-        END $$;
-        """)
-
     async def apply_schema_changes(self, pool_entry: Dict[str, Any], source_schema, target_schema):
         async with self.get_connection(pool_entry) as conn:
             source_tables = {t['table_name']: t for t in source_schema['tables']}
@@ -529,7 +518,7 @@ class APIConfig(BaseModel):
 
             def get_column_type(data_type):
                 if data_type == 'ARRAY':
-                    return 'text[]'  # or another appropriate type
+                    return 'text[]'
                 elif data_type == 'USER-DEFINED':
                     return 'geometry'
                 else:
@@ -630,6 +619,16 @@ class APIConfig(BaseModel):
 
         info(f"Schema synchronization completed for {pool_entry['ts_ip']}")
 
+    async def create_sequence_if_not_exists(self, conn, sequence_name):
+        await conn.execute(f"""
+        DO $$
+        BEGIN
+            IF NOT EXISTS (SELECT 1 FROM pg_sequences WHERE schemaname = 'public' AND sequencename = '{sequence_name}') THEN
+                CREATE SEQUENCE {sequence_name};
+            END IF;
+        END $$;
+        """)
+
 
 class Location(BaseModel):
     latitude: float