From d3930ca85f7f6b1580b537ec5359d4c3de74be00 Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Mon, 12 Aug 2024 18:01:13 -0700 Subject: [PATCH] Auto-update: Mon Aug 12 18:01:13 PDT 2024 --- sijapi/database.py | 297 +++++++++--------------------------------- sijapi/routers/sys.py | 32 +---- 2 files changed, 60 insertions(+), 269 deletions(-) diff --git a/sijapi/database.py b/sijapi/database.py index a4a87a6..2eeebbf 100644 --- a/sijapi/database.py +++ b/sijapi/database.py @@ -52,7 +52,6 @@ class QueryTracking(Base): args = Column(JSONB) executed_at = Column(DateTime(timezone=True), server_default=func.now()) completed_by = Column(JSONB, default={}) - result_checksum = Column(String) class Database: @classmethod @@ -65,6 +64,7 @@ class Database: self.sessions: Dict[str, Any] = {} self.online_servers: set = set() self.local_ts_id = self.get_local_ts_id() + self.last_sync_time = 0 def load_config(self, config_path: str) -> Dict[str, Any]: base_path = Path(__file__).parent.parent @@ -127,7 +127,6 @@ class Database: l.error(f"Failed to execute read query: {str(e)}") return None - async def write(self, query: str, **kwargs): if self.local_ts_id not in self.sessions: l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.") @@ -135,27 +134,15 @@ class Database: async with self.sessions[self.local_ts_id]() as session: try: - # a. Execute the write query locally + # Execute the write query locally serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()} result = await session.execute(text(query), serialized_kwargs) - - # b. Log the query in query_tracking table - new_query = QueryTracking( - ts_id=self.local_ts_id, - query=query, - args=json_dumps(kwargs), - completed_by={self.local_ts_id: True} - ) - session.add(new_query) - await session.flush() - query_id = new_query.id - await session.commit() # Initiate async operations - asyncio.create_task(self._async_sync_operations(query_id, query, serialized_kwargs)) + asyncio.create_task(self._async_sync_operations(query, kwargs)) - # c. Return the result + # Return the result return result except Exception as e: @@ -166,55 +153,36 @@ class Database: l.error(f"Traceback: {traceback.format_exc()}") return None - async def _async_sync_operations(self, query_id: int, query: str, params: dict): + async def _async_sync_operations(self, query: str, kwargs: dict): try: - # a. Calculate and add checksum - checksum = await self._local_compute_checksum(query, params) - await self.update_query_checksum(query_id, checksum) + # Add the write query to the query_tracking table + await self.add_query_to_tracking(query, kwargs) - # b. Synchronize query_tracking table - await self.sync_query_tracking() - - # c. Call /db/sync on all servers + # Call /db/sync on all online servers await self.call_db_sync_on_servers() except Exception as e: l.error(f"Error in async sync operations: {str(e)}") l.error(f"Traceback: {traceback.format_exc()}") - async def call_db_sync_on_servers(self): - """Call /db/sync on all online servers.""" - online_servers = await self.get_online_servers() - for server in self.config['POOL']: - if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id: - asyncio.create_task(self.call_db_sync(server)) - - async def call_db_sync(self, server): - url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync" - headers = { - "Authorization": f"Bearer {server['api_key']}" - } - async with aiohttp.ClientSession() as session: - try: - async with session.post(url, headers=headers, timeout=30) as response: - if response.status == 200: - l.info(f"Successfully called /db/sync on {url}") - else: - l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}") - except asyncio.TimeoutError: - l.debug(f"Timeout while calling /db/sync on {url}") - except Exception as e: - l.error(f"Error calling /db/sync on {url}: {str(e)}") - + async def add_query_to_tracking(self, query: str, kwargs: dict): + async with self.sessions[self.local_ts_id]() as session: + new_query = QueryTracking( + ts_id=self.local_ts_id, + query=query, + args=json_dumps(kwargs), + completed_by={self.local_ts_id: True} + ) + session.add(new_query) + await session.commit() async def get_primary_server(self) -> str: - url = urljoin(self.config['URL'], '/id') - + url = f"{self.config['URL']}/id" async with aiohttp.ClientSession() as session: try: async with session.get(url) as response: if response.status == 200: primary_ts_id = await response.text() - return primary_ts_id.strip() + return primary_ts_id.strip().strip('"') else: l.error(f"Failed to get primary server. Status: {response.status}") return None @@ -222,207 +190,62 @@ class Database: l.error(f"Error connecting to load balancer: {str(e)}") return None - async def get_checksum_server(self) -> dict: + async def pull_query_tracking_from_primary(self): primary_ts_id = await self.get_primary_server() - online_servers = await self.get_online_servers() - - checksum_servers = [server for server in self.config['POOL'] if server['ts_id'] in online_servers and server['ts_id'] != primary_ts_id] - - if not checksum_servers: - return next(server for server in self.config['POOL'] if server['ts_id'] == primary_ts_id) - - return random.choice(checksum_servers) + if not primary_ts_id: + l.error("Failed to get primary server") + return - async def _local_compute_checksum(self, query: str, params: dict): + primary_server = next((s for s in self.config['POOL'] if s['ts_id'] == primary_ts_id), None) + if not primary_server: + l.error(f"Primary server {primary_ts_id} not found in config") + return + + async with self.sessions[primary_ts_id]() as session: + queries = await session.execute(select(QueryTracking)) + queries = queries.fetchall() + + async with self.sessions[self.local_ts_id]() as local_session: + for query in queries: + existing = await local_session.get(QueryTracking, query.id) + if existing: + existing.completed_by = {**existing.completed_by, **query.completed_by} + else: + local_session.add(query) + await local_session.commit() + + async def execute_unexecuted_queries(self): async with self.sessions[self.local_ts_id]() as session: - result = await session.execute(text(query), params) - if result.returns_rows: - data = result.fetchall() - else: - data = str(result.rowcount) + query + str(params) - checksum = hashlib.md5(str(data).encode()).hexdigest() - return checksum - - async def _delegate_compute_checksum(self, server: Dict[str, Any], query: str, params: dict): - url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum" - - async with aiohttp.ClientSession() as session: - try: - async with session.post(url, json={"query": query, "params": params}) as response: - if response.status == 200: - result = await response.json() - return result['checksum'] - else: - l.error(f"Failed to get checksum from {server['ts_id']}. Status: {response.status}") - return await self._local_compute_checksum(query, params) - except aiohttp.ClientError as e: - l.error(f"Error connecting to {server['ts_id']} for checksum: {str(e)}") - return await self._local_compute_checksum(query, params) - - async def update_query_checksum(self, query_id: int, checksum: str): - async with self.sessions[self.local_ts_id]() as session: - await session.execute( - text("UPDATE query_tracking SET result_checksum = :checksum WHERE id = :id"), - {"checksum": checksum, "id": query_id} - ) - await session.commit() - - async def _replicate_write(self, ts_id: str, query_id: int, query: str, params: dict, expected_checksum: str): - try: - async with self.sessions[ts_id]() as session: - await session.execute(text(query), params) - actual_checksum = await self._local_compute_checksum(query, params) - if actual_checksum != expected_checksum: - raise ValueError(f"Checksum mismatch on {ts_id}") - await self.mark_query_completed(query_id, ts_id) - await session.commit() - l.info(f"Successfully replicated write to {ts_id}") - except Exception as e: - l.error(f"Failed to replicate write on {ts_id}: {str(e)}") - l.error(f"Traceback: {traceback.format_exc()}") - - - async def mark_query_completed(self, query_id: int, ts_id: str): - async with self.sessions[self.local_ts_id]() as session: - query = await session.get(QueryTracking, query_id) - if query: - completed_by = query.completed_by or {} - completed_by[ts_id] = True - query.completed_by = completed_by - await session.commit() - - async def sync_local_server(self): - async with self.sessions[self.local_ts_id]() as session: - last_synced = await session.execute( - text("SELECT MAX(id) FROM query_tracking WHERE completed_by ? :ts_id"), - {"ts_id": self.local_ts_id} - ) - last_synced_id = last_synced.scalar() or 0 - unexecuted_queries = await session.execute( - text("SELECT * FROM query_tracking WHERE id > :last_id ORDER BY id"), - {"last_id": last_synced_id} + select(QueryTracking).where(~QueryTracking.completed_by.has_key(self.local_ts_id)).order_by(QueryTracking.id) ) + unexecuted_queries = unexecuted_queries.fetchall() for query in unexecuted_queries: try: params = json.loads(query.args) await session.execute(text(query.query), params) - actual_checksum = await self._local_compute_checksum(query.query, params) - if actual_checksum != query.result_checksum: - raise ValueError(f"Checksum mismatch for query ID {query.id}") - await self.mark_query_completed(query.id, self.local_ts_id) + query.completed_by[self.local_ts_id] = True + await session.commit() + l.info(f"Successfully executed query ID {query.id}") except Exception as e: - l.error(f"Failed to execute query ID {query.id} during local sync: {str(e)}") + l.error(f"Failed to execute query ID {query.id}: {str(e)}") + await session.rollback() - await session.commit() - l.info(f"Local server sync completed. Executed {unexecuted_queries.rowcount} queries.") + async def sync_db(self): + current_time = time.time() + if current_time - self.last_sync_time < 30: + l.info("Skipping sync, last sync was less than 30 seconds ago") + return - async def purge_completed_queries(self): - async with self.sessions[self.local_ts_id]() as session: - all_ts_ids = [db['ts_id'] for db in self.config['POOL']] - - result = await session.execute( - text(""" - DELETE FROM query_tracking - WHERE id <= ( - SELECT MAX(id) - FROM query_tracking - WHERE completed_by ?& :ts_ids - ) - """), - {"ts_ids": all_ts_ids} - ) - await session.commit() - - deleted_count = result.rowcount - l.info(f"Purged {deleted_count} completed queries.") - - - async def sync_query_tracking(self): - """Combinatorial sync method for the query_tracking table.""" try: - online_servers = await self.get_online_servers() - - for ts_id in online_servers: - if ts_id == self.local_ts_id: - continue - - try: - async with self.sessions[ts_id]() as remote_session: - local_max_id = await self.get_max_query_id(self.local_ts_id) - remote_max_id = await self.get_max_query_id(ts_id) - - # Sync from remote to local - remote_new_queries = await remote_session.execute( - select(QueryTracking).where(QueryTracking.id > local_max_id) - ) - remote_new_queries = remote_new_queries.fetchall() - for query in remote_new_queries: - await self.add_or_update_query(query[0]) - - # Sync from local to remote - async with self.sessions[self.local_ts_id]() as local_session: - local_new_queries = await local_session.execute( - select(QueryTracking).where(QueryTracking.id > remote_max_id) - ) - local_new_queries = local_new_queries.fetchall() - for query in local_new_queries: - await self.add_or_update_query_remote(ts_id, query[0]) - except Exception as e: - l.error(f"Error syncing with {ts_id}: {str(e)}") - l.error(f"Traceback: {traceback.format_exc()}") + await self.pull_query_tracking_from_primary() + await self.execute_unexecuted_queries() + self.last_sync_time = current_time except Exception as e: - l.error(f"Error in sync_query_tracking: {str(e)}") + l.error(f"Error during database sync: {str(e)}") l.error(f"Traceback: {traceback.format_exc()}") - - async def get_max_query_id(self, ts_id): - async with self.sessions[ts_id]() as session: - result = await session.execute(select(func.max(QueryTracking.id))) - return result.scalar() or 0 - - async def add_or_update_query(self, query): - async with self.sessions[self.local_ts_id]() as session: - existing_query = await session.get(QueryTracking, query.id) - if existing_query: - existing_query.completed_by = {**existing_query.completed_by, **query.completed_by} - else: - session.add(query) - await session.commit() - - async def add_or_update_query_remote(self, ts_id, query): - async with self.sessions[ts_id]() as session: - existing_query = await session.get(QueryTracking, query.id) - if existing_query: - existing_query.completed_by = {**existing_query.completed_by, **query.completed_by} - else: - new_query = QueryTracking( - id=query.id, - ts_id=query.ts_id, - query=query.query, - args=query.args, - executed_at=query.executed_at, - completed_by=query.completed_by, - result_checksum=query.result_checksum - ) - session.add(new_query) - try: - await session.commit() - except Exception as e: - l.error(f"Failed to add or update query on {ts_id}: {str(e)}") - await session.rollback() - - async def ensure_query_tracking_table(self): - for ts_id, engine in self.engines.items(): - try: - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - l.info(f"Ensured query_tracking table exists for {ts_id}") - except Exception as e: - l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}") - - async def call_db_sync_on_servers(self): """Call /db/sync on all online servers.""" online_servers = await self.get_online_servers() @@ -433,7 +256,6 @@ class Database: except Exception as e: l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}") - async def call_db_sync(self, server): url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync" headers = { @@ -451,7 +273,6 @@ class Database: except Exception as e: l.error(f"Error calling /db/sync on {url}: {str(e)}") - async def close(self): for engine in self.engines.values(): await engine.dispose() diff --git a/sijapi/routers/sys.py b/sijapi/routers/sys.py index 6a33d99..9ab2643 100644 --- a/sijapi/routers/sys.py +++ b/sijapi/routers/sys.py @@ -67,37 +67,7 @@ async def get_tailscale_ip(): else: return "No devices found" -async def sync_process(): - async with Db.sessions[TS_ID]() as session: - # Find unexecuted queries - unexecuted_queries = await session.execute( - select(QueryTracking).where(~QueryTracking.completed_by.has_key(TS_ID)).order_by(QueryTracking.id) - ) - - for query in unexecuted_queries: - try: - params = json_loads(query.args) - await session.execute(text(query.query), params) - actual_checksum = await Db._local_compute_checksum(query.query, params) - if actual_checksum != query.result_checksum: - l.error(f"Checksum mismatch for query ID {query.id}") - continue - - # Update the completed_by field - query.completed_by[TS_ID] = True - await session.commit() - - l.info(f"Successfully executed and verified query ID {query.id}") - except Exception as e: - l.error(f"Failed to execute query ID {query.id} during sync: {str(e)}") - await session.rollback() - - l.info(f"Sync process completed. Executed {unexecuted_queries.rowcount} queries.") - - # After executing all queries, perform combinatorial sync - await Db.sync_query_tracking() - @sys.post("/db/sync") async def db_sync(background_tasks: BackgroundTasks): - background_tasks.add_task(sync_process) + background_tasks.add_task(Db.sync_db) return {"message": "Sync process initiated"}