Auto-update: Tue Aug 6 22:19:38 PDT 2024
This commit is contained in:
parent
d8514b57a6
commit
e527b0e391
1 changed files with 27 additions and 24 deletions
|
@ -832,7 +832,6 @@ class APIConfig(BaseModel):
|
||||||
return [dict(r) for r in latest_result] # Convert Record objects to dictionaries
|
return [dict(r) for r in latest_result] # Convert Record objects to dictionaries
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_write_query(self, query: str, *args, table_name: str):
|
async def execute_write_query(self, query: str, *args, table_name: str):
|
||||||
local_ts_id = os.environ.get('TS_ID')
|
local_ts_id = os.environ.get('TS_ID')
|
||||||
online_hosts = await self.get_online_hosts()
|
online_hosts = await self.get_online_hosts()
|
||||||
|
@ -850,11 +849,12 @@ class APIConfig(BaseModel):
|
||||||
if table_name in self.SPECIAL_TABLES:
|
if table_name in self.SPECIAL_TABLES:
|
||||||
result = await self._execute_special_table_write(conn, query, *args, table_name=table_name)
|
result = await self._execute_special_table_write(conn, query, *args, table_name=table_name)
|
||||||
else:
|
else:
|
||||||
|
# Modify the query to be an UPSERT operation
|
||||||
if query.strip().upper().startswith("INSERT"):
|
if query.strip().upper().startswith("INSERT"):
|
||||||
table_columns = await self.get_table_columns(conn, table_name)
|
table_columns = await self.get_table_columns(conn, table_name)
|
||||||
primary_key = await self.get_primary_key(conn, table_name)
|
primary_key = await self.get_primary_key(conn, table_name)
|
||||||
|
|
||||||
if primary_key:
|
if primary_key and len(table_columns) == len(args):
|
||||||
set_clause = ", ".join([f"{col} = EXCLUDED.{col}" for col in table_columns if col != primary_key])
|
set_clause = ", ".join([f"{col} = EXCLUDED.{col}" for col in table_columns if col != primary_key])
|
||||||
query = f"""
|
query = f"""
|
||||||
INSERT INTO {table_name} ({', '.join(table_columns)})
|
INSERT INTO {table_name} ({', '.join(table_columns)})
|
||||||
|
@ -862,11 +862,12 @@ class APIConfig(BaseModel):
|
||||||
ON CONFLICT ({primary_key}) DO UPDATE SET
|
ON CONFLICT ({primary_key}) DO UPDATE SET
|
||||||
{set_clause}
|
{set_clause}
|
||||||
"""
|
"""
|
||||||
|
else:
|
||||||
|
err(f"Column count mismatch for table {table_name}. Columns: {len(table_columns)}, Values: {len(args)}")
|
||||||
|
return []
|
||||||
|
|
||||||
result = await conn.fetch(query, *args)
|
result = await conn.fetch(query, *args)
|
||||||
|
|
||||||
asyncio.create_task(self._sync_changes(table_name, await self.get_primary_key(conn, table_name)))
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error executing write query on {pool_entry['ts_id']}: {str(e)}")
|
err(f"Error executing write query on {pool_entry['ts_id']}: {str(e)}")
|
||||||
|
@ -877,6 +878,27 @@ class APIConfig(BaseModel):
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def get_table_columns(self, conn, table_name: str) -> List[str]:
|
||||||
|
query = """
|
||||||
|
SELECT column_name
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_schema = 'public' AND table_name = $1
|
||||||
|
ORDER BY ordinal_position
|
||||||
|
"""
|
||||||
|
columns = await conn.fetch(query, table_name)
|
||||||
|
return [col['column_name'] for col in columns]
|
||||||
|
|
||||||
|
async def get_primary_key(self, conn, table_name: str) -> str:
|
||||||
|
query = """
|
||||||
|
SELECT a.attname
|
||||||
|
FROM pg_index i
|
||||||
|
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
|
||||||
|
WHERE i.indrelid = $1::regclass AND i.indisprimary
|
||||||
|
"""
|
||||||
|
result = await conn.fetchval(query, table_name)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _sync_write_to_other_dbs(self, query: str, args: tuple, table_name: str):
|
async def _sync_write_to_other_dbs(self, query: str, args: tuple, table_name: str):
|
||||||
|
@ -1165,25 +1187,6 @@ class APIConfig(BaseModel):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
async def get_table_columns(self, conn, table_name: str) -> List[str]:
|
|
||||||
query = """
|
|
||||||
SELECT column_name
|
|
||||||
FROM information_schema.columns
|
|
||||||
WHERE table_schema = 'public' AND table_name = $1
|
|
||||||
"""
|
|
||||||
columns = await conn.fetch(query, table_name)
|
|
||||||
return [col['column_name'] for col in columns]
|
|
||||||
|
|
||||||
async def get_primary_key(self, conn, table_name: str) -> str:
|
|
||||||
query = """
|
|
||||||
SELECT a.attname
|
|
||||||
FROM pg_index i
|
|
||||||
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
|
|
||||||
WHERE i.indrelid = $1::regclass AND i.indisprimary
|
|
||||||
"""
|
|
||||||
result = await conn.fetchval(query, table_name)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def add_primary_keys_to_local_tables(self):
|
async def add_primary_keys_to_local_tables(self):
|
||||||
|
@ -1614,4 +1617,4 @@ class WidgetUpdate(BaseModel):
|
||||||
color: Optional[str] = None
|
color: Optional[str] = None
|
||||||
url: Optional[str] = None
|
url: Optional[str] = None
|
||||||
shortcut: Optional[str] = None
|
shortcut: Optional[str] = None
|
||||||
graph: Optional[str] = None
|
graph: Optional[str] = None
|
Loading…
Add table
Reference in a new issue