From ea960f1719e410336ba985c0e19f0d3aac2e465f Mon Sep 17 00:00:00 2001
From: sanj <67624670+iodrift@users.noreply.github.com>
Date: Mon, 12 Aug 2024 18:01:13 -0700
Subject: [PATCH] Auto-update: Mon Aug 12 18:01:13 PDT 2024

---
 sijapi/database.py    | 297 +++++++++---------------------------------
 sijapi/routers/sys.py |  32 +----
 2 files changed, 60 insertions(+), 269 deletions(-)

diff --git a/sijapi/database.py b/sijapi/database.py
index a4a87a6..2eeebbf 100644
--- a/sijapi/database.py
+++ b/sijapi/database.py
@@ -52,7 +52,6 @@ class QueryTracking(Base):
     args = Column(JSONB)
     executed_at = Column(DateTime(timezone=True), server_default=func.now())
     completed_by = Column(JSONB, default={})
-    result_checksum = Column(String)
 
 class Database:
     @classmethod
@@ -65,6 +64,7 @@ class Database:
         self.sessions: Dict[str, Any] = {}
         self.online_servers: set = set()
         self.local_ts_id = self.get_local_ts_id()
+        self.last_sync_time = 0
 
     def load_config(self, config_path: str) -> Dict[str, Any]:
         base_path = Path(__file__).parent.parent
@@ -127,7 +127,6 @@ class Database:
                 l.error(f"Failed to execute read query: {str(e)}")
                 return None
 
-
     async def write(self, query: str, **kwargs):
         if self.local_ts_id not in self.sessions:
             l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.")
@@ -135,27 +134,15 @@ class Database:
 
         async with self.sessions[self.local_ts_id]() as session:
             try:
-                # a. Execute the write query locally
+                # Execute the write query locally
                 serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
                 result = await session.execute(text(query), serialized_kwargs)
-                
-                # b. Log the query in query_tracking table
-                new_query = QueryTracking(
-                    ts_id=self.local_ts_id,
-                    query=query,
-                    args=json_dumps(kwargs),
-                    completed_by={self.local_ts_id: True}
-                )
-                session.add(new_query)
-                await session.flush()
-                query_id = new_query.id
-
                 await session.commit()
                 
                 # Initiate async operations
-                asyncio.create_task(self._async_sync_operations(query_id, query, serialized_kwargs))
+                asyncio.create_task(self._async_sync_operations(query, kwargs))
 
-                # c. Return the result
+                # Return the result
                 return result
             
             except Exception as e:
@@ -166,55 +153,36 @@ class Database:
                 l.error(f"Traceback: {traceback.format_exc()}")
                 return None
 
-    async def _async_sync_operations(self, query_id: int, query: str, params: dict):
+    async def _async_sync_operations(self, query: str, kwargs: dict):
         try:
-            # a. Calculate and add checksum
-            checksum = await self._local_compute_checksum(query, params)
-            await self.update_query_checksum(query_id, checksum)
+            # Add the write query to the query_tracking table
+            await self.add_query_to_tracking(query, kwargs)
 
-            # b. Synchronize query_tracking table
-            await self.sync_query_tracking()
-
-            # c. Call /db/sync on all servers
+            # Call /db/sync on all online servers
             await self.call_db_sync_on_servers()
         except Exception as e:
             l.error(f"Error in async sync operations: {str(e)}")
             l.error(f"Traceback: {traceback.format_exc()}")
 
-    async def call_db_sync_on_servers(self):
-        """Call /db/sync on all online servers."""
-        online_servers = await self.get_online_servers()
-        for server in self.config['POOL']:
-            if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id:
-                asyncio.create_task(self.call_db_sync(server))
-
-    async def call_db_sync(self, server):
-        url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
-        headers = {
-            "Authorization": f"Bearer {server['api_key']}"
-        }
-        async with aiohttp.ClientSession() as session:
-            try:
-                async with session.post(url, headers=headers, timeout=30) as response:
-                    if response.status == 200:
-                        l.info(f"Successfully called /db/sync on {url}")
-                    else:
-                        l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}")
-            except asyncio.TimeoutError:
-                l.debug(f"Timeout while calling /db/sync on {url}")
-            except Exception as e:
-                l.error(f"Error calling /db/sync on {url}: {str(e)}")
-
+    async def add_query_to_tracking(self, query: str, kwargs: dict):
+        async with self.sessions[self.local_ts_id]() as session:
+            new_query = QueryTracking(
+                ts_id=self.local_ts_id,
+                query=query,
+                args=json_dumps(kwargs),
+                completed_by={self.local_ts_id: True}
+            )
+            session.add(new_query)
+            await session.commit()
 
     async def get_primary_server(self) -> str:
