Auto-update: Tue Aug 6 22:15:01 PDT 2024
This commit is contained in:
parent
e42a17e61f
commit
d8514b57a6
1 changed files with 67 additions and 44 deletions
|
@ -363,7 +363,8 @@ class APIConfig(BaseModel):
|
||||||
local_ts_id = os.environ.get('TS_ID')
|
local_ts_id = os.environ.get('TS_ID')
|
||||||
|
|
||||||
for pool_entry in self.POOL:
|
for pool_entry in self.POOL:
|
||||||
if pool_entry['ts_id'] != local_ts_id:
|
# omit self from online hosts
|
||||||
|
# if pool_entry['ts_id'] != local_ts_id:
|
||||||
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
|
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
|
||||||
if pool_key in self.offline_servers:
|
if pool_key in self.offline_servers:
|
||||||
if current_time - self.offline_servers[pool_key] < self.offline_timeout:
|
if current_time - self.offline_servers[pool_key] < self.offline_timeout:
|
||||||
|
@ -832,42 +833,50 @@ class APIConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_write_query(self, query: str, *args, table_name: str, max_retries=3, retry_delay=1):
|
async def execute_write_query(self, query: str, *args, table_name: str):
|
||||||
local_ts_id = os.environ.get('TS_ID')
|
local_ts_id = os.environ.get('TS_ID')
|
||||||
|
online_hosts = await self.get_online_hosts()
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
for pool_entry in online_hosts:
|
||||||
try:
|
if pool_entry['ts_id'] != local_ts_id:
|
||||||
# Execute on local database
|
continue # Only write to the local database
|
||||||
local_conn = await self.get_connection()
|
|
||||||
if local_conn is None:
|
conn = await self.get_connection(pool_entry)
|
||||||
raise ConnectionError("Failed to connect to local database")
|
if conn is None:
|
||||||
|
err(f"Unable to connect to local database {pool_entry['ts_id']}. Write operation failed.")
|
||||||
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if table_name in self.SPECIAL_TABLES:
|
if table_name in self.SPECIAL_TABLES:
|
||||||
local_result = await self._execute_special_table_write(local_conn, query, *args, table_name=table_name)
|
result = await self._execute_special_table_write(conn, query, *args, table_name=table_name)
|
||||||
else:
|
else:
|
||||||
local_result = await local_conn.fetch(query, *args) if query.strip().upper().startswith("INSERT") and "RETURNING" in query.upper() else await local_conn.execute(query, *args)
|
if query.strip().upper().startswith("INSERT"):
|
||||||
|
table_columns = await self.get_table_columns(conn, table_name)
|
||||||
|
primary_key = await self.get_primary_key(conn, table_name)
|
||||||
|
|
||||||
# Start background task to sync with other databases
|
if primary_key:
|
||||||
asyncio.create_task(self._sync_write_to_other_dbs(query, args, table_name))
|
set_clause = ", ".join([f"{col} = EXCLUDED.{col}" for col in table_columns if col != primary_key])
|
||||||
|
query = f"""
|
||||||
|
INSERT INTO {table_name} ({', '.join(table_columns)})
|
||||||
|
VALUES ({', '.join(f'${i+1}' for i in range(len(args)))})
|
||||||
|
ON CONFLICT ({primary_key}) DO UPDATE SET
|
||||||
|
{set_clause}
|
||||||
|
"""
|
||||||
|
|
||||||
return local_result
|
result = await conn.fetch(query, *args)
|
||||||
finally:
|
|
||||||
await local_conn.close()
|
|
||||||
|
|
||||||
except (ConnectionError, asyncpg.exceptions.ConnectionDoesNotExistError) as e:
|
asyncio.create_task(self._sync_changes(table_name, await self.get_primary_key(conn, table_name)))
|
||||||
if attempt == max_retries - 1:
|
|
||||||
raise
|
return result
|
||||||
warn(f"Connection error on attempt {attempt + 1}. Retrying in {retry_delay} seconds...")
|
|
||||||
await asyncio.sleep(retry_delay)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error executing write query on table {table_name}: {str(e)}")
|
err(f"Error executing write query on {pool_entry['ts_id']}: {str(e)}")
|
||||||
err(f"Query: {query}")
|
err(f"Query: {query}")
|
||||||
err(f"Args: {args}")
|
err(f"Args: {args}")
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
raise
|
finally:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
raise ConnectionError(f"Failed to execute write query after {max_retries} attempts")
|
return []
|
||||||
|
|
||||||
|
|
||||||
async def _sync_write_to_other_dbs(self, query: str, args: tuple, table_name: str):
|
async def _sync_write_to_other_dbs(self, query: str, args: tuple, table_name: str):
|
||||||
|
@ -1156,11 +1165,25 @@ class APIConfig(BaseModel):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def get_primary_key(self, table_name: str) -> str:
|
async def get_table_columns(self, conn, table_name: str) -> List[str]:
|
||||||
# This method should return the primary key for the given table
|
query = """
|
||||||
# You might want to cache this information for performance
|
SELECT column_name
|
||||||
# For now, we'll assume it's always 'id', but you should implement proper logic here
|
FROM information_schema.columns
|
||||||
return 'id'
|
WHERE table_schema = 'public' AND table_name = $1
|
||||||
|
"""
|
||||||
|
columns = await conn.fetch(query, table_name)
|
||||||
|
return [col['column_name'] for col in columns]
|
||||||
|
|
||||||
|
async def get_primary_key(self, conn, table_name: str) -> str:
|
||||||
|
query = """
|
||||||
|
SELECT a.attname
|
||||||
|
FROM pg_index i
|
||||||
|
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
|
||||||
|
WHERE i.indrelid = $1::regclass AND i.indisprimary
|
||||||
|
"""
|
||||||
|
result = await conn.fetchval(query, table_name)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def add_primary_keys_to_local_tables(self):
|
async def add_primary_keys_to_local_tables(self):
|
||||||
|
|
Loading…
Reference in a new issue