From d01f24ad45f23f883e9cf1cdb66692d7bdbe2c33 Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Wed, 31 Jul 2024 14:32:26 -0700 Subject: [PATCH] Much more efficient database sync method, initial attempt. --- sijapi/__main__.py | 12 ++ sijapi/classes.py | 369 ++++++++++++++-------------------------- sijapi/routers/serve.py | 121 ++++++------- 3 files changed, 199 insertions(+), 303 deletions(-) diff --git a/sijapi/__main__.py b/sijapi/__main__.py index f8e5d1c..27eae12 100755 --- a/sijapi/__main__.py +++ b/sijapi/__main__.py @@ -134,6 +134,17 @@ async def handle_exception_middleware(request: Request, call_next): raise return response + +# This was removed on 7/31/2024 when we decided to instead use a targeted push sync approach. +deprecated = ''' +async def push_changes_background(): + try: + await API.push_changes_to_all() + except Exception as e: + err(f"Error pushing changes to other databases: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + + @app.middleware("http") async def sync_middleware(request: Request, call_next): response = await call_next(request) @@ -147,6 +158,7 @@ async def sync_middleware(request: Request, call_next): err(f"Error pushing changes to other databases: {str(e)}") return response +''' def load_router(router_name): router_file = ROUTER_DIR / f'{router_name}.py' diff --git a/sijapi/classes.py b/sijapi/classes.py index 96e3041..191a564 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -167,7 +167,6 @@ class Configuration(BaseModel): - class APIConfig(BaseModel): HOST: str PORT: int @@ -338,55 +337,6 @@ class APIConfig(BaseModel): err(f"Failed to acquire connection from pool for {pool_key}: {str(e)}") yield None - async def push_changes_to_one(self, pool_entry): - try: - async with self.get_connection() as local_conn: - if local_conn is None: - err(f"Failed to connect to local database. Skipping push to {pool_entry['ts_id']}") - return - - async with self.get_connection(pool_entry) as remote_conn: - if remote_conn is None: - err(f"Failed to connect to remote database {pool_entry['ts_id']}. Skipping push.") - return - - tables = await local_conn.fetch(""" - SELECT tablename FROM pg_tables - WHERE schemaname = 'public' - """) - - for table in tables: - table_name = table['tablename'] - try: - if table_name in self.SPECIAL_TABLES: - await self.sync_special_table(local_conn, remote_conn, table_name) - else: - primary_key = await self.ensure_sync_columns(remote_conn, table_name) - last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID')) - - changes = await local_conn.fetch(f""" - SELECT * FROM "{table_name}" - WHERE version > $1 AND server_id = $2 - ORDER BY version ASC - """, last_synced_version, os.environ.get('TS_ID')) - - if changes: - changes_count = await self.apply_batch_changes(remote_conn, table_name, changes, primary_key) - - if changes_count > 0: - debug(f"Pushed {changes_count} changes for table {table_name} to {pool_entry['ts_id']}") - - except Exception as e: - err(f"Error pushing changes for table {table_name} to {pool_entry['ts_id']}: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") - - info(f"Successfully pushed changes to {pool_entry['ts_id']}") - - except Exception as e: - err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") - - async def close_db_pools(self): info("Closing database connection pools...") for pool_key, pool in self.db_pools.items(): @@ -500,114 +450,6 @@ class APIConfig(BaseModel): err(f"Error ensuring sync columns for table {table_name}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") - async def apply_batch_changes(self, conn, table_name, changes, primary_key): - if conn is None or not changes: - debug(f"Skipping apply_batch_changes because conn is none or there are no changes.") - return 0 - - try: - columns = list(changes[0].keys()) - placeholders = [f'${i+1}' for i in range(len(columns))] - - if primary_key: - insert_query = f""" - INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) - VALUES ({', '.join(placeholders)}) - ON CONFLICT ("{primary_key}") DO UPDATE SET - {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in [primary_key, 'version', 'server_id'])}, - version = EXCLUDED.version, - server_id = EXCLUDED.server_id - WHERE "{table_name}".version < EXCLUDED.version - OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) - """ - else: - # For tables without a primary key, we'll use all columns for conflict resolution - insert_query = f""" - INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) - VALUES ({', '.join(placeholders)}) - ON CONFLICT DO NOTHING - """ - - # debug(f"Generated insert query for {table_name}: {insert_query}") - - affected_rows = 0 - async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"): - values = [change[col] for col in columns] - # debug(f"Executing query for {table_name} with values: {values}") - result = await conn.execute(insert_query, *values) - affected_rows += int(result.split()[-1]) - - return affected_rows - - except Exception as e: - err(f"Error applying batch changes to {table_name}: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") - return 0 - - async def pull_changes(self, source_pool_entry, batch_size=10000): - if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): - debug("Skipping self-sync") - return 0 - - total_changes = 0 - source_id = source_pool_entry['ts_id'] - source_ip = source_pool_entry['ts_ip'] - dest_id = os.environ.get('TS_ID') - dest_ip = self.local_db['ts_ip'] - - info(f"Starting sync from source {source_id} ({source_ip}) to destination {dest_id} ({dest_ip})") - - try: - async with self.get_connection(source_pool_entry) as source_conn: - async with self.get_connection(self.local_db) as dest_conn: - tables = await source_conn.fetch(""" - SELECT tablename FROM pg_tables - WHERE schemaname = 'public' - """) - - async for table in tqdm(tables, desc="Syncing tables", unit="table"): - table_name = table['tablename'] - try: - if table_name in self.SPECIAL_TABLES: - await self.sync_special_table(source_conn, dest_conn, table_name) - else: - primary_key = await self.ensure_sync_columns(dest_conn, table_name) - last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id) - - changes = await source_conn.fetch(f""" - SELECT * FROM "{table_name}" - WHERE version > $1 AND server_id = $2 - ORDER BY version ASC - LIMIT $3 - """, last_synced_version, source_id, batch_size) - - if changes: - changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, primary_key) - total_changes += changes_count - - if changes_count > 0: - info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}") - else: - debug(f"No changes to sync for {table_name}") - - except Exception as e: - err(f"Error syncing table {table_name}: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") - - info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}") - - except Exception as e: - err(f"Error during sync process: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") - - info(f"Sync summary:") - info(f" Total changes: {total_changes}") - info(f" Tables synced: {len(tables)}") - info(f" Source: {source_id} ({source_ip})") - info(f" Destination: {dest_id} ({dest_ip})") - - return total_changes - async def get_online_hosts(self) -> List[Dict[str, Any]]: online_hosts = [] for pool_entry in self.POOL: @@ -619,31 +461,6 @@ class APIConfig(BaseModel): err(f"Error checking host {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}") return online_hosts - async def push_changes_to_all(self): - for pool_entry in self.POOL: - if pool_entry['ts_id'] != os.environ.get('TS_ID'): - try: - await self.push_changes_to_one(pool_entry) - except Exception as e: - err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}") - - - - async def get_last_synced_version(self, conn, table_name, server_id): - if conn is None: - debug(f"Skipping offline server...") - return 0 - - if table_name in self.SPECIAL_TABLES: - debug(f"Skipping get_last_synced_version becaue {table_name} is special.") - return 0 # Special handling for tables without version column - - return await conn.fetchval(f""" - SELECT COALESCE(MAX(version), 0) - FROM "{table_name}" - WHERE server_id = $1 - """, server_id) - async def check_postgis(self, conn): if conn is None: debug(f"Skipping offline server...") @@ -661,11 +478,134 @@ class APIConfig(BaseModel): err(f"Error checking PostGIS: {str(e)}") return False - async def sync_special_table(self, source_conn, dest_conn, table_name): + async def execute_write_query(self, query: str, *args, table_name: str): + if table_name in self.SPECIAL_TABLES: + return await self._execute_special_table_write(query, *args, table_name=table_name) + + async with self.get_connection() as conn: + if conn is None: + raise ConnectionError("Failed to connect to local database") + + # Ensure sync columns exist + primary_key = await self.ensure_sync_columns(conn, table_name) + + # Execute the query + result = await conn.execute(query, *args) + + # Get the primary key and new version of the affected row + if primary_key: + affected_row = await conn.fetchrow(f""" + SELECT "{primary_key}", version, server_id + FROM "{table_name}" + WHERE version = (SELECT MAX(version) FROM "{table_name}") + """) + if affected_row: + await self.push_change(table_name, affected_row[primary_key], affected_row['version'], affected_row['server_id']) + else: + # For tables without a primary key, we'll push all rows + await self.push_all_changes(table_name) + + return result + + async def push_change(self, table_name: str, pk_value: Any, version: int, server_id: str): + online_hosts = await self.get_online_hosts() + for pool_entry in online_hosts: + if pool_entry['ts_id'] != os.environ.get('TS_ID'): + try: + async with self.get_connection(pool_entry) as remote_conn: + if remote_conn is None: + continue + + # Fetch the updated row from the local database + async with self.get_connection() as local_conn: + updated_row = await local_conn.fetchrow(f'SELECT * FROM "{table_name}" WHERE "{self.get_primary_key(table_name)}" = $1', pk_value) + + if updated_row: + columns = updated_row.keys() + placeholders = [f'${i+1}' for i in range(len(columns))] + primary_key = self.get_primary_key(table_name) + + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT ("{primary_key}") DO UPDATE SET + {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != primary_key)}, + version = EXCLUDED.version, + server_id = EXCLUDED.server_id + WHERE "{table_name}".version < EXCLUDED.version + OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) + """ + await remote_conn.execute(insert_query, *updated_row.values()) + except Exception as e: + err(f"Error pushing change to {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + + async def push_all_changes(self, table_name: str): + online_hosts = await self.get_online_hosts() + for pool_entry in online_hosts: + if pool_entry['ts_id'] != os.environ.get('TS_ID'): + try: + async with self.get_connection(pool_entry) as remote_conn: + if remote_conn is None: + continue + + # Fetch all rows from the local database + async with self.get_connection() as local_conn: + all_rows = await local_conn.fetch(f'SELECT * FROM "{table_name}"') + + if all_rows: + columns = all_rows[0].keys() + placeholders = [f'${i+1}' for i in range(len(columns))] + + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT DO NOTHING + """ + + # Use a transaction to insert all rows + async with remote_conn.transaction(): + for row in all_rows: + await remote_conn.execute(insert_query, *row.values()) + except Exception as e: + err(f"Error pushing all changes to {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + + async def _execute_special_table_write(self, query: str, *args, table_name: str): if table_name == 'spatial_ref_sys': - return await self.sync_spatial_ref_sys(source_conn, dest_conn) + return await self._execute_spatial_ref_sys_write(query, *args) # Add more special cases as needed + async def _execute_spatial_ref_sys_write(self, query: str, *args): + result = None + async with self.get_connection() as local_conn: + if local_conn is None: + raise ConnectionError("Failed to connect to local database") + + # Execute the query locally + result = await local_conn.execute(query, *args) + + # Sync the entire spatial_ref_sys table with all online hosts + online_hosts = await self.get_online_hosts() + for pool_entry in online_hosts: + if pool_entry['ts_id'] != os.environ.get('TS_ID'): + try: + async with self.get_connection(pool_entry) as remote_conn: + if remote_conn is None: + continue + await self.sync_spatial_ref_sys(local_conn, remote_conn) + except Exception as e: + err(f"Error syncing spatial_ref_sys to {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + + return result + + def get_primary_key(self, table_name: str) -> str: + # This method should return the primary key for the given table + # You might want to cache this information for performance + # For now, we'll assume it's always 'id', but you should implement proper logic here + return 'id' + async def sync_spatial_ref_sys(self, source_conn, dest_conn): try: # Get all entries from the source @@ -696,7 +636,6 @@ class APIConfig(BaseModel): INSERT INTO spatial_ref_sys ({', '.join(f'"{col}"' for col in columns)}) VALUES ({', '.join(placeholders)}) """ - # debug(f"Inserting new entry for srid {srid}: {insert_query}") await dest_conn.execute(insert_query, *source_entry.values()) inserts += 1 elif source_entry != dest_dict[srid]: @@ -709,7 +648,6 @@ class APIConfig(BaseModel): proj4text = $4::text WHERE srid = $5::integer """ - # debug(f"Updating entry for srid {srid}: {update_query}") await dest_conn.execute(update_query, source_entry['auth_name'], source_entry['auth_srid'], @@ -727,63 +665,6 @@ class APIConfig(BaseModel): err(f"Traceback: {traceback.format_exc()}") return 0 - async def get_most_recent_source(self): - most_recent_source = None - max_version = -1 - local_ts_id = os.environ.get('TS_ID') - online_hosts = await self.get_online_hosts() - num_online_hosts = len(online_hosts) - if num_online_hosts > 0: - online_ts_ids = [host['ts_id'] for host in online_hosts if host['ts_id'] != local_ts_id] - crit(f"Online hosts: {', '.join(online_ts_ids)}") - - for pool_entry in online_hosts: - if pool_entry['ts_id'] == local_ts_id: - continue # Skip local database - - try: - async with self.get_connection(pool_entry) as conn: - tables = await conn.fetch(""" - SELECT tablename FROM pg_tables - WHERE schemaname = 'public' - """) - - for table in tables: - table_name = table['tablename'] - if table_name in self.SPECIAL_TABLES: - continue # Skip special tables for version comparison - try: - result = await conn.fetchrow(f""" - SELECT MAX(version) as max_version, server_id - FROM "{table_name}" - WHERE version = (SELECT MAX(version) FROM "{table_name}") - GROUP BY server_id - ORDER BY MAX(version) DESC - LIMIT 1 - """) - if result: - version, server_id = result['max_version'], result['server_id'] - info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})") - if version > max_version: - max_version = version - most_recent_source = pool_entry - else: - debug(f"No data in table {table_name} for {pool_entry['ts_id']}") - except asyncpg.exceptions.UndefinedColumnError: - warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping.") - except Exception as e: - err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}") - - except asyncpg.exceptions.ConnectionFailureError: - warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}") - except Exception as e: - err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") - - return most_recent_source - else: - warn(f"No other online hosts for sync") - return None class Location(BaseModel): diff --git a/sijapi/routers/serve.py b/sijapi/routers/serve.py index 7709d0b..cb596a4 100644 --- a/sijapi/routers/serve.py +++ b/sijapi/routers/serve.py @@ -417,50 +417,49 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True: await cl_download_file(download_url, target_path, session) debug(f"Downloaded {file_name} to {target_path}") - @serve.get("/s", response_class=HTMLResponse) async def shortener_form(request: Request): return templates.TemplateResponse("shortener.html", {"request": request}) @serve.post("/s") async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)): - async with API.get_connection() as conn: - await create_tables(conn) + await create_tables() - if custom_code: - if len(custom_code) != 3 or not custom_code.isalnum(): - return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"}) - - existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code) - if existing: - return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"}) - - short_code = custom_code - else: - chars = string.ascii_letters + string.digits - while True: - short_code = ''.join(random.choice(chars) for _ in range(3)) - existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', short_code) - if not existing: - break + if custom_code: + if len(custom_code) != 3 or not custom_code.isalnum(): + return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"}) + + existing = await API.execute_write_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls") + if existing: + return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"}) + + short_code = custom_code + else: + chars = string.ascii_letters + string.digits + while True: + short_code = ''.join(random.choice(chars) for _ in range(3)) + existing = await API.execute_write_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls") + if not existing: + break - await conn.execute( - 'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)', - short_code, long_url - ) + await API.execute_write_query( + 'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)', + short_code, long_url, + table_name="short_urls" + ) short_url = f"https://sij.ai/{short_code}" return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url}) -async def create_tables(conn): - await conn.execute(''' +async def create_tables(): + await API.execute_write_query(''' CREATE TABLE IF NOT EXISTS short_urls ( short_code VARCHAR(3) PRIMARY KEY, long_url TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) - ''') - await conn.execute(''' + ''', table_name="short_urls") + await API.execute_write_query(''' CREATE TABLE IF NOT EXISTS click_logs ( id SERIAL PRIMARY KEY, short_code VARCHAR(3) REFERENCES short_urls(short_code), @@ -468,48 +467,51 @@ async def create_tables(conn): ip_address TEXT, user_agent TEXT ) - ''') + ''', table_name="click_logs") @serve.get("/{short_code}", response_class=RedirectResponse, status_code=301) async def redirect_short_url(request: Request, short_code: str = PathParam(..., min_length=3, max_length=3)): if request.headers.get('host') != 'sij.ai': raise HTTPException(status_code=404, detail="Not Found") - async with API.get_connection() as conn: - result = await conn.fetchrow( - 'SELECT long_url FROM short_urls WHERE short_code = $1', - short_code - ) + result = await API.execute_write_query( + 'SELECT long_url FROM short_urls WHERE short_code = $1', + short_code, + table_name="short_urls" + ) - if result: - await conn.execute( - 'INSERT INTO click_logs (short_code, ip_address, user_agent) VALUES ($1, $2, $3)', - short_code, request.client.host, request.headers.get("user-agent") - ) - return result['long_url'] - else: - raise HTTPException(status_code=404, detail="Short URL not found") + if result: + await API.execute_write_query( + 'INSERT INTO click_logs (short_code, ip_address, user_agent) VALUES ($1, $2, $3)', + short_code, request.client.host, request.headers.get("user-agent"), + table_name="click_logs" + ) + return result['long_url'] + else: + raise HTTPException(status_code=404, detail="Short URL not found") @serve.get("/analytics/{short_code}") async def get_analytics(short_code: str): - async with API.get_connection() as conn: - url_info = await conn.fetchrow( - 'SELECT long_url, created_at FROM short_urls WHERE short_code = $1', - short_code - ) - if not url_info: - raise HTTPException(status_code=404, detail="Short URL not found") - - click_count = await conn.fetchval( - 'SELECT COUNT(*) FROM click_logs WHERE short_code = $1', - short_code - ) - - clicks = await conn.fetch( - 'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100', - short_code - ) - + url_info = await API.execute_write_query( + 'SELECT long_url, created_at FROM short_urls WHERE short_code = $1', + short_code, + table_name="short_urls" + ) + if not url_info: + raise HTTPException(status_code=404, detail="Short URL not found") + + click_count = await API.execute_write_query( + 'SELECT COUNT(*) FROM click_logs WHERE short_code = $1', + short_code, + table_name="click_logs" + ) + + clicks = await API.execute_write_query( + 'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100', + short_code, + table_name="click_logs" + ) + return { "short_code": short_code, "long_url": url_info['long_url'], @@ -520,6 +522,7 @@ async def get_analytics(short_code: str): + async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, destination: str): dest_host, dest_port = destination.split(':') dest_port = int(dest_port)