-        url = urljoin(self.config['URL'], '/id')
-        
+        url = f"{self.config['URL']}/id"
         async with aiohttp.ClientSession() as session:
             try:
                 async with session.get(url) as response:
                     if response.status == 200:
                         primary_ts_id = await response.text()
-                        return primary_ts_id.strip()
+                        return primary_ts_id.strip().strip('"')
                     else:
                         l.error(f"Failed to get primary server. Status: {response.status}")
                         return None
@@ -222,207 +190,62 @@ class Database:
                 l.error(f"Error connecting to load balancer: {str(e)}")
                 return None
 
-    async def get_checksum_server(self) -> dict:
+    async def pull_query_tracking_from_primary(self):
         primary_ts_id = await self.get_primary_server()
-        online_servers = await self.get_online_servers()
-        
-        checksum_servers = [server for server in self.config['POOL'] if server['ts_id'] in online_servers and server['ts_id'] != primary_ts_id]
-        
-        if not checksum_servers:
-            return next(server for server in self.config['POOL'] if server['ts_id'] == primary_ts_id)
-        
-        return random.choice(checksum_servers)
+        if not primary_ts_id:
+            l.error("Failed to get primary server")
+            return
 
-    async def _local_compute_checksum(self, query: str, params: dict):
+        primary_server = next((s for s in self.config['POOL'] if s['ts_id'] == primary_ts_id), None)
+        if not primary_server:
+            l.error(f"Primary server {primary_ts_id} not found in config")
+            return
+
+        async with self.sessions[primary_ts_id]() as session:
+            queries = await session.execute(select(QueryTracking))
+            queries = queries.fetchall()
+
+        async with self.sessions[self.local_ts_id]() as local_session:
+            for query in queries:
+                existing = await local_session.get(QueryTracking, query.id)
+                if existing:
+                    existing.completed_by = {**existing.completed_by, **query.completed_by}
+                else:
+                    local_session.add(query)
+            await local_session.commit()
+
+    async def execute_unexecuted_queries(self):
         async with self.sessions[self.local_ts_id]() as session:
-            result = await session.execute(text(query), params)
-            if result.returns_rows:
-                data = result.fetchall()
-            else:
-                data = str(result.rowcount) + query + str(params)
-            checksum = hashlib.md5(str(data).encode()).hexdigest()
-            return checksum
-
-    async def _delegate_compute_checksum(self, server: Dict[str, Any], query: str, params: dict):
-        url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum"
-        
-        async with aiohttp.ClientSession() as session:
-            try:
-                async with session.post(url, json={"query": query, "params": params}) as response:
-                    if response.status == 200:
-                        result = await response.json()
-                        return result['checksum']
-                    else:
-                        l.error(f"Failed to get checksum from {server['ts_id']}. Status: {response.status}")
-                        return await self._local_compute_checksum(query, params)
-            except aiohttp.ClientError as e:
-                l.error(f"Error connecting to {server['ts_id']} for checksum: {str(e)}")
-                return await self._local_compute_checksum(query, params)
-
-    async def update_query_checksum(self, query_id: int, checksum: str):
-        async with self.sessions[self.local_ts_id]() as session:
-            await session.execute(
-                text("UPDATE query_tracking SET result_checksum = :checksum WHERE id = :id"),
-                {"checksum": checksum, "id": query_id}
-            )
-            await session.commit()
-
-    async def _replicate_write(self, ts_id: str, query_id: int, query: str, params: dict, expected_checksum: str):
-        try:
-            async with self.sessions[ts_id]() as session:
-                await session.execute(text(query), params)
-                actual_checksum = await self._local_compute_checksum(query, params)
-                if actual_checksum != expected_checksum:
-                    raise ValueError(f"Checksum mismatch on {ts_id}")
-                await self.mark_query_completed(query_id, ts_id)
-                await session.commit()
-                l.info(f"Successfully replicated write to {ts_id}")
-        except Exception as e:
-            l.error(f"Failed to replicate write on {ts_id}: {str(e)}")
-            l.error(f"Traceback: {traceback.format_exc()}")
-
-
-    async def mark_query_completed(self, query_id: int, ts_id: str):
-        async with self.sessions[self.local_ts_id]() as session:
-            query = await session.get(QueryTracking, query_id)
-            if query:
-                completed_by = query.completed_by or {}
-                completed_by[ts_id] = True
-                query.completed_by = completed_by
-                await session.commit()
-
-    async def sync_local_server(self):
-        async with self.sessions[self.local_ts_id]() as session:
-            last_synced = await session.execute(
-                text("SELECT MAX(id) FROM query_tracking WHERE completed_by ? :ts_id"),
-                {"ts_id": self.local_ts_id}
-            )
-            last_synced_id = last_synced.scalar() or 0
-
             unexecuted_queries = await session.execute(
-                text("SELECT * FROM query_tracking WHERE id > :last_id ORDER BY id"),
-                {"last_id": last_synced_id}
+                select(QueryTracking).where(~QueryTracking.completed_by.has_key(self.local_ts_id)).order_by(QueryTracking.id)
             )
