Auto-update: Mon Aug 12 22:34:19 PDT 2024

This commit is contained in:
sanj 2024-08-12 22:34:19 -07:00
parent 27f4beb50c
commit 6896592356

View file

@ -45,18 +45,21 @@ load_dotenv(ENV_PATH)
TS_ID = os.environ.get('TS_ID') TS_ID = os.environ.get('TS_ID')
class QueryTracking(Base): class QueryTracking(Base):
__tablename__ = 'query_tracking' __tablename__ = 'query_tracking'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
origin_ts_id = Column(String, nullable=False) origin_ts_id = Column(String, nullable=False)
query = Column(Text, nullable=False) query = Column(Text, nullable=False)
args = Column(JSONB) args = Column(JSON)
executed_at = Column(DateTime(timezone=True), server_default=func.now()) executed_at = Column(DateTime(timezone=True), server_default=func.now())
completed_by = Column(JSONB, default={}) completed_by = Column(ARRAY(String), default=[])
result_checksum = Column(String(32)) result_checksum = Column(String(32))
class Database: class Database:
@classmethod @classmethod
def init(cls, config_name: str): def init(cls, config_name: str):
@ -178,7 +181,7 @@ class Database:
origin_ts_id=self.local_ts_id, origin_ts_id=self.local_ts_id,
query=query, query=query,
args=json_dumps(kwargs), args=json_dumps(kwargs),
completed_by={self.local_ts_id: True}, completed_by=[self.local_ts_id],
result_checksum=result_checksum result_checksum=result_checksum
) )
session.add(new_query) session.add(new_query)
@ -202,6 +205,7 @@ class Database:
l.error(f"Error during database sync: {str(e)}") l.error(f"Error during database sync: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}") l.error(f"Traceback: {traceback.format_exc()}")
async def pull_query_tracking_from_all_servers(self): async def pull_query_tracking_from_all_servers(self):
online_servers = await self.get_online_servers() online_servers = await self.get_online_servers()
l.info(f"Pulling query tracking from {len(online_servers)} online servers") l.info(f"Pulling query tracking from {len(online_servers)} online servers")
@ -212,45 +216,51 @@ class Database:
l.info(f"Pulling queries from server: {server_id}") l.info(f"Pulling queries from server: {server_id}")
async with self.sessions[server_id]() as remote_session: async with self.sessions[server_id]() as remote_session:
queries = await remote_session.execute(select(QueryTracking)) try:
queries = queries.fetchall() result = await remote_session.execute(select(QueryTracking))
queries = result.scalars().all()
l.info(f"Retrieved {len(queries)} queries from server {server_id}")
l.info(f"Retrieved {len(queries)} queries from server {server_id}") async with self.sessions[self.local_ts_id]() as local_session:
async with self.sessions[self.local_ts_id]() as local_session: for query in queries:
for query in queries: existing = await local_session.execute(
existing = await local_session.execute( select(QueryTracking).where(QueryTracking.id == query.id)
select(QueryTracking).where(QueryTracking.id == query.id) )
) existing = existing.scalar_one_or_none()
existing = existing.scalar_one_or_none()
if existing: if existing:
existing.completed_by = {**existing.completed_by, **query.completed_by} existing.completed_by = list(set(existing.completed_by + query.completed_by))
l.debug(f"Updated existing query: {query.id}") l.debug(f"Updated existing query: {query.id}")
else: else:
local_session.add(query) local_session.add(query)
l.debug(f"Added new query: {query.id}") l.debug(f"Added new query: {query.id}")
await local_session.commit() await local_session.commit()
except Exception as e:
l.error(f"Error pulling queries from server {server_id}: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
l.info("Finished pulling queries from all servers") l.info("Finished pulling queries from all servers")
async def execute_unexecuted_queries(self): async def execute_unexecuted_queries(self):
async with self.sessions[self.local_ts_id]() as session: async with self.sessions[self.local_ts_id]() as session:
unexecuted_queries = await session.execute( unexecuted_queries = await session.execute(
select(QueryTracking).where(~QueryTracking.completed_by.has_key(self.local_ts_id)).order_by(QueryTracking.executed_at) select(QueryTracking).where(~QueryTracking.completed_by.any(self.local_ts_id)).order_by(QueryTracking.executed_at)
) )
unexecuted_queries = unexecuted_queries.fetchall() unexecuted_queries = unexecuted_queries.scalars().all()
l.info(f"Executing {len(unexecuted_queries)} unexecuted queries") l.info(f"Executing {len(unexecuted_queries)} unexecuted queries")
for query in unexecuted_queries: for query in unexecuted_queries:
try: try:
params = json.loads(query.args) params = json.loads(query.args)
await session.execute(text(query.query), params) await session.execute(text(query.query), params)
query.completed_by[self.local_ts_id] = True query.completed_by = list(set(query.completed_by + [self.local_ts_id]))
await session.commit() await session.commit()
l.info(f"Successfully executed query ID {query.id}") l.info(f"Successfully executed query ID {query.id}")
except Exception as e: except Exception as e:
l.error(f"Failed to execute query ID {query.id}: {str(e)}") l.error(f"Failed to execute query ID {query.id}: {str(e)}")
await session.rollback() await session.rollback()
l.info("Finished executing unexecuted queries") l.info("Finished executing unexecuted queries")
async def call_db_sync_on_servers(self): async def call_db_sync_on_servers(self):
"""Call /db/sync on all online servers.""" """Call /db/sync on all online servers."""