Auto-update: Mon Aug 12 20:45:04 PDT 2024
This commit is contained in:
parent
9863965976
commit
301e8e1408
2 changed files with 103 additions and 42 deletions
|
@ -20,6 +20,8 @@ from zoneinfo import ZoneInfo
|
|||
from srtm import get_data
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from loguru import logger
|
||||
from sqlalchemy import text, select, func, and_
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
|
@ -43,15 +45,18 @@ ENV_PATH = CONFIG_DIR / ".env"
|
|||
load_dotenv(ENV_PATH)
|
||||
TS_ID = os.environ.get('TS_ID')
|
||||
|
||||
|
||||
class QueryTracking(Base):
|
||||
__tablename__ = 'query_tracking'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
ts_id = Column(String, nullable=False)
|
||||
query = Column(Text, nullable=False)
|
||||
args = Column(JSONB)
|
||||
args = Column(JSON)
|
||||
executed_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
completed_by = Column(JSONB, default={})
|
||||
completed_by = Column(JSON, default={})
|
||||
result_checksum = Column(String(32)) # MD5 checksum
|
||||
|
||||
|
||||
class Database:
|
||||
@classmethod
|
||||
|
@ -139,8 +144,15 @@ class Database:
|
|||
result = await session.execute(text(query), serialized_kwargs)
|
||||
await session.commit()
|
||||
|
||||
# Calculate result checksum
|
||||
result_str = str(result.fetchall())
|
||||
result_checksum = hashlib.md5(result_str.encode()).hexdigest()
|
||||
|
||||
# Add the write query to the query_tracking table
|
||||
await self.add_query_to_tracking(query, kwargs, result_checksum)
|
||||
|
||||
# Initiate async operations
|
||||
asyncio.create_task(self._async_sync_operations(query, kwargs))
|
||||
asyncio.create_task(self._async_sync_operations())
|
||||
|
||||
# Return the result
|
||||
return result
|
||||
|
@ -153,11 +165,8 @@ class Database:
|
|||
l.error(f"Traceback: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def _async_sync_operations(self, query: str, kwargs: dict):
|
||||
async def _async_sync_operations(self):
|
||||
try:
|
||||
# Add the write query to the query_tracking table
|
||||
await self.add_query_to_tracking(query, kwargs)
|
||||
|
||||
# Call /db/sync on all online servers
|
||||
await self.call_db_sync_on_servers()
|
||||
except Exception as e:
|
||||
|
@ -175,20 +184,6 @@ class Database:
|
|||
session.add(new_query)
|
||||
await session.commit()
|
||||
|
||||
async def get_primary_server(self) -> str:
|
||||
url = f"{self.config['URL']}/id"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
primary_ts_id = await response.text()
|
||||
return primary_ts_id.strip().strip('"')
|
||||
else:
|
||||
l.error(f"Failed to get primary server. Status: {response.status}")
|
||||
return None
|
||||
except aiohttp.ClientError as e:
|
||||
l.error(f"Error connecting to load balancer: {str(e)}")
|
||||
return None
|
||||
|
||||
async def pull_query_tracking_from_primary(self):
|
||||
primary_ts_id = await self.get_primary_server()
|
||||
|
@ -214,38 +209,34 @@ class Database:
|
|||
local_session.add(query)
|
||||
await local_session.commit()
|
||||
|
||||
|
||||
async def execute_unexecuted_queries(self):
|
||||
async with self.sessions[self.local_ts_id]() as session:
|
||||
unexecuted_queries = await session.execute(
|
||||
select(QueryTracking).where(~QueryTracking.completed_by.has_key(self.local_ts_id)).order_by(QueryTracking.id)
|
||||
select(QueryTracking).where(~QueryTracking.completed_by.has_key(self.local_ts_id)).order_by(QueryTracking.executed_at)
|
||||
)
|
||||
unexecuted_queries = unexecuted_queries.fetchall()
|
||||
|
||||
for query in unexecuted_queries:
|
||||
try:
|
||||
params = json.loads(query.args)
|
||||
await session.execute(text(query.query), params)
|
||||
query.completed_by[self.local_ts_id] = True
|
||||
await session.commit()
|
||||
l.info(f"Successfully executed query ID {query.id}")
|
||||
result = await session.execute(text(query.query), params)
|
||||
|
||||
# Validate result checksum
|
||||
result_str = str(result.fetchall())
|
||||
result_checksum = hashlib.md5(result_str.encode()).hexdigest()
|
||||
|
||||
if result_checksum == query.result_checksum:
|
||||
query.completed_by[self.local_ts_id] = True
|
||||
await session.commit()
|
||||
l.info(f"Successfully executed query ID {query.id}")
|
||||
else:
|
||||
l.error(f"Checksum mismatch for query ID {query.id}")
|
||||
await session.rollback()
|
||||
except Exception as e:
|
||||
l.error(f"Failed to execute query ID {query.id}: {str(e)}")
|
||||
await session.rollback()
|
||||
|
||||
async def sync_db(self):
|
||||
current_time = time.time()
|
||||
if current_time - self.last_sync_time < 30:
|
||||
l.info("Skipping sync, last sync was less than 30 seconds ago")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.pull_query_tracking_from_primary()
|
||||
await self.execute_unexecuted_queries()
|
||||
self.last_sync_time = current_time
|
||||
except Exception as e:
|
||||
l.error(f"Error during database sync: {str(e)}")
|
||||
l.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def call_db_sync_on_servers(self):
|
||||
"""Call /db/sync on all online servers."""
|
||||
online_servers = await self.get_online_servers()
|
||||
|
@ -273,6 +264,20 @@ class Database:
|
|||
except Exception as e:
|
||||
l.error(f"Error calling /db/sync on {url}: {str(e)}")
|
||||
|
||||
async def sync_db(self):
|
||||
current_time = time.time()
|
||||
if current_time - self.last_sync_time < 30:
|
||||
l.info("Skipping sync, last sync was less than 30 seconds ago")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.pull_query_tracking_from_all_servers()
|
||||
await self.execute_unexecuted_queries()
|
||||
self.last_sync_time = current_time
|
||||
except Exception as e:
|
||||
l.error(f"Error during database sync: {str(e)}")
|
||||
l.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def ensure_query_tracking_table(self):
|
||||
for ts_id, engine in self.engines.items():
|
||||
try:
|
||||
|
@ -281,7 +286,33 @@ class Database:
|
|||
l.info(f"Ensured query_tracking table exists for {ts_id}")
|
||||
except Exception as e:
|
||||
l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}")
|
||||
|
||||
async def pull_query_tracking_from_all_servers(self):
|
||||
online_servers = await self.get_online_servers()
|
||||
|
||||
for server_id in online_servers:
|
||||
if server_id == self.local_ts_id:
|
||||
continue # Skip local server
|
||||
|
||||
async with self.sessions[server_id]() as remote_session:
|
||||
queries = await remote_session.execute(select(QueryTracking))
|
||||
queries = queries.fetchall()
|
||||
|
||||
async with self.sessions[self.local_ts_id]() as local_session:
|
||||
for query in queries:
|
||||
existing = await local_session.execute(
|
||||
select(QueryTracking).where(QueryTracking.id == query.id)
|
||||
)
|
||||
existing = existing.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
existing.completed_by = {**existing.completed_by, **query.completed_by}
|
||||
else:
|
||||
local_session.add(query)
|
||||
await local_session.commit()
|
||||
|
||||
|
||||
async def close(self):
|
||||
for engine in self.engines.values():
|
||||
await engine.dispose()
|
||||
|
||||
|
|
|
@ -67,7 +67,37 @@ async def get_tailscale_ip():
|
|||
else:
|
||||
return "No devices found"
|
||||
|
||||
async def sync_process():
|
||||
async with Db.sessions[TS_ID]() as session:
|
||||
# Find unexecuted queries
|
||||
unexecuted_queries = await session.execute(
|
||||
select(QueryTracking).where(~QueryTracking.completed_by.has_key(TS_ID)).order_by(QueryTracking.id)
|
||||
)
|
||||
|
||||
for query in unexecuted_queries:
|
||||
try:
|
||||
params = json_loads(query.args)
|
||||
await session.execute(text(query.query), params)
|
||||
actual_checksum = await Db._local_compute_checksum(query.query, params)
|
||||
if actual_checksum != query.result_checksum:
|
||||
l.error(f"Checksum mismatch for query ID {query.id}")
|
||||
continue
|
||||
|
||||
# Update the completed_by field
|
||||
query.completed_by[TS_ID] = True
|
||||
await session.commit()
|
||||
|
||||
l.info(f"Successfully executed and verified query ID {query.id}")
|
||||
except Exception as e:
|
||||
l.error(f"Failed to execute query ID {query.id} during sync: {str(e)}")
|
||||
await session.rollback()
|
||||
|
||||
l.info(f"Sync process completed. Executed {unexecuted_queries.rowcount} queries.")
|
||||
|
||||
# After executing all queries, perform combinatorial sync
|
||||
await Db.sync_query_tracking()
|
||||
|
||||
@sys.post("/db/sync")
|
||||
async def db_sync(background_tasks: BackgroundTasks):
|
||||
background_tasks.add_task(Db.sync_db)
|
||||
background_tasks.add_task(sync_process)
|
||||
return {"message": "Sync process initiated"}
|
||||
|
|
Loading…
Reference in a new issue