From 68965923565fc0a1c80d455ce257ba2d26d6a2f3 Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Mon, 12 Aug 2024 22:34:19 -0700 Subject: [PATCH] Auto-update: Mon Aug 12 22:34:19 PDT 2024 --- sijapi/database.py | 58 +++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/sijapi/database.py b/sijapi/database.py index 78a58f6..f9f52d3 100644 --- a/sijapi/database.py +++ b/sijapi/database.py @@ -45,18 +45,21 @@ load_dotenv(ENV_PATH) TS_ID = os.environ.get('TS_ID') + + class QueryTracking(Base): __tablename__ = 'query_tracking' id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) origin_ts_id = Column(String, nullable=False) query = Column(Text, nullable=False) - args = Column(JSONB) + args = Column(JSON) executed_at = Column(DateTime(timezone=True), server_default=func.now()) - completed_by = Column(JSONB, default={}) + completed_by = Column(ARRAY(String), default=[]) result_checksum = Column(String(32)) + class Database: @classmethod def init(cls, config_name: str): @@ -178,7 +181,7 @@ class Database: origin_ts_id=self.local_ts_id, query=query, args=json_dumps(kwargs), - completed_by={self.local_ts_id: True}, + completed_by=[self.local_ts_id], result_checksum=result_checksum ) session.add(new_query) @@ -202,6 +205,7 @@ class Database: l.error(f"Error during database sync: {str(e)}") l.error(f"Traceback: {traceback.format_exc()}") + async def pull_query_tracking_from_all_servers(self): online_servers = await self.get_online_servers() l.info(f"Pulling query tracking from {len(online_servers)} online servers") @@ -212,45 +216,51 @@ class Database: l.info(f"Pulling queries from server: {server_id}") async with self.sessions[server_id]() as remote_session: - queries = await remote_session.execute(select(QueryTracking)) - queries = queries.fetchall() + try: + result = await remote_session.execute(select(QueryTracking)) + queries = result.scalars().all() + l.info(f"Retrieved {len(queries)} queries from server {server_id}") - l.info(f"Retrieved {len(queries)} queries from server {server_id}") - async with self.sessions[self.local_ts_id]() as local_session: - for query in queries: - existing = await local_session.execute( - select(QueryTracking).where(QueryTracking.id == query.id) - ) - existing = existing.scalar_one_or_none() - - if existing: - existing.completed_by = {**existing.completed_by, **query.completed_by} - l.debug(f"Updated existing query: {query.id}") - else: - local_session.add(query) - l.debug(f"Added new query: {query.id}") - await local_session.commit() + async with self.sessions[self.local_ts_id]() as local_session: + for query in queries: + existing = await local_session.execute( + select(QueryTracking).where(QueryTracking.id == query.id) + ) + existing = existing.scalar_one_or_none() + + if existing: + existing.completed_by = list(set(existing.completed_by + query.completed_by)) + l.debug(f"Updated existing query: {query.id}") + else: + local_session.add(query) + l.debug(f"Added new query: {query.id}") + await local_session.commit() + except Exception as e: + l.error(f"Error pulling queries from server {server_id}: {str(e)}") + l.error(f"Traceback: {traceback.format_exc()}") l.info("Finished pulling queries from all servers") + async def execute_unexecuted_queries(self): async with self.sessions[self.local_ts_id]() as session: unexecuted_queries = await session.execute( - select(QueryTracking).where(~QueryTracking.completed_by.has_key(self.local_ts_id)).order_by(QueryTracking.executed_at) + select(QueryTracking).where(~QueryTracking.completed_by.any(self.local_ts_id)).order_by(QueryTracking.executed_at) ) - unexecuted_queries = unexecuted_queries.fetchall() + unexecuted_queries = unexecuted_queries.scalars().all() l.info(f"Executing {len(unexecuted_queries)} unexecuted queries") for query in unexecuted_queries: try: params = json.loads(query.args) await session.execute(text(query.query), params) - query.completed_by[self.local_ts_id] = True + query.completed_by = list(set(query.completed_by + [self.local_ts_id])) await session.commit() l.info(f"Successfully executed query ID {query.id}") except Exception as e: l.error(f"Failed to execute query ID {query.id}: {str(e)}") await session.rollback() - l.info("Finished executing unexecuted queries") + l.info("Finished executing unexecuted queries") + async def call_db_sync_on_servers(self): """Call /db/sync on all online servers."""