Auto-update: Tue Jul 30 11:51:53 PDT 2024
This commit is contained in:
parent
72047b561c
commit
66775e0a82
2 changed files with 51 additions and 20 deletions
|
@ -31,17 +31,13 @@ args = parser.parse_args()
|
||||||
|
|
||||||
L.setup_from_args(args)
|
L.setup_from_args(args)
|
||||||
print(f"Debug modules after setup: {L.debug_modules}")
|
print(f"Debug modules after setup: {L.debug_modules}")
|
||||||
|
|
||||||
logger = L.get_module_logger("main")
|
logger = L.get_module_logger("main")
|
||||||
def debug(text: str): logger.debug(text)
|
def debug(text: str): logger.debug(text)
|
||||||
debug(f"Debug message.")
|
|
||||||
def info(text: str): logger.info(text)
|
def info(text: str): logger.info(text)
|
||||||
info(f"Info message.")
|
|
||||||
def warn(text: str): logger.warning(text)
|
def warn(text: str): logger.warning(text)
|
||||||
warn(f"Warning message.")
|
|
||||||
def err(text: str): logger.error(text)
|
def err(text: str): logger.error(text)
|
||||||
err(f"Error message.")
|
|
||||||
def crit(text: str): logger.critical(text)
|
def crit(text: str): logger.critical(text)
|
||||||
crit(f"Critical message.")
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
|
@ -79,8 +75,8 @@ async def lifespan(app: FastAPI):
|
||||||
|
|
||||||
# Shutdown
|
# Shutdown
|
||||||
crit("Shutting down...")
|
crit("Shutting down...")
|
||||||
# Perform any cleanup operations here if needed
|
await API.close_db_pools()
|
||||||
|
crit("Database pools closed.")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -165,6 +165,32 @@ class Configuration(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class DatabasePool:
|
||||||
|
def __init__(self):
|
||||||
|
self.pools = {}
|
||||||
|
|
||||||
|
async def get_connection(self, pool_entry):
|
||||||
|
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
|
||||||
|
if pool_key not in self.pools:
|
||||||
|
self.pools[pool_key] = await asyncpg.create_pool(
|
||||||
|
host=pool_entry['ts_ip'],
|
||||||
|
port=pool_entry['db_port'],
|
||||||
|
user=pool_entry['db_user'],
|
||||||
|
password=pool_entry['db_pass'],
|
||||||
|
database=pool_entry['db_name'],
|
||||||
|
min_size=1,
|
||||||
|
max_size=10
|
||||||
|
)
|
||||||
|
return await self.pools[pool_key].acquire()
|
||||||
|
|
||||||
|
async def release_connection(self, pool_entry, connection):
|
||||||
|
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
|
||||||
|
await self.pools[pool_key].release(connection)
|
||||||
|
|
||||||
|
async def close_all(self):
|
||||||
|
for pool in self.pools.values():
|
||||||
|
await pool.close()
|
||||||
|
|
||||||
class APIConfig(BaseModel):
|
class APIConfig(BaseModel):
|
||||||
HOST: str
|
HOST: str
|
||||||
PORT: int
|
PORT: int
|
||||||
|
@ -178,6 +204,11 @@ class APIConfig(BaseModel):
|
||||||
TZ: str
|
TZ: str
|
||||||
KEYS: List[str]
|
KEYS: List[str]
|
||||||
GARBAGE: Dict[str, Any]
|
GARBAGE: Dict[str, Any]
|
||||||
|
db_pool: DatabasePool = None
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.db_pool = DatabasePool()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
|
def load(cls, config_path: Union[str, Path], secrets_path: Union[str, Path]):
|
||||||
|
@ -300,19 +331,20 @@ class APIConfig(BaseModel):
|
||||||
|
|
||||||
info(f"Attempting to connect to database: {pool_entry}")
|
info(f"Attempting to connect to database: {pool_entry}")
|
||||||
try:
|
try:
|
||||||
conn = await asyncpg.connect(
|
conn = await self.db_pool.get_connection(pool_entry)
|
||||||
host=pool_entry['ts_ip'],
|
|
||||||
port=pool_entry['db_port'],
|
|
||||||
user=pool_entry['db_user'],
|
|
||||||
password=pool_entry['db_pass'],
|
|
||||||
database=pool_entry['db_name']
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
yield conn
|
yield conn
|
||||||
finally:
|
finally:
|
||||||
await conn.close()
|
await self.db_pool.release_connection(pool_entry, conn)
|
||||||
|
except asyncpg.exceptions.ConnectionDoesNotExistError:
|
||||||
|
err(f"Connection to database {pool_entry['ts_ip']}:{pool_entry['db_port']} does not exist or has been closed")
|
||||||
|
raise
|
||||||
|
except asyncpg.exceptions.ConnectionFailureError as e:
|
||||||
|
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
||||||
|
err(f"Connection error: {str(e)}")
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
err(f"Unexpected error connecting to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
||||||
err(f"Error: {str(e)}")
|
err(f"Error: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -331,11 +363,12 @@ class APIConfig(BaseModel):
|
||||||
await self.create_sync_trigger(conn, table_name)
|
await self.create_sync_trigger(conn, table_name)
|
||||||
|
|
||||||
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.")
|
info(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have version and server_id columns with appropriate triggers.")
|
||||||
|
except asyncpg.exceptions.ConnectionFailureError:
|
||||||
|
err(f"Failed to connect to database during initialization: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
|
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
async def ensure_sync_columns(self, conn, table_name):
|
async def ensure_sync_columns(self, conn, table_name):
|
||||||
try:
|
try:
|
||||||
await conn.execute(f"""
|
await conn.execute(f"""
|
||||||
|
@ -356,7 +389,6 @@ class APIConfig(BaseModel):
|
||||||
err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
|
err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
async def create_sync_trigger(self, conn, table_name):
|
async def create_sync_trigger(self, conn, table_name):
|
||||||
await conn.execute(f"""
|
await conn.execute(f"""
|
||||||
CREATE OR REPLACE FUNCTION update_version_and_server_id()
|
CREATE OR REPLACE FUNCTION update_version_and_server_id()
|
||||||
|
@ -416,6 +448,8 @@ class APIConfig(BaseModel):
|
||||||
if version > max_version:
|
if version > max_version:
|
||||||
max_version = version
|
max_version = version
|
||||||
most_recent_source = pool_entry
|
most_recent_source = pool_entry
|
||||||
|
except asyncpg.exceptions.ConnectionFailureError:
|
||||||
|
err(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}")
|
err(f"Error checking version for {pool_entry['ts_id']}: {str(e)}")
|
||||||
|
|
||||||
|
@ -426,7 +460,6 @@ class APIConfig(BaseModel):
|
||||||
|
|
||||||
return most_recent_source
|
return most_recent_source
|
||||||
|
|
||||||
|
|
||||||
async def pull_changes(self, source_pool_entry, batch_size=10000):
|
async def pull_changes(self, source_pool_entry, batch_size=10000):
|
||||||
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
|
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
|
||||||
info("Skipping self-sync")
|
info("Skipping self-sync")
|
||||||
|
@ -684,7 +717,9 @@ class APIConfig(BaseModel):
|
||||||
""", table_name, column_name)
|
""", table_name, column_name)
|
||||||
return exists
|
return exists
|
||||||
|
|
||||||
|
async def close_db_pools(self):
|
||||||
|
if self.db_pool:
|
||||||
|
await self.db_pool.close_all()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue