diff --git a/sijapi/classes.py b/sijapi/classes.py index f0749be..cbda9b4 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -363,18 +363,19 @@ class APIConfig(BaseModel): local_ts_id = os.environ.get('TS_ID') for pool_entry in self.POOL: - if pool_entry['ts_id'] != local_ts_id: - pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}" - if pool_key in self.offline_servers: - if current_time - self.offline_servers[pool_key] < self.offline_timeout: - continue - else: - del self.offline_servers[pool_key] + # omit self from online hosts + # if pool_entry['ts_id'] != local_ts_id: + pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}" + if pool_key in self.offline_servers: + if current_time - self.offline_servers[pool_key] < self.offline_timeout: + continue + else: + del self.offline_servers[pool_key] - conn = await self.get_connection(pool_entry) - if conn is not None: - online_hosts.append(pool_entry) - await conn.close() + conn = await self.get_connection(pool_entry) + if conn is not None: + online_hosts.append(pool_entry) + await conn.close() self.online_hosts_cache[cache_key] = (online_hosts, current_time) return online_hosts @@ -832,42 +833,50 @@ class APIConfig(BaseModel): - async def execute_write_query(self, query: str, *args, table_name: str, max_retries=3, retry_delay=1): + async def execute_write_query(self, query: str, *args, table_name: str): local_ts_id = os.environ.get('TS_ID') + online_hosts = await self.get_online_hosts() - for attempt in range(max_retries): + for pool_entry in online_hosts: + if pool_entry['ts_id'] != local_ts_id: + continue # Only write to the local database + + conn = await self.get_connection(pool_entry) + if conn is None: + err(f"Unable to connect to local database {pool_entry['ts_id']}. Write operation failed.") + return [] + try: - # Execute on local database - local_conn = await self.get_connection() - if local_conn is None: - raise ConnectionError("Failed to connect to local database") - - try: - if table_name in self.SPECIAL_TABLES: - local_result = await self._execute_special_table_write(local_conn, query, *args, table_name=table_name) - else: - local_result = await local_conn.fetch(query, *args) if query.strip().upper().startswith("INSERT") and "RETURNING" in query.upper() else await local_conn.execute(query, *args) + if table_name in self.SPECIAL_TABLES: + result = await self._execute_special_table_write(conn, query, *args, table_name=table_name) + else: + if query.strip().upper().startswith("INSERT"): + table_columns = await self.get_table_columns(conn, table_name) + primary_key = await self.get_primary_key(conn, table_name) + + if primary_key: + set_clause = ", ".join([f"{col} = EXCLUDED.{col}" for col in table_columns if col != primary_key]) + query = f""" + INSERT INTO {table_name} ({', '.join(table_columns)}) + VALUES ({', '.join(f'${i+1}' for i in range(len(args)))}) + ON CONFLICT ({primary_key}) DO UPDATE SET + {set_clause} + """ - # Start background task to sync with other databases - asyncio.create_task(self._sync_write_to_other_dbs(query, args, table_name)) - - return local_result - finally: - await local_conn.close() - - except (ConnectionError, asyncpg.exceptions.ConnectionDoesNotExistError) as e: - if attempt == max_retries - 1: - raise - warn(f"Connection error on attempt {attempt + 1}. Retrying in {retry_delay} seconds...") - await asyncio.sleep(retry_delay) + result = await conn.fetch(query, *args) + + asyncio.create_task(self._sync_changes(table_name, await self.get_primary_key(conn, table_name))) + + return result except Exception as e: - err(f"Error executing write query on table {table_name}: {str(e)}") + err(f"Error executing write query on {pool_entry['ts_id']}: {str(e)}") err(f"Query: {query}") err(f"Args: {args}") err(f"Traceback: {traceback.format_exc()}") - raise - - raise ConnectionError(f"Failed to execute write query after {max_retries} attempts") + finally: + await conn.close() + + return [] async def _sync_write_to_other_dbs(self, query: str, args: tuple, table_name: str): @@ -1156,11 +1165,25 @@ class APIConfig(BaseModel): return 0 - def get_primary_key(self, table_name: str) -> str: - # This method should return the primary key for the given table - # You might want to cache this information for performance - # For now, we'll assume it's always 'id', but you should implement proper logic here - return 'id' + async def get_table_columns(self, conn, table_name: str) -> List[str]: + query = """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' AND table_name = $1 + """ + columns = await conn.fetch(query, table_name) + return [col['column_name'] for col in columns] + + async def get_primary_key(self, conn, table_name: str) -> str: + query = """ + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = $1::regclass AND i.indisprimary + """ + result = await conn.fetchval(query, table_name) + return result + async def add_primary_keys_to_local_tables(self):