From a2bbee6d53fc32121c6589426fce5b7c04c7be48 Mon Sep 17 00:00:00 2001
From: sanj <67624670+iodrift@users.noreply.github.com>
Date: Mon, 12 Aug 2024 21:44:51 -0700
Subject: [PATCH] Auto-update: Mon Aug 12 21:44:51 PDT 2024

---
 sijapi/database.py    | 156 +++++++++++++++++-------------------------
 sijapi/routers/sys.py |  34 +--------
 2 files changed, 65 insertions(+), 125 deletions(-)

diff --git a/sijapi/database.py b/sijapi/database.py
index f581bf2..f623603 100644
--- a/sijapi/database.py
+++ b/sijapi/database.py
@@ -1,5 +1,3 @@
-# database.py
-
 import json
 import yaml
 import time
@@ -20,8 +18,6 @@ from zoneinfo import ZoneInfo
 from srtm import get_data
 import os
 import sys
-from sqlalchemy.dialects.postgresql import UUID
-import uuid
 from loguru import logger
 from sqlalchemy import text, select, func, and_
 from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
@@ -45,18 +41,15 @@ ENV_PATH = CONFIG_DIR / ".env"
 load_dotenv(ENV_PATH)
 TS_ID = os.environ.get('TS_ID')
 
-
 class QueryTracking(Base):
     __tablename__ = 'query_tracking'
 
-    id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
+    id = Column(Integer, primary_key=True)
     ts_id = Column(String, nullable=False)
     query = Column(Text, nullable=False)
-    args = Column(JSON)
+    args = Column(JSONB)
     executed_at = Column(DateTime(timezone=True), server_default=func.now())
-    completed_by = Column(JSON, default={})
-    result_checksum = Column(String(32))  # MD5 checksum
-
+    completed_by = Column(JSONB, default={})
 
 class Database:
     @classmethod
@@ -109,9 +102,11 @@ class Database:
                 async with engine.connect() as conn:
                     await conn.execute(text("SELECT 1"))
                 online_servers.append(ts_id)
+                l.debug(f"Server {ts_id} is online")
             except OperationalError:
-                pass
+                l.warning(f"Server {ts_id} is offline")
         self.online_servers = set(online_servers)
+        l.info(f"Online servers: {', '.join(online_servers)}")
         return online_servers
 
     async def read(self, query: str, **kwargs):
@@ -144,11 +139,8 @@ class Database:
                 result = await session.execute(text(query), serialized_kwargs)
                 await session.commit()
                 
-                # Add the write query to the query_tracking table
-                await self.add_query_to_tracking(query, kwargs)
-
                 # Initiate async operations
-                asyncio.create_task(self._async_sync_operations())
+                asyncio.create_task(self._async_sync_operations(query, kwargs))
 
                 # Return the result
                 return result
@@ -161,16 +153,17 @@ class Database:
                 l.error(f"Traceback: {traceback.format_exc()}")
                 return None
 
-
-    async def _async_sync_operations(self):
+    async def _async_sync_operations(self, query: str, kwargs: dict):
         try:
+            # Add the write query to the query_tracking table
+            await self.add_query_to_tracking(query, kwargs)
+
             # Call /db/sync on all online servers
             await self.call_db_sync_on_servers()
         except Exception as e:
             l.error(f"Error in async sync operations: {str(e)}")
             l.error(f"Traceback: {traceback.format_exc()}")
 
-
     async def add_query_to_tracking(self, query: str, kwargs: dict):
         async with self.sessions[self.local_ts_id]() as session:
             new_query = QueryTracking(
@@ -181,34 +174,53 @@ class Database:
             )
             session.add(new_query)
             await session.commit()
+        l.info(f"Added query to tracking: {query[:50]}...")
 
-
-
-
-    async def pull_query_tracking_from_primary(self):
-        primary_ts_id = await self.get_primary_server()
-        if not primary_ts_id:
-            l.error("Failed to get primary server")
+    async def sync_db(self):
+        current_time = time.time()
+        if current_time - self.last_sync_time < 30:
+            l.info("Skipping sync, last sync was less than 30 seconds ago")
             return
 
-        primary_server = next((s for s in self.config['POOL'] if s['ts_id'] == primary_ts_id), None)
-        if not primary_server:
-            l.error(f"Primary server {primary_ts_id} not found in config")
-            return
+        try:
+            l.info("Starting database synchronization")
+            await self.pull_query_tracking_from_all_servers()
+            await self.execute_unexecuted_queries()
+            self.last_sync_time = current_time
+            l.info("Database synchronization completed successfully")
+        except Exception as e:
+            l.error(f"Error during database sync: {str(e)}")
+            l.error(f"Traceback: {traceback.format_exc()}")
 
-        async with self.sessions[primary_ts_id]() as session:
-            queries = await session.execute(select(QueryTracking))
-            queries = queries.fetchall()
-
-        async with self.sessions[self.local_ts_id]() as local_session:
-            for query in queries:
-                existing = await local_session.get(QueryTracking, query.id)
-                if existing:
-                    existing.completed_by = {**existing.completed_by, **query.completed_by}
-                else:
-                    local_session.add(query)
-            await local_session.commit()
+    async def pull_query_tracking_from_all_servers(self):
+        online_servers = await self.get_online_servers()
+        l.info(f"Pulling query tracking from {len(online_servers)} online servers")
+        
+        for server_id in online_servers:
+            if server_id == self.local_ts_id:
+                continue  # Skip local server
+            
+            l.info(f"Pulling queries from server: {server_id}")
+            async with self.sessions[server_id]() as remote_session:
+                queries = await remote_session.execute(select(QueryTracking))
+                queries = queries.fetchall()
 
+            l.info(f"Retrieved {len(queries)} queries from server {server_id}")
+            async with self.sessions[self.local_ts_id]() as local_session:
+                for query in queries:
+                    existing = await local_session.execute(
+                        select(QueryTracking).where(QueryTracking.id == query.id)
+                    )
+                    existing = existing.scalar_one_or_none()
+                    
+                    if existing:
+                        existing.completed_by = {**existing.completed_by, **query.completed_by}
+                        l.debug(f"Updated existing query: {query.id}")
+                    else:
+                        local_session.add(query)
+                        l.debug(f"Added new query: {query.id}")
+                await local_session.commit()
+        l.info("Finished pulling queries from all servers")
 
     async def execute_unexecuted_queries(self):
         async with self.sessions[self.local_ts_id]() as session:
@@ -217,35 +229,30 @@ class Database:
             )
             unexecuted_queries = unexecuted_queries.fetchall()
 
+            l.info(f"Executing {len(unexecuted_queries)} unexecuted queries")
             for query in unexecuted_queries:
                 try:
                     params = json.loads(query.args)
-                    result = await session.execute(text(query.query), params)
-                    
-                    # Validate result checksum
-                    result_str = str(result.fetchall())
-                    result_checksum = hashlib.md5(result_str.encode()).hexdigest()
-                    
-                    if result_checksum == query.result_checksum:
-                        query.completed_by[self.local_ts_id] = True
-                        await session.commit()
-                        l.info(f"Successfully executed query ID {query.id}")
-                    else:
-                        l.error(f"Checksum mismatch for query ID {query.id}")
-                        await session.rollback()
+                    await session.execute(text(query.query), params)
+                    query.completed_by[self.local_ts_id] = True
+                    await session.commit()
+                    l.info(f"Successfully executed query ID {query.id}")
                 except Exception as e:
                     l.error(f"Failed to execute query ID {query.id}: {str(e)}")
                     await session.rollback()
+        l.info("Finished executing unexecuted queries")
 
     async def call_db_sync_on_servers(self):
         """Call /db/sync on all online servers."""
         online_servers = await self.get_online_servers()
