Auto-update: Mon Aug 12 18:01:13 PDT 2024
This commit is contained in:
parent
e4db7a0f88
commit
d3930ca85f
2 changed files with 60 additions and 269 deletions
|
@ -52,7 +52,6 @@ class QueryTracking(Base):
|
|||
args = Column(JSONB)
|
||||
executed_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
completed_by = Column(JSONB, default={})
|
||||
result_checksum = Column(String)
|
||||
|
||||
class Database:
|
||||
@classmethod
|
||||
|
@ -65,6 +64,7 @@ class Database:
|
|||
self.sessions: Dict[str, Any] = {}
|
||||
self.online_servers: set = set()
|
||||
self.local_ts_id = self.get_local_ts_id()
|
||||
self.last_sync_time = 0
|
||||
|
||||
def load_config(self, config_path: str) -> Dict[str, Any]:
|
||||
base_path = Path(__file__).parent.parent
|
||||
|
@ -127,7 +127,6 @@ class Database:
|
|||
l.error(f"Failed to execute read query: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def write(self, query: str, **kwargs):
|
||||
if self.local_ts_id not in self.sessions:
|
||||
l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.")
|
||||
|
@ -135,27 +134,15 @@ class Database:
|
|||
|
||||
async with self.sessions[self.local_ts_id]() as session:
|
||||
try:
|
||||
# a. Execute the write query locally
|
||||
# Execute the write query locally
|
||||
serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
|
||||
result = await session.execute(text(query), serialized_kwargs)
|
||||
|
||||
# b. Log the query in query_tracking table
|
||||
new_query = QueryTracking(
|
||||
ts_id=self.local_ts_id,
|
||||
query=query,
|
||||
args=json_dumps(kwargs),
|
||||
completed_by={self.local_ts_id: True}
|
||||
)
|
||||
session.add(new_query)
|
||||
await session.flush()
|
||||
query_id = new_query.id
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Initiate async operations
|
||||
asyncio.create_task(self._async_sync_operations(query_id, query, serialized_kwargs))
|
||||
asyncio.create_task(self._async_sync_operations(query, kwargs))
|
||||
|
||||
# c. Return the result
|
||||
# Return the result
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
@ -166,55 +153,36 @@ class Database:
|
|||
l.error(f"Traceback: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def _async_sync_operations(self, query_id: int, query: str, params: dict):
|
||||
async def _async_sync_operations(self, query: str, kwargs: dict):
|
||||
try:
|
||||
# a. Calculate and add checksum
|
||||
checksum = await self._local_compute_checksum(query, params)
|
||||
await self.update_query_checksum(query_id, checksum)
|
||||
# Add the write query to the query_tracking table
|
||||
await self.add_query_to_tracking(query, kwargs)
|
||||
|
||||
# b. Synchronize query_tracking table
|
||||
await self.sync_query_tracking()
|
||||
|
||||
# c. Call /db/sync on all servers
|
||||
# Call /db/sync on all online servers
|
||||
await self.call_db_sync_on_servers()
|
||||
except Exception as e:
|
||||
l.error(f"Error in async sync operations: {str(e)}")
|
||||
l.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
async def call_db_sync_on_servers(self):
|
||||
"""Call /db/sync on all online servers."""
|
||||
online_servers = await self.get_online_servers()
|
||||
for server in self.config['POOL']:
|
||||
if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id:
|
||||
asyncio.create_task(self.call_db_sync(server))
|
||||
|
||||
async def call_db_sync(self, server):
|
||||
url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {server['api_key']}"
|
||||
}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, headers=headers, timeout=30) as response:
|
||||
if response.status == 200:
|
||||
l.info(f"Successfully called /db/sync on {url}")
|
||||
else:
|
||||
l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}")
|
||||
except asyncio.TimeoutError:
|
||||
l.debug(f"Timeout while calling /db/sync on {url}")
|
||||
except Exception as e:
|
||||
l.error(f"Error calling /db/sync on {url}: {str(e)}")
|
||||
|
||||
async def add_query_to_tracking(self, query: str, kwargs: dict):
|
||||
async with self.sessions[self.local_ts_id]() as session:
|
||||
new_query = QueryTracking(
|
||||
ts_id=self.local_ts_id,
|
||||
query=query,
|
||||
args=json_dumps(kwargs),
|
||||
completed_by={self.local_ts_id: True}
|
||||
)
|
||||
session.add(new_query)
|
||||
await session.commit()
|
||||
|
||||
async def get_primary_server(self) -> str:
|
||||
url = urljoin(self.config['URL'], '/id')
|
||||
|
||||
url = f"{self.config['URL']}/id"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(url) as response:
|
||||
if response.status == 200:
|
||||
primary_ts_id = await response.text()
|
||||
return primary_ts_id.strip()
|
||||
return primary_ts_id.strip().strip('"')
|
||||
else:
|
||||
l.error(f"Failed to get primary server. Status: {response.status}")
|
||||
return None
|
||||
|
@ -222,207 +190,62 @@ class Database:
|
|||
l.error(f"Error connecting to load balancer: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_checksum_server(self) -> dict:
|
||||
async def pull_query_tracking_from_primary(self):
|
||||
primary_ts_id = await self.get_primary_server()
|
||||
online_servers = await self.get_online_servers()
|
||||
|
||||
checksum_servers = [server for server in self.config['POOL'] if server['ts_id'] in online_servers and server['ts_id'] != primary_ts_id]
|
||||
|
||||
if not checksum_servers:
|
||||
return next(server for server in self.config['POOL'] if server['ts_id'] == primary_ts_id)
|
||||
|
||||
return random.choice(checksum_servers)
|
||||
if not primary_ts_id:
|
||||
l.error("Failed to get primary server")
|
||||
return
|
||||
|
||||
async def _local_compute_checksum(self, query: str, params: dict):
|
||||
primary_server = next((s for s in self.config['POOL'] if s['ts_id'] == primary_ts_id), None)
|
||||
if not primary_server:
|
||||
l.error(f"Primary server {primary_ts_id} not found in config")
|
||||
return
|
||||
|
||||
async with self.sessions[primary_ts_id]() as session:
|
||||
queries = await session.execute(select(QueryTracking))
|
||||
queries = queries.fetchall()
|
||||
|
||||
async with self.sessions[self.local_ts_id]() as local_session:
|
||||
for query in queries:
|
||||
existing = await local_session.get(QueryTracking, query.id)
|
||||
if existing:
|
||||
existing.completed_by = {**existing.completed_by, **query.completed_by}
|
||||
else:
|
||||
local_session.add(query)
|
||||
await local_session.commit()
|
||||
|
||||
async def execute_unexecuted_queries(self):
|
||||
async with self.sessions[self.local_ts_id]() as session:
|
||||
result = await session.execute(text(query), params)
|
||||
if result.returns_rows:
|
||||
data = result.fetchall()
|
||||
else:
|
||||
data = str(result.rowcount) + query + str(params)
|
||||
checksum = hashlib.md5(str(data).encode()).hexdigest()
|
||||
return checksum
|
||||
|
||||
async def _delegate_compute_checksum(self, server: Dict[str, Any], query: str, params: dict):
|
||||
url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum"
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.post(url, json={"query": query, "params": params}) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
return result['checksum']
|
||||
else:
|
||||
l.error(f"Failed to get checksum from {server['ts_id']}. Status: {response.status}")
|
||||
return await self._local_compute_checksum(query, params)
|
||||
except aiohttp.ClientError as e:
|
||||
l.error(f"Error connecting to {server['ts_id']} for checksum: {str(e)}")
|
||||
return await self._local_compute_checksum(query, params)
|
||||
|
||||
async def update_query_checksum(self, query_id: int, checksum: str):
|
||||
async with self.sessions[self.local_ts_id]() as session:
|
||||
await session.execute(
|
||||
text("UPDATE query_tracking SET result_checksum = :checksum WHERE id = :id"),
|
||||
{"checksum": checksum, "id": query_id}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def _replicate_write(self, ts_id: str, query_id: int, query: str, params: dict, expected_checksum: str):
|
||||
try:
|
||||
async with self.sessions[ts_id]() as session:
|
||||
await session.execute(text(query), params)
|
||||
actual_checksum = await self._local_compute_checksum(query, params)
|
||||
if actual_checksum != expected_checksum:
|
||||
raise ValueError(f"Checksum mismatch on {ts_id}")
|
||||
await self.mark_query_completed(query_id, ts_id)
|
||||
await session.commit()
|
||||
l.info(f"Successfully replicated write to {ts_id}")
|
||||
except Exception as e:
|
||||
l.error(f"Failed to replicate write on {ts_id}: {str(e)}")
|
||||
l.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
async def mark_query_completed(self, query_id: int, ts_id: str):
|
||||
async with self.sessions[self.local_ts_id]() as session:
|
||||
query = await session.get(QueryTracking, query_id)
|
||||
if query:
|
||||
completed_by = query.completed_by or {}
|
||||
completed_by[ts_id] = True
|
||||
query.completed_by = completed_by
|
||||
await session.commit()
|
||||
|
||||
async def sync_local_server(self):
|
||||
async with self.sessions[self.local_ts_id]() as session:
|
||||
last_synced = await session.execute(
|
||||
text("SELECT MAX(id) FROM query_tracking WHERE completed_by ? :ts_id"),
|
||||
{"ts_id": self.local_ts_id}
|
||||
)
|
||||
last_synced_id = last_synced.scalar() or 0
|
||||
|
||||
unexecuted_queries = await session.execute(
|
||||
text("SELECT * FROM query_tracking WHERE id > :last_id ORDER BY id"),
|
||||
{"last_id": last_synced_id}
|
||||
select(QueryTracking).where(~QueryTracking.completed_by.has_key(self.local_ts_id)).order_by(QueryTracking.id)
|
||||
)
|
||||
unexecuted_queries = unexecuted_queries.fetchall()
|
||||
|
||||
for query in unexecuted_queries:
|
||||
try:
|
||||
params = json.loads(query.args)
|
||||
await session.execute(text(query.query), params)
|
||||
actual_checksum = await self._local_compute_checksum(query.query, params)
|
||||
if actual_checksum != query.result_checksum:
|
||||
raise ValueError(f"Checksum mismatch for query ID {query.id}")
|
||||
await self.mark_query_completed(query.id, self.local_ts_id)
|
||||
query.completed_by[self.local_ts_id] = True
|
||||
await session.commit()
|
||||
l.info(f"Successfully executed query ID {query.id}")
|
||||
except Exception as e:
|
||||
l.error(f"Failed to execute query ID {query.id} during local sync: {str(e)}")
|
||||
l.error(f"Failed to execute query ID {query.id}: {str(e)}")
|
||||
await session.rollback()
|
||||
|
||||
await session.commit()
|
||||
l.info(f"Local server sync completed. Executed {unexecuted_queries.rowcount} queries.")
|
||||
async def sync_db(self):
|
||||
current_time = time.time()
|
||||
if current_time - self.last_sync_time < 30:
|
||||
l.info("Skipping sync, last sync was less than 30 seconds ago")
|
||||
return
|
||||
|
||||
async def purge_completed_queries(self):
|
||||
async with self.sessions[self.local_ts_id]() as session:
|
||||
all_ts_ids = [db['ts_id'] for db in self.config['POOL']]
|
||||
|
||||
result = await session.execute(
|
||||
text("""
|
||||
DELETE FROM query_tracking
|
||||
WHERE id <= (
|
||||
SELECT MAX(id)
|
||||
FROM query_tracking
|
||||
WHERE completed_by ?& :ts_ids
|
||||
)
|
||||
"""),
|
||||
{"ts_ids": all_ts_ids}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
l.info(f"Purged {deleted_count} completed queries.")
|
||||
|
||||
|
||||
async def sync_query_tracking(self):
|
||||
"""Combinatorial sync method for the query_tracking table."""
|
||||
try:
|
||||
online_servers = await self.get_online_servers()
|
||||
|
||||
for ts_id in online_servers:
|
||||
if ts_id == self.local_ts_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
async with self.sessions[ts_id]() as remote_session:
|
||||
local_max_id = await self.get_max_query_id(self.local_ts_id)
|
||||
remote_max_id = await self.get_max_query_id(ts_id)
|
||||
|
||||
# Sync from remote to local
|
||||
remote_new_queries = await remote_session.execute(
|
||||
select(QueryTracking).where(QueryTracking.id > local_max_id)
|
||||
)
|
||||
remote_new_queries = remote_new_queries.fetchall()
|
||||
for query in remote_new_queries:
|
||||
await self.add_or_update_query(query[0])
|
||||
|
||||
# Sync from local to remote
|
||||
async with self.sessions[self.local_ts_id]() as local_session:
|
||||
local_new_queries = await local_session.execute(
|
||||
select(QueryTracking).where(QueryTracking.id > remote_max_id)
|
||||
)
|
||||
local_new_queries = local_new_queries.fetchall()
|
||||
for query in local_new_queries:
|
||||
await self.add_or_update_query_remote(ts_id, query[0])
|
||||
except Exception as e:
|
||||
l.error(f"Error syncing with {ts_id}: {str(e)}")
|
||||
l.error(f"Traceback: {traceback.format_exc()}")
|
||||
await self.pull_query_tracking_from_primary()
|
||||
await self.execute_unexecuted_queries()
|
||||
self.last_sync_time = current_time
|
||||
except Exception as e:
|
||||
l.error(f"Error in sync_query_tracking: {str(e)}")
|
||||
l.error(f"Error during database sync: {str(e)}")
|
||||
l.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
|
||||
async def get_max_query_id(self, ts_id):
|
||||
async with self.sessions[ts_id]() as session:
|
||||
result = await session.execute(select(func.max(QueryTracking.id)))
|
||||
return result.scalar() or 0
|
||||
|
||||
async def add_or_update_query(self, query):
|
||||
async with self.sessions[self.local_ts_id]() as session:
|
||||
existing_query = await session.get(QueryTracking, query.id)
|
||||
if existing_query:
|
||||
existing_query.completed_by = {**existing_query.completed_by, **query.completed_by}
|
||||
else:
|
||||
session.add(query)
|
||||
await session.commit()
|
||||
|
||||
async def add_or_update_query_remote(self, ts_id, query):
|
||||
async with self.sessions[ts_id]() as session:
|
||||
existing_query = await session.get(QueryTracking, query.id)
|
||||
if existing_query:
|
||||
existing_query.completed_by = {**existing_query.completed_by, **query.completed_by}
|
||||
else:
|
||||
new_query = QueryTracking(
|
||||
id=query.id,
|
||||
ts_id=query.ts_id,
|
||||
query=query.query,
|
||||
args=query.args,
|
||||
executed_at=query.executed_at,
|
||||
completed_by=query.completed_by,
|
||||
result_checksum=query.result_checksum
|
||||
)
|
||||
session.add(new_query)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
l.error(f"Failed to add or update query on {ts_id}: {str(e)}")
|
||||
await session.rollback()
|
||||
|
||||
async def ensure_query_tracking_table(self):
|
||||
for ts_id, engine in self.engines.items():
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
l.info(f"Ensured query_tracking table exists for {ts_id}")
|
||||
except Exception as e:
|
||||
l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}")
|
||||
|
||||
|
||||
async def call_db_sync_on_servers(self):
|
||||
"""Call /db/sync on all online servers."""
|
||||
online_servers = await self.get_online_servers()
|
||||
|
@ -433,7 +256,6 @@ class Database:
|
|||
except Exception as e:
|
||||
l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}")
|
||||
|
||||
|
||||
async def call_db_sync(self, server):
|
||||
url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
|
||||
headers = {
|
||||
|
@ -451,7 +273,6 @@ class Database:
|
|||
except Exception as e:
|
||||
l.error(f"Error calling /db/sync on {url}: {str(e)}")
|
||||
|
||||
|
||||
async def close(self):
|
||||
for engine in self.engines.values():
|
||||
await engine.dispose()
|
||||
|
|
|
@ -67,37 +67,7 @@ async def get_tailscale_ip():
|
|||
else:
|
||||
return "No devices found"
|
||||
|
||||
async def sync_process():
|
||||
async with Db.sessions[TS_ID]() as session:
|
||||
# Find unexecuted queries
|
||||
unexecuted_queries = await session.execute(
|
||||
select(QueryTracking).where(~QueryTracking.completed_by.has_key(TS_ID)).order_by(QueryTracking.id)
|
||||
)
|
||||
|
||||
for query in unexecuted_queries:
|
||||
try:
|
||||
params = json_loads(query.args)
|
||||
await session.execute(text(query.query), params)
|
||||
actual_checksum = await Db._local_compute_checksum(query.query, params)
|
||||
if actual_checksum != query.result_checksum:
|
||||
l.error(f"Checksum mismatch for query ID {query.id}")
|
||||
continue
|
||||
|
||||
# Update the completed_by field
|
||||
query.completed_by[TS_ID] = True
|
||||
await session.commit()
|
||||
|
||||
l.info(f"Successfully executed and verified query ID {query.id}")
|
||||
except Exception as e:
|
||||
l.error(f"Failed to execute query ID {query.id} during sync: {str(e)}")
|
||||
await session.rollback()
|
||||
|
||||
l.info(f"Sync process completed. Executed {unexecuted_queries.rowcount} queries.")
|
||||
|
||||
# After executing all queries, perform combinatorial sync
|
||||
await Db.sync_query_tracking()
|
||||
|
||||
@sys.post("/db/sync")
|
||||
async def db_sync(background_tasks: BackgroundTasks):
|
||||
background_tasks.add_task(sync_process)
|
||||
background_tasks.add_task(Db.sync_db)
|
||||
return {"message": "Sync process initiated"}
|
||||
|
|
Loading…
Reference in a new issue