diff --git a/sijapi/classes.py b/sijapi/classes.py index cbda9b4..8162e77 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -832,7 +832,6 @@ class APIConfig(BaseModel): return [dict(r) for r in latest_result] # Convert Record objects to dictionaries - async def execute_write_query(self, query: str, *args, table_name: str): local_ts_id = os.environ.get('TS_ID') online_hosts = await self.get_online_hosts() @@ -850,11 +849,12 @@ class APIConfig(BaseModel): if table_name in self.SPECIAL_TABLES: result = await self._execute_special_table_write(conn, query, *args, table_name=table_name) else: + # Modify the query to be an UPSERT operation if query.strip().upper().startswith("INSERT"): table_columns = await self.get_table_columns(conn, table_name) primary_key = await self.get_primary_key(conn, table_name) - if primary_key: + if primary_key and len(table_columns) == len(args): set_clause = ", ".join([f"{col} = EXCLUDED.{col}" for col in table_columns if col != primary_key]) query = f""" INSERT INTO {table_name} ({', '.join(table_columns)}) @@ -862,11 +862,12 @@ class APIConfig(BaseModel): ON CONFLICT ({primary_key}) DO UPDATE SET {set_clause} """ + else: + err(f"Column count mismatch for table {table_name}. Columns: {len(table_columns)}, Values: {len(args)}") + return [] result = await conn.fetch(query, *args) - asyncio.create_task(self._sync_changes(table_name, await self.get_primary_key(conn, table_name))) - return result except Exception as e: err(f"Error executing write query on {pool_entry['ts_id']}: {str(e)}") @@ -877,6 +878,27 @@ class APIConfig(BaseModel): await conn.close() return [] + + async def get_table_columns(self, conn, table_name: str) -> List[str]: + query = """ + SELECT column_name + FROM information_schema.columns + WHERE table_schema = 'public' AND table_name = $1 + ORDER BY ordinal_position + """ + columns = await conn.fetch(query, table_name) + return [col['column_name'] for col in columns] + + async def get_primary_key(self, conn, table_name: str) -> str: + query = """ + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = $1::regclass AND i.indisprimary + """ + result = await conn.fetchval(query, table_name) + return result + async def _sync_write_to_other_dbs(self, query: str, args: tuple, table_name: str): @@ -1165,25 +1187,6 @@ class APIConfig(BaseModel): return 0 - async def get_table_columns(self, conn, table_name: str) -> List[str]: - query = """ - SELECT column_name - FROM information_schema.columns - WHERE table_schema = 'public' AND table_name = $1 - """ - columns = await conn.fetch(query, table_name) - return [col['column_name'] for col in columns] - - async def get_primary_key(self, conn, table_name: str) -> str: - query = """ - SELECT a.attname - FROM pg_index i - JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) - WHERE i.indrelid = $1::regclass AND i.indisprimary - """ - result = await conn.fetchval(query, table_name) - return result - async def add_primary_keys_to_local_tables(self): @@ -1614,4 +1617,4 @@ class WidgetUpdate(BaseModel): color: Optional[str] = None url: Optional[str] = None shortcut: Optional[str] = None - graph: Optional[str] = None + graph: Optional[str] = None \ No newline at end of file