+        l.info(f"Calling /db/sync on {len(online_servers)} online servers")
         for server in self.config['POOL']:
             if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id:
                 try:
                     await self.call_db_sync(server)
                 except Exception as e:
                     l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}")
+        l.info("Finished calling /db/sync on all servers")
 
     async def call_db_sync(self, server):
         url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
@@ -264,20 +271,6 @@ class Database:
             except Exception as e:
                 l.error(f"Error calling /db/sync on {url}: {str(e)}")
 
-    async def sync_db(self):
-        current_time = time.time()
-        if current_time - self.last_sync_time < 30:
-            l.info("Skipping sync, last sync was less than 30 seconds ago")
-            return
-
-        try:
-            await self.pull_query_tracking_from_all_servers()
-            await self.execute_unexecuted_queries()
-            self.last_sync_time = current_time
-        except Exception as e:
-            l.error(f"Error during database sync: {str(e)}")
-            l.error(f"Traceback: {traceback.format_exc()}")
-
     async def ensure_query_tracking_table(self):
         for ts_id, engine in self.engines.items():
             try:
@@ -286,33 +279,8 @@ class Database:
                 l.info(f"Ensured query_tracking table exists for {ts_id}")
             except Exception as e:
                 l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}")
-
-    async def pull_query_tracking_from_all_servers(self):
-        online_servers = await self.get_online_servers()
-        
-        for server_id in online_servers:
-            if server_id == self.local_ts_id:
-                continue  # Skip local server
-            
-            async with self.sessions[server_id]() as remote_session:
-                queries = await remote_session.execute(select(QueryTracking))
-                queries = queries.fetchall()
-
-            async with self.sessions[self.local_ts_id]() as local_session:
-                for query in queries:
-                    existing = await local_session.execute(
-                        select(QueryTracking).where(QueryTracking.id == query.id)
-                    )
-                    existing = existing.scalar_one_or_none()
-                    
-                    if existing:
-                        existing.completed_by = {**existing.completed_by, **query.completed_by}
-                    else:
-                        local_session.add(query)
-                await local_session.commit()
-
     
     async def close(self):
         for engine in self.engines.values():
             await engine.dispose()
-
+        l.info("Closed all database connections")
diff --git a/sijapi/routers/sys.py b/sijapi/routers/sys.py
index 6a33d99..4954af8 100644
--- a/sijapi/routers/sys.py
+++ b/sijapi/routers/sys.py
@@ -67,37 +67,9 @@ async def get_tailscale_ip():
         else:
             return "No devices found"
 
-async def sync_process():
-    async with Db.sessions[TS_ID]() as session:
-        # Find unexecuted queries
-        unexecuted_queries = await session.execute(
-            select(QueryTracking).where(~QueryTracking.completed_by.has_key(TS_ID)).order_by(QueryTracking.id)
-        )
-
-        for query in unexecuted_queries:
-            try:
-                params = json_loads(query.args)
-                await session.execute(text(query.query), params)
-                actual_checksum = await Db._local_compute_checksum(query.query, params)
-                if actual_checksum != query.result_checksum:
-                    l.error(f"Checksum mismatch for query ID {query.id}")
-                    continue
-                
-                # Update the completed_by field
-                query.completed_by[TS_ID] = True
-                await session.commit()
-                
-                l.info(f"Successfully executed and verified query ID {query.id}")
-            except Exception as e:
-                l.error(f"Failed to execute query ID {query.id} during sync: {str(e)}")
-                await session.rollback()
-
-        l.info(f"Sync process completed. Executed {unexecuted_queries.rowcount} queries.")
-
-    # After executing all queries, perform combinatorial sync
-    await Db.sync_query_tracking()
 
 @sys.post("/db/sync")
 async def db_sync(background_tasks: BackgroundTasks):
-    background_tasks.add_task(sync_process)
-    return {"message": "Sync process initiated"}
+    l.info(f"Received request to /db/sync")
+    background_tasks.add_task(Db.sync_db)
+    return {"message": "Sync process initiated"}
\ No newline at end of file