diff --git a/sijapi/classes.py b/sijapi/classes.py index 5a23d45..f0749be 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -832,36 +832,63 @@ class APIConfig(BaseModel): - async def execute_write_query(self, query: str, *args, table_name: str): - conn = await self.get_connection() - if conn is None: - raise ConnectionError("Failed to connect to local database") + async def execute_write_query(self, query: str, *args, table_name: str, max_retries=3, retry_delay=1): + local_ts_id = os.environ.get('TS_ID') - try: - if table_name in self.SPECIAL_TABLES: - return await self._execute_special_table_write(conn, query, *args, table_name=table_name) + for attempt in range(max_retries): + try: + # Execute on local database + local_conn = await self.get_connection() + if local_conn is None: + raise ConnectionError("Failed to connect to local database") - primary_key = await self.ensure_sync_columns(conn, table_name) + try: + if table_name in self.SPECIAL_TABLES: + local_result = await self._execute_special_table_write(local_conn, query, *args, table_name=table_name) + else: + local_result = await local_conn.fetch(query, *args) if query.strip().upper().startswith("INSERT") and "RETURNING" in query.upper() else await local_conn.execute(query, *args) + + # Start background task to sync with other databases + asyncio.create_task(self._sync_write_to_other_dbs(query, args, table_name)) + + return local_result + finally: + await local_conn.close() - if query.strip().upper().startswith("INSERT") and "RETURNING" in query.upper(): - result = await conn.fetch(query, *args) - else: - result = await conn.execute(query, *args) + except (ConnectionError, asyncpg.exceptions.ConnectionDoesNotExistError) as e: + if attempt == max_retries - 1: + raise + warn(f"Connection error on attempt {attempt + 1}. Retrying in {retry_delay} seconds...") + await asyncio.sleep(retry_delay) + except Exception as e: + err(f"Error executing write query on table {table_name}: {str(e)}") + err(f"Query: {query}") + err(f"Args: {args}") + err(f"Traceback: {traceback.format_exc()}") + raise - # Schedule the sync task - asyncio.create_task(self._sync_changes(table_name, primary_key)) - - return result - except Exception as e: - err(f"Error executing write query on table {table_name}: {str(e)}") - err(f"Query: {query}") - err(f"Args: {args}") - err(f"Traceback: {traceback.format_exc()}") - raise - finally: - await conn.close() + raise ConnectionError(f"Failed to execute write query after {max_retries} attempts") + async def _sync_write_to_other_dbs(self, query: str, args: tuple, table_name: str): + local_ts_id = os.environ.get('TS_ID') + online_hosts = await self.get_online_hosts() + + for pool_entry in online_hosts: + if pool_entry['ts_id'] != local_ts_id: + remote_conn = await self.get_connection(pool_entry) + if remote_conn is None: + warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping write operation.") + continue + + try: + await remote_conn.execute(query, *args) + debug(f"Successfully synced write operation to {pool_entry['ts_id']} for table {table_name}") + except Exception as e: + err(f"Error executing write query on {pool_entry['ts_id']}: {str(e)}") + finally: + await remote_conn.close() + async def _run_sync_tasks(self, tasks): for task in tasks: