From e76a059f60da5daaa59dcca63d5d8d88726129a8 Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Mon, 12 Aug 2024 21:44:51 -0700 Subject: [PATCH] Auto-update: Mon Aug 12 21:44:51 PDT 2024 --- sijapi/database.py | 156 +++++++++++++++++------------------------- sijapi/routers/sys.py | 34 +-------- 2 files changed, 65 insertions(+), 125 deletions(-) diff --git a/sijapi/database.py b/sijapi/database.py index f581bf2..f623603 100644 --- a/sijapi/database.py +++ b/sijapi/database.py @@ -1,5 +1,3 @@ -# database.py - import json import yaml import time @@ -20,8 +18,6 @@ from zoneinfo import ZoneInfo from srtm import get_data import os import sys -from sqlalchemy.dialects.postgresql import UUID -import uuid from loguru import logger from sqlalchemy import text, select, func, and_ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession @@ -45,18 +41,15 @@ ENV_PATH = CONFIG_DIR / ".env" load_dotenv(ENV_PATH) TS_ID = os.environ.get('TS_ID') - class QueryTracking(Base): __tablename__ = 'query_tracking' - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id = Column(Integer, primary_key=True) ts_id = Column(String, nullable=False) query = Column(Text, nullable=False) - args = Column(JSON) + args = Column(JSONB) executed_at = Column(DateTime(timezone=True), server_default=func.now()) - completed_by = Column(JSON, default={}) - result_checksum = Column(String(32)) # MD5 checksum - + completed_by = Column(JSONB, default={}) class Database: @classmethod @@ -109,9 +102,11 @@ class Database: async with engine.connect() as conn: await conn.execute(text("SELECT 1")) online_servers.append(ts_id) + l.debug(f"Server {ts_id} is online") except OperationalError: - pass + l.warning(f"Server {ts_id} is offline") self.online_servers = set(online_servers) + l.info(f"Online servers: {', '.join(online_servers)}") return online_servers async def read(self, query: str, **kwargs): @@ -144,11 +139,8 @@ class Database: result = await session.execute(text(query), serialized_kwargs) await session.commit() - # Add the write query to the query_tracking table - await self.add_query_to_tracking(query, kwargs) - # Initiate async operations - asyncio.create_task(self._async_sync_operations()) + asyncio.create_task(self._async_sync_operations(query, kwargs)) # Return the result return result @@ -161,16 +153,17 @@ class Database: l.error(f"Traceback: {traceback.format_exc()}") return None - - async def _async_sync_operations(self): + async def _async_sync_operations(self, query: str, kwargs: dict): try: + # Add the write query to the query_tracking table + await self.add_query_to_tracking(query, kwargs) + # Call /db/sync on all online servers await self.call_db_sync_on_servers() except Exception as e: l.error(f"Error in async sync operations: {str(e)}") l.error(f"Traceback: {traceback.format_exc()}") - async def add_query_to_tracking(self, query: str, kwargs: dict): async with self.sessions[self.local_ts_id]() as session: new_query = QueryTracking( @@ -181,34 +174,53 @@ class Database: ) session.add(new_query) await session.commit() + l.info(f"Added query to tracking: {query[:50]}...") - - - - async def pull_query_tracking_from_primary(self): - primary_ts_id = await self.get_primary_server() - if not primary_ts_id: - l.error("Failed to get primary server") + async def sync_db(self): + current_time = time.time() + if current_time - self.last_sync_time < 30: + l.info("Skipping sync, last sync was less than 30 seconds ago") return - primary_server = next((s for s in self.config['POOL'] if s['ts_id'] == primary_ts_id), None) - if not primary_server: - l.error(f"Primary server {primary_ts_id} not found in config") - return + try: + l.info("Starting database synchronization") + await self.pull_query_tracking_from_all_servers() + await self.execute_unexecuted_queries() + self.last_sync_time = current_time + l.info("Database synchronization completed successfully") + except Exception as e: + l.error(f"Error during database sync: {str(e)}") + l.error(f"Traceback: {traceback.format_exc()}") - async with self.sessions[primary_ts_id]() as session: - queries = await session.execute(select(QueryTracking)) - queries = queries.fetchall() - - async with self.sessions[self.local_ts_id]() as local_session: - for query in queries: - existing = await local_session.get(QueryTracking, query.id) - if existing: - existing.completed_by = {**existing.completed_by, **query.completed_by} - else: - local_session.add(query) - await local_session.commit() + async def pull_query_tracking_from_all_servers(self): + online_servers = await self.get_online_servers() + l.info(f"Pulling query tracking from {len(online_servers)} online servers") + + for server_id in online_servers: + if server_id == self.local_ts_id: + continue # Skip local server + + l.info(f"Pulling queries from server: {server_id}") + async with self.sessions[server_id]() as remote_session: + queries = await remote_session.execute(select(QueryTracking)) + queries = queries.fetchall() + l.info(f"Retrieved {len(queries)} queries from server {server_id}") + async with self.sessions[self.local_ts_id]() as local_session: + for query in queries: + existing = await local_session.execute( + select(QueryTracking).where(QueryTracking.id == query.id) + ) + existing = existing.scalar_one_or_none() + + if existing: + existing.completed_by = {**existing.completed_by, **query.completed_by} + l.debug(f"Updated existing query: {query.id}") + else: + local_session.add(query) + l.debug(f"Added new query: {query.id}") + await local_session.commit() + l.info("Finished pulling queries from all servers") async def execute_unexecuted_queries(self): async with self.sessions[self.local_ts_id]() as session: @@ -217,35 +229,30 @@ class Database: ) unexecuted_queries = unexecuted_queries.fetchall() + l.info(f"Executing {len(unexecuted_queries)} unexecuted queries") for query in unexecuted_queries: try: params = json.loads(query.args) - result = await session.execute(text(query.query), params) - - # Validate result checksum - result_str = str(result.fetchall()) - result_checksum = hashlib.md5(result_str.encode()).hexdigest() - - if result_checksum == query.result_checksum: - query.completed_by[self.local_ts_id] = True - await session.commit() - l.info(f"Successfully executed query ID {query.id}") - else: - l.error(f"Checksum mismatch for query ID {query.id}") - await session.rollback() + await session.execute(text(query.query), params) + query.completed_by[self.local_ts_id] = True + await session.commit() + l.info(f"Successfully executed query ID {query.id}") except Exception as e: l.error(f"Failed to execute query ID {query.id}: {str(e)}") await session.rollback() + l.info("Finished executing unexecuted queries") async def call_db_sync_on_servers(self): """Call /db/sync on all online servers.""" online_servers = await self.get_online_servers() + l.info(f"Calling /db/sync on {len(online_servers)} online servers") for server in self.config['POOL']: if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id: try: await self.call_db_sync(server) except Exception as e: l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}") + l.info("Finished calling /db/sync on all servers") async def call_db_sync(self, server): url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync" @@ -264,20 +271,6 @@ class Database: except Exception as e: l.error(f"Error calling /db/sync on {url}: {str(e)}") - async def sync_db(self): - current_time = time.time() - if current_time - self.last_sync_time < 30: - l.info("Skipping sync, last sync was less than 30 seconds ago") - return - - try: - await self.pull_query_tracking_from_all_servers() - await self.execute_unexecuted_queries() - self.last_sync_time = current_time - except Exception as e: - l.error(f"Error during database sync: {str(e)}") - l.error(f"Traceback: {traceback.format_exc()}") - async def ensure_query_tracking_table(self): for ts_id, engine in self.engines.items(): try: @@ -286,33 +279,8 @@ class Database: l.info(f"Ensured query_tracking table exists for {ts_id}") except Exception as e: l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}") - - async def pull_query_tracking_from_all_servers(self): - online_servers = await self.get_online_servers() - - for server_id in online_servers: - if server_id == self.local_ts_id: - continue # Skip local server - - async with self.sessions[server_id]() as remote_session: - queries = await remote_session.execute(select(QueryTracking)) - queries = queries.fetchall() - - async with self.sessions[self.local_ts_id]() as local_session: - for query in queries: - existing = await local_session.execute( - select(QueryTracking).where(QueryTracking.id == query.id) - ) - existing = existing.scalar_one_or_none() - - if existing: - existing.completed_by = {**existing.completed_by, **query.completed_by} - else: - local_session.add(query) - await local_session.commit() - async def close(self): for engine in self.engines.values(): await engine.dispose() - + l.info("Closed all database connections") diff --git a/sijapi/routers/sys.py b/sijapi/routers/sys.py index 6a33d99..4954af8 100644 --- a/sijapi/routers/sys.py +++ b/sijapi/routers/sys.py @@ -67,37 +67,9 @@ async def get_tailscale_ip(): else: return "No devices found" -async def sync_process(): - async with Db.sessions[TS_ID]() as session: - # Find unexecuted queries - unexecuted_queries = await session.execute( - select(QueryTracking).where(~QueryTracking.completed_by.has_key(TS_ID)).order_by(QueryTracking.id) - ) - - for query in unexecuted_queries: - try: - params = json_loads(query.args) - await session.execute(text(query.query), params) - actual_checksum = await Db._local_compute_checksum(query.query, params) - if actual_checksum != query.result_checksum: - l.error(f"Checksum mismatch for query ID {query.id}") - continue - - # Update the completed_by field - query.completed_by[TS_ID] = True - await session.commit() - - l.info(f"Successfully executed and verified query ID {query.id}") - except Exception as e: - l.error(f"Failed to execute query ID {query.id} during sync: {str(e)}") - await session.rollback() - - l.info(f"Sync process completed. Executed {unexecuted_queries.rowcount} queries.") - - # After executing all queries, perform combinatorial sync - await Db.sync_query_tracking() @sys.post("/db/sync") async def db_sync(background_tasks: BackgroundTasks): - background_tasks.add_task(sync_process) - return {"message": "Sync process initiated"} + l.info(f"Received request to /db/sync") + background_tasks.add_task(Db.sync_db) + return {"message": "Sync process initiated"} \ No newline at end of file