+            unexecuted_queries = unexecuted_queries.fetchall()
 
             for query in unexecuted_queries:
                 try:
                     params = json.loads(query.args)
                     await session.execute(text(query.query), params)
-                    actual_checksum = await self._local_compute_checksum(query.query, params)
-                    if actual_checksum != query.result_checksum:
-                        raise ValueError(f"Checksum mismatch for query ID {query.id}")
-                    await self.mark_query_completed(query.id, self.local_ts_id)
+                    query.completed_by[self.local_ts_id] = True
+                    await session.commit()
+                    l.info(f"Successfully executed query ID {query.id}")
                 except Exception as e:
-                    l.error(f"Failed to execute query ID {query.id} during local sync: {str(e)}")
+                    l.error(f"Failed to execute query ID {query.id}: {str(e)}")
+                    await session.rollback()
 
-            await session.commit()
-            l.info(f"Local server sync completed. Executed {unexecuted_queries.rowcount} queries.")
+    async def sync_db(self):
+        current_time = time.time()
+        if current_time - self.last_sync_time < 30:
+            l.info("Skipping sync, last sync was less than 30 seconds ago")
+            return
 
-    async def purge_completed_queries(self):
-        async with self.sessions[self.local_ts_id]() as session:
-            all_ts_ids = [db['ts_id'] for db in self.config['POOL']]
-            
-            result = await session.execute(
-                text("""
-                    DELETE FROM query_tracking
-                    WHERE id <= (
-                        SELECT MAX(id)
-                        FROM query_tracking
-                        WHERE completed_by ?& :ts_ids
-                    )
-                """),
-                {"ts_ids": all_ts_ids}
-            )
-            await session.commit()
-            
-            deleted_count = result.rowcount
-            l.info(f"Purged {deleted_count} completed queries.")
-
-
-    async def sync_query_tracking(self):
-        """Combinatorial sync method for the query_tracking table."""
         try:
-            online_servers = await self.get_online_servers()
-            
-            for ts_id in online_servers:
-                if ts_id == self.local_ts_id:
-                    continue
-                
-                try:
-                    async with self.sessions[ts_id]() as remote_session:
-                        local_max_id = await self.get_max_query_id(self.local_ts_id)
-                        remote_max_id = await self.get_max_query_id(ts_id)
-                        
-                        # Sync from remote to local
-                        remote_new_queries = await remote_session.execute(
-                            select(QueryTracking).where(QueryTracking.id > local_max_id)
-                        )
-                        remote_new_queries = remote_new_queries.fetchall()
-                        for query in remote_new_queries:
-                            await self.add_or_update_query(query[0])
-                        
-                        # Sync from local to remote
-                        async with self.sessions[self.local_ts_id]() as local_session:
-                            local_new_queries = await local_session.execute(
-                                select(QueryTracking).where(QueryTracking.id > remote_max_id)
-                            )
-                            local_new_queries = local_new_queries.fetchall()
-                            for query in local_new_queries:
-                                await self.add_or_update_query_remote(ts_id, query[0])
-                except Exception as e:
-                    l.error(f"Error syncing with {ts_id}: {str(e)}")
-                    l.error(f"Traceback: {traceback.format_exc()}")
+            await self.pull_query_tracking_from_primary()
+            await self.execute_unexecuted_queries()
+            self.last_sync_time = current_time
         except Exception as e:
-            l.error(f"Error in sync_query_tracking: {str(e)}")
+            l.error(f"Error during database sync: {str(e)}")
             l.error(f"Traceback: {traceback.format_exc()}")
 
-
-    async def get_max_query_id(self, ts_id):
-        async with self.sessions[ts_id]() as session:
-            result = await session.execute(select(func.max(QueryTracking.id)))
-            return result.scalar() or 0
-
-    async def add_or_update_query(self, query):
-        async with self.sessions[self.local_ts_id]() as session:
-            existing_query = await session.get(QueryTracking, query.id)
-            if existing_query:
-                existing_query.completed_by = {**existing_query.completed_by, **query.completed_by}
-            else:
-                session.add(query)
-            await session.commit()
-
-    async def add_or_update_query_remote(self, ts_id, query):
-        async with self.sessions[ts_id]() as session:
-            existing_query = await session.get(QueryTracking, query.id)
-            if existing_query:
-                existing_query.completed_by = {**existing_query.completed_by, **query.completed_by}
-            else:
-                new_query = QueryTracking(
-                    id=query.id,
-                    ts_id=query.ts_id,
-                    query=query.query,
-                    args=query.args,
-                    executed_at=query.executed_at,
-                    completed_by=query.completed_by,
-                    result_checksum=query.result_checksum
-                )
-                session.add(new_query)
-            try:
-                await session.commit()
-            except Exception as e:
-                l.error(f"Failed to add or update query on {ts_id}: {str(e)}")
-                await session.rollback()
-
-    async def ensure_query_tracking_table(self):
-        for ts_id, engine in self.engines.items():
-            try:
-                async with engine.begin() as conn:
-                    await conn.run_sync(Base.metadata.create_all)
-                l.info(f"Ensured query_tracking table exists for {ts_id}")
-            except Exception as e:
-                l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}")
-
-
     async def call_db_sync_on_servers(self):
         """Call /db/sync on all online servers."""
         online_servers = await self.get_online_servers()
@@ -433,7 +256,6 @@ class Database:
                 except Exception as e:
                     l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}")
 
-
     async def call_db_sync(self, server):
         url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
         headers = {
@@ -451,7 +273,6 @@ class Database:
             except Exception as e:
                 l.error(f"Error calling /db/sync on {url}: {str(e)}")
 
-
     async def close(self):
         for engine in self.engines.values():
             await engine.dispose()
diff --git a/sijapi/routers/sys.py b/sijapi/routers/sys.py
index 6a33d99..9ab2643 100644
--- a/sijapi/routers/sys.py
+++ b/sijapi/routers/sys.py
@@ -67,37 +67,7 @@ async def get_tailscale_ip():
         else:
             return "No devices found"
 
-async def sync_process():
-    async with Db.sessions[TS_ID]() as session:
-        # Find unexecuted queries
-        unexecuted_queries = await session.execute(
-            select(QueryTracking).where(~QueryTracking.completed_by.has_key(TS_ID)).order_by(QueryTracking.id)
-        )
-
-        for query in unexecuted_queries:
-            try:
-                params = json_loads(query.args)
-                await session.execute(text(query.query), params)
-                actual_checksum = await Db._local_compute_checksum(query.query, params)
-                if actual_checksum != query.result_checksum:
-                    l.error(f"Checksum mismatch for query ID {query.id}")
-                    continue
-                
-                # Update the completed_by field
-                query.completed_by[TS_ID] = True
-                await session.commit()
-                
-                l.info(f"Successfully executed and verified query ID {query.id}")
-            except Exception as e:
-                l.error(f"Failed to execute query ID {query.id} during sync: {str(e)}")
-                await session.rollback()
-
-        l.info(f"Sync process completed. Executed {unexecuted_queries.rowcount} queries.")
-
-    # After executing all queries, perform combinatorial sync
-    await Db.sync_query_tracking()
-
 @sys.post("/db/sync")
 async def db_sync(background_tasks: BackgroundTasks):
-    background_tasks.add_task(sync_process)
+    background_tasks.add_task(Db.sync_db)
     return {"message": "Sync process initiated"}