From 38a7b36f7d7b8755e15fc9ddc48594ccb90ce195 Mon Sep 17 00:00:00 2001
From: sanj <67624670+iodrift@users.noreply.github.com>
Date: Mon, 12 Aug 2024 17:30:08 -0700
Subject: [PATCH] Auto-update: Mon Aug 12 17:30:08 PDT 2024

---
 sijapi/database.py | 67 +++++++++++++++++++++++++++-------------------
 1 file changed, 40 insertions(+), 27 deletions(-)

diff --git a/sijapi/database.py b/sijapi/database.py
index 28ea5a0..a4a87a6 100644
--- a/sijapi/database.py
+++ b/sijapi/database.py
@@ -135,33 +135,27 @@ class Database:
 
         async with self.sessions[self.local_ts_id]() as session:
             try:
-                # Serialize the kwargs
+                # a. Execute the write query locally
                 serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
-
-                # Execute the write query
                 result = await session.execute(text(query), serialized_kwargs)
                 
-                # Log the query
+                # b. Log the query in query_tracking table
                 new_query = QueryTracking(
                     ts_id=self.local_ts_id,
                     query=query,
-                    args=json_dumps(kwargs)  # Use json_dumps for logging
+                    args=json_dumps(kwargs),
+                    completed_by={self.local_ts_id: True}
                 )
                 session.add(new_query)
                 await session.flush()
                 query_id = new_query.id
 
                 await session.commit()
-                l.info(f"Successfully executed write query: {query[:50]}...")
-
-                checksum = await self._local_compute_checksum(query, serialized_kwargs)
-
-                # Update query_tracking with checksum
-                await self.update_query_checksum(query_id, checksum)
-
-                # Perform sync operations asynchronously
-                asyncio.create_task(self._async_sync_operations(query_id, query, serialized_kwargs, checksum))
+                
+                # Initiate async operations
+                asyncio.create_task(self._async_sync_operations(query_id, query, serialized_kwargs))
 
+                # c. Return the result
                 return result
             
             except Exception as e:
@@ -172,25 +166,44 @@ class Database:
                 l.error(f"Traceback: {traceback.format_exc()}")
                 return None
 
-    async def _async_sync_operations(self, query_id: int, query: str, params: dict, checksum: str):
+    async def _async_sync_operations(self, query_id: int, query: str, params: dict):
         try:
-            await self.sync_query_tracking()
-        except Exception as e:
-            l.error(f"Failed to sync query_tracking: {str(e)}")
+            # a. Calculate and add checksum
+            checksum = await self._local_compute_checksum(query, params)
+            await self.update_query_checksum(query_id, checksum)
 
-        try:
+            # b. Synchronize query_tracking table
+            await self.sync_query_tracking()
+
+            # c. Call /db/sync on all servers
             await self.call_db_sync_on_servers()
         except Exception as e:
-            l.error(f"Failed to call db_sync on other servers: {str(e)}")
+            l.error(f"Error in async sync operations: {str(e)}")
+            l.error(f"Traceback: {traceback.format_exc()}")
 
-        # Replicate write to other servers
+    async def call_db_sync_on_servers(self):
+        """Call /db/sync on all online servers."""
         online_servers = await self.get_online_servers()
-        for ts_id in online_servers:
-            if ts_id != self.local_ts_id:
-                try:
-                    await self._replicate_write(ts_id, query_id, query, params, checksum)
-                except Exception as e:
-                    l.error(f"Failed to replicate write to {ts_id}: {str(e)}")
+        for server in self.config['POOL']:
+            if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id:
+                asyncio.create_task(self.call_db_sync(server))
+
+    async def call_db_sync(self, server):
+        url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
+        headers = {
+            "Authorization": f"Bearer {server['api_key']}"
+        }
+        async with aiohttp.ClientSession() as session:
+            try:
+                async with session.post(url, headers=headers, timeout=30) as response:
+                    if response.status == 200:
+                        l.info(f"Successfully called /db/sync on {url}")
+                    else:
+                        l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}")
+            except asyncio.TimeoutError:
+                l.debug(f"Timeout while calling /db/sync on {url}")
+            except Exception as e:
+                l.error(f"Error calling /db/sync on {url}: {str(e)}")
 
 
     async def get_primary_server(self) -> str: