Auto-update: Mon Aug 12 21:44:51 PDT 2024
This commit is contained in:
parent
6c6c3a7b65
commit
e76a059f60
2 changed files with 65 additions and 125 deletions
|
@ -1,5 +1,3 @@
|
||||||
# database.py
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import yaml
|
import yaml
|
||||||
import time
|
import time
|
||||||
|
@ -20,8 +18,6 @@ from zoneinfo import ZoneInfo
|
||||||
from srtm import get_data
|
from srtm import get_data
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
|
||||||
import uuid
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from sqlalchemy import text, select, func, and_
|
from sqlalchemy import text, select, func, and_
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||||
|
@ -45,18 +41,15 @@ ENV_PATH = CONFIG_DIR / ".env"
|
||||||
load_dotenv(ENV_PATH)
|
load_dotenv(ENV_PATH)
|
||||||
TS_ID = os.environ.get('TS_ID')
|
TS_ID = os.environ.get('TS_ID')
|
||||||
|
|
||||||
|
|
||||||
class QueryTracking(Base):
|
class QueryTracking(Base):
|
||||||
__tablename__ = 'query_tracking'
|
__tablename__ = 'query_tracking'
|
||||||
|
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
id = Column(Integer, primary_key=True)
|
||||||
ts_id = Column(String, nullable=False)
|
ts_id = Column(String, nullable=False)
|
||||||
query = Column(Text, nullable=False)
|
query = Column(Text, nullable=False)
|
||||||
args = Column(JSON)
|
args = Column(JSONB)
|
||||||
executed_at = Column(DateTime(timezone=True), server_default=func.now())
|
executed_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
completed_by = Column(JSON, default={})
|
completed_by = Column(JSONB, default={})
|
||||||
result_checksum = Column(String(32)) # MD5 checksum
|
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -109,9 +102,11 @@ class Database:
|
||||||
async with engine.connect() as conn:
|
async with engine.connect() as conn:
|
||||||
await conn.execute(text("SELECT 1"))
|
await conn.execute(text("SELECT 1"))
|
||||||
online_servers.append(ts_id)
|
online_servers.append(ts_id)
|
||||||
|
l.debug(f"Server {ts_id} is online")
|
||||||
except OperationalError:
|
except OperationalError:
|
||||||
pass
|
l.warning(f"Server {ts_id} is offline")
|
||||||
self.online_servers = set(online_servers)
|
self.online_servers = set(online_servers)
|
||||||
|
l.info(f"Online servers: {', '.join(online_servers)}")
|
||||||
return online_servers
|
return online_servers
|
||||||
|
|
||||||
async def read(self, query: str, **kwargs):
|
async def read(self, query: str, **kwargs):
|
||||||
|
@ -144,11 +139,8 @@ class Database:
|
||||||
result = await session.execute(text(query), serialized_kwargs)
|
result = await session.execute(text(query), serialized_kwargs)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# Add the write query to the query_tracking table
|
|
||||||
await self.add_query_to_tracking(query, kwargs)
|
|
||||||
|
|
||||||
# Initiate async operations
|
# Initiate async operations
|
||||||
asyncio.create_task(self._async_sync_operations())
|
asyncio.create_task(self._async_sync_operations(query, kwargs))
|
||||||
|
|
||||||
# Return the result
|
# Return the result
|
||||||
return result
|
return result
|
||||||
|
@ -161,16 +153,17 @@ class Database:
|
||||||
l.error(f"Traceback: {traceback.format_exc()}")
|
l.error(f"Traceback: {traceback.format_exc()}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _async_sync_operations(self, query: str, kwargs: dict):
|
||||||
async def _async_sync_operations(self):
|
|
||||||
try:
|
try:
|
||||||
|
# Add the write query to the query_tracking table
|
||||||
|
await self.add_query_to_tracking(query, kwargs)
|
||||||
|
|
||||||
# Call /db/sync on all online servers
|
# Call /db/sync on all online servers
|
||||||
await self.call_db_sync_on_servers()
|
await self.call_db_sync_on_servers()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
l.error(f"Error in async sync operations: {str(e)}")
|
l.error(f"Error in async sync operations: {str(e)}")
|
||||||
l.error(f"Traceback: {traceback.format_exc()}")
|
l.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
async def add_query_to_tracking(self, query: str, kwargs: dict):
|
async def add_query_to_tracking(self, query: str, kwargs: dict):
|
||||||
async with self.sessions[self.local_ts_id]() as session:
|
async with self.sessions[self.local_ts_id]() as session:
|
||||||
new_query = QueryTracking(
|
new_query = QueryTracking(
|
||||||
|
@ -181,34 +174,53 @@ class Database:
|
||||||
)
|
)
|
||||||
session.add(new_query)
|
session.add(new_query)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
l.info(f"Added query to tracking: {query[:50]}...")
|
||||||
|
|
||||||
|
async def sync_db(self):
|
||||||
|
current_time = time.time()
|
||||||
|
if current_time - self.last_sync_time < 30:
|
||||||
async def pull_query_tracking_from_primary(self):
|
l.info("Skipping sync, last sync was less than 30 seconds ago")
|
||||||
primary_ts_id = await self.get_primary_server()
|
|
||||||
if not primary_ts_id:
|
|
||||||
l.error("Failed to get primary server")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
primary_server = next((s for s in self.config['POOL'] if s['ts_id'] == primary_ts_id), None)
|
try:
|
||||||
if not primary_server:
|
l.info("Starting database synchronization")
|
||||||
l.error(f"Primary server {primary_ts_id} not found in config")
|
await self.pull_query_tracking_from_all_servers()
|
||||||
return
|
await self.execute_unexecuted_queries()
|
||||||
|
self.last_sync_time = current_time
|
||||||
|
l.info("Database synchronization completed successfully")
|
||||||
|
except Exception as e:
|
||||||
|
l.error(f"Error during database sync: {str(e)}")
|
||||||
|
l.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
async with self.sessions[primary_ts_id]() as session:
|
async def pull_query_tracking_from_all_servers(self):
|
||||||
queries = await session.execute(select(QueryTracking))
|
online_servers = await self.get_online_servers()
|
||||||
|
l.info(f"Pulling query tracking from {len(online_servers)} online servers")
|
||||||
|
|
||||||
|
for server_id in online_servers:
|
||||||
|
if server_id == self.local_ts_id:
|
||||||
|
continue # Skip local server
|
||||||
|
|
||||||
|
l.info(f"Pulling queries from server: {server_id}")
|
||||||
|
async with self.sessions[server_id]() as remote_session:
|
||||||
|
queries = await remote_session.execute(select(QueryTracking))
|
||||||
queries = queries.fetchall()
|
queries = queries.fetchall()
|
||||||
|
|
||||||
|
l.info(f"Retrieved {len(queries)} queries from server {server_id}")
|
||||||
async with self.sessions[self.local_ts_id]() as local_session:
|
async with self.sessions[self.local_ts_id]() as local_session:
|
||||||
for query in queries:
|
for query in queries:
|
||||||
existing = await local_session.get(QueryTracking, query.id)
|
existing = await local_session.execute(
|
||||||
|
select(QueryTracking).where(QueryTracking.id == query.id)
|
||||||
|
)
|
||||||
|
existing = existing.scalar_one_or_none()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
existing.completed_by = {**existing.completed_by, **query.completed_by}
|
existing.completed_by = {**existing.completed_by, **query.completed_by}
|
||||||
|
l.debug(f"Updated existing query: {query.id}")
|
||||||
else:
|
else:
|
||||||
local_session.add(query)
|
local_session.add(query)
|
||||||
|
l.debug(f"Added new query: {query.id}")
|
||||||
await local_session.commit()
|
await local_session.commit()
|
||||||
|
l.info("Finished pulling queries from all servers")
|
||||||
|
|
||||||
async def execute_unexecuted_queries(self):
|
async def execute_unexecuted_queries(self):
|
||||||
async with self.sessions[self.local_ts_id]() as session:
|
async with self.sessions[self.local_ts_id]() as session:
|
||||||
|
@ -217,35 +229,30 @@ class Database:
|
||||||
)
|
)
|
||||||
unexecuted_queries = unexecuted_queries.fetchall()
|
unexecuted_queries = unexecuted_queries.fetchall()
|
||||||
|
|
||||||
|
l.info(f"Executing {len(unexecuted_queries)} unexecuted queries")
|
||||||
for query in unexecuted_queries:
|
for query in unexecuted_queries:
|
||||||
try:
|
try:
|
||||||
params = json.loads(query.args)
|
params = json.loads(query.args)
|
||||||
result = await session.execute(text(query.query), params)
|
await session.execute(text(query.query), params)
|
||||||
|
|
||||||
# Validate result checksum
|
|
||||||
result_str = str(result.fetchall())
|
|
||||||
result_checksum = hashlib.md5(result_str.encode()).hexdigest()
|
|
||||||
|
|
||||||
if result_checksum == query.result_checksum:
|
|
||||||
query.completed_by[self.local_ts_id] = True
|
query.completed_by[self.local_ts_id] = True
|
||||||
await session.commit()
|
await session.commit()
|
||||||
l.info(f"Successfully executed query ID {query.id}")
|
l.info(f"Successfully executed query ID {query.id}")
|
||||||
else:
|
|
||||||
l.error(f"Checksum mismatch for query ID {query.id}")
|
|
||||||
await session.rollback()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
l.error(f"Failed to execute query ID {query.id}: {str(e)}")
|
l.error(f"Failed to execute query ID {query.id}: {str(e)}")
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
|
l.info("Finished executing unexecuted queries")
|
||||||
|
|
||||||
async def call_db_sync_on_servers(self):
|
async def call_db_sync_on_servers(self):
|
||||||
"""Call /db/sync on all online servers."""
|
"""Call /db/sync on all online servers."""
|
||||||
online_servers = await self.get_online_servers()
|
online_servers = await self.get_online_servers()
|
||||||
|
l.info(f"Calling /db/sync on {len(online_servers)} online servers")
|
||||||
for server in self.config['POOL']:
|
for server in self.config['POOL']:
|
||||||
if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id:
|
if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id:
|
||||||
try:
|
try:
|
||||||
await self.call_db_sync(server)
|
await self.call_db_sync(server)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}")
|
l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}")
|
||||||
|
l.info("Finished calling /db/sync on all servers")
|
||||||
|
|
||||||
async def call_db_sync(self, server):
|
async def call_db_sync(self, server):
|
||||||
url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
|
url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
|
||||||
|
@ -264,20 +271,6 @@ class Database:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
l.error(f"Error calling /db/sync on {url}: {str(e)}")
|
l.error(f"Error calling /db/sync on {url}: {str(e)}")
|
||||||
|
|
||||||
async def sync_db(self):
|
|
||||||
current_time = time.time()
|
|
||||||
if current_time - self.last_sync_time < 30:
|
|
||||||
l.info("Skipping sync, last sync was less than 30 seconds ago")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.pull_query_tracking_from_all_servers()
|
|
||||||
await self.execute_unexecuted_queries()
|
|
||||||
self.last_sync_time = current_time
|
|
||||||
except Exception as e:
|
|
||||||
l.error(f"Error during database sync: {str(e)}")
|
|
||||||
l.error(f"Traceback: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
async def ensure_query_tracking_table(self):
|
async def ensure_query_tracking_table(self):
|
||||||
for ts_id, engine in self.engines.items():
|
for ts_id, engine in self.engines.items():
|
||||||
try:
|
try:
|
||||||
|
@ -287,32 +280,7 @@ class Database:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}")
|
l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}")
|
||||||
|
|
||||||
async def pull_query_tracking_from_all_servers(self):
|
|
||||||
online_servers = await self.get_online_servers()
|
|
||||||
|
|
||||||
for server_id in online_servers:
|
|
||||||
if server_id == self.local_ts_id:
|
|
||||||
continue # Skip local server
|
|
||||||
|
|
||||||
async with self.sessions[server_id]() as remote_session:
|
|
||||||
queries = await remote_session.execute(select(QueryTracking))
|
|
||||||
queries = queries.fetchall()
|
|
||||||
|
|
||||||
async with self.sessions[self.local_ts_id]() as local_session:
|
|
||||||
for query in queries:
|
|
||||||
existing = await local_session.execute(
|
|
||||||
select(QueryTracking).where(QueryTracking.id == query.id)
|
|
||||||
)
|
|
||||||
existing = existing.scalar_one_or_none()
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
existing.completed_by = {**existing.completed_by, **query.completed_by}
|
|
||||||
else:
|
|
||||||
local_session.add(query)
|
|
||||||
await local_session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
for engine in self.engines.values():
|
for engine in self.engines.values():
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
|
l.info("Closed all database connections")
|
||||||
|
|
|
@ -67,37 +67,9 @@ async def get_tailscale_ip():
|
||||||
else:
|
else:
|
||||||
return "No devices found"
|
return "No devices found"
|
||||||
|
|
||||||
async def sync_process():
|
|
||||||
async with Db.sessions[TS_ID]() as session:
|
|
||||||
# Find unexecuted queries
|
|
||||||
unexecuted_queries = await session.execute(
|
|
||||||
select(QueryTracking).where(~QueryTracking.completed_by.has_key(TS_ID)).order_by(QueryTracking.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
for query in unexecuted_queries:
|
|
||||||
try:
|
|
||||||
params = json_loads(query.args)
|
|
||||||
await session.execute(text(query.query), params)
|
|
||||||
actual_checksum = await Db._local_compute_checksum(query.query, params)
|
|
||||||
if actual_checksum != query.result_checksum:
|
|
||||||
l.error(f"Checksum mismatch for query ID {query.id}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Update the completed_by field
|
|
||||||
query.completed_by[TS_ID] = True
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
l.info(f"Successfully executed and verified query ID {query.id}")
|
|
||||||
except Exception as e:
|
|
||||||
l.error(f"Failed to execute query ID {query.id} during sync: {str(e)}")
|
|
||||||
await session.rollback()
|
|
||||||
|
|
||||||
l.info(f"Sync process completed. Executed {unexecuted_queries.rowcount} queries.")
|
|
||||||
|
|
||||||
# After executing all queries, perform combinatorial sync
|
|
||||||
await Db.sync_query_tracking()
|
|
||||||
|
|
||||||
@sys.post("/db/sync")
|
@sys.post("/db/sync")
|
||||||
async def db_sync(background_tasks: BackgroundTasks):
|
async def db_sync(background_tasks: BackgroundTasks):
|
||||||
background_tasks.add_task(sync_process)
|
l.info(f"Received request to /db/sync")
|
||||||
|
background_tasks.add_task(Db.sync_db)
|
||||||
return {"message": "Sync process initiated"}
|
return {"message": "Sync process initiated"}
|
Loading…
Reference in a new issue