Much more efficient database sync method, initial attempt.
This commit is contained in:
parent
6c757c6556
commit
6e03675fb6
3 changed files with 199 additions and 303 deletions
sijapi
|
@ -134,6 +134,17 @@ async def handle_exception_middleware(request: Request, call_next):
|
||||||
raise
|
raise
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
# This was removed on 7/31/2024 when we decided to instead use a targeted push sync approach.
|
||||||
|
deprecated = '''
|
||||||
|
async def push_changes_background():
|
||||||
|
try:
|
||||||
|
await API.push_changes_to_all()
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Error pushing changes to other databases: {str(e)}")
|
||||||
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def sync_middleware(request: Request, call_next):
|
async def sync_middleware(request: Request, call_next):
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
@ -147,6 +158,7 @@ async def sync_middleware(request: Request, call_next):
|
||||||
err(f"Error pushing changes to other databases: {str(e)}")
|
err(f"Error pushing changes to other databases: {str(e)}")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
'''
|
||||||
|
|
||||||
def load_router(router_name):
|
def load_router(router_name):
|
||||||
router_file = ROUTER_DIR / f'{router_name}.py'
|
router_file = ROUTER_DIR / f'{router_name}.py'
|
||||||
|
|
|
@ -167,7 +167,6 @@ class Configuration(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class APIConfig(BaseModel):
|
class APIConfig(BaseModel):
|
||||||
HOST: str
|
HOST: str
|
||||||
PORT: int
|
PORT: int
|
||||||
|
@ -338,55 +337,6 @@ class APIConfig(BaseModel):
|
||||||
err(f"Failed to acquire connection from pool for {pool_key}: {str(e)}")
|
err(f"Failed to acquire connection from pool for {pool_key}: {str(e)}")
|
||||||
yield None
|
yield None
|
||||||
|
|
||||||
async def push_changes_to_one(self, pool_entry):
|
|
||||||
try:
|
|
||||||
async with self.get_connection() as local_conn:
|
|
||||||
if local_conn is None:
|
|
||||||
err(f"Failed to connect to local database. Skipping push to {pool_entry['ts_id']}")
|
|
||||||
return
|
|
||||||
|
|
||||||
async with self.get_connection(pool_entry) as remote_conn:
|
|
||||||
if remote_conn is None:
|
|
||||||
err(f"Failed to connect to remote database {pool_entry['ts_id']}. Skipping push.")
|
|
||||||
return
|
|
||||||
|
|
||||||
tables = await local_conn.fetch("""
|
|
||||||
SELECT tablename FROM pg_tables
|
|
||||||
WHERE schemaname = 'public'
|
|
||||||
""")
|
|
||||||
|
|
||||||
for table in tables:
|
|
||||||
table_name = table['tablename']
|
|
||||||
try:
|
|
||||||
if table_name in self.SPECIAL_TABLES:
|
|
||||||
await self.sync_special_table(local_conn, remote_conn, table_name)
|
|
||||||
else:
|
|
||||||
primary_key = await self.ensure_sync_columns(remote_conn, table_name)
|
|
||||||
last_synced_version = await self.get_last_synced_version(remote_conn, table_name, os.environ.get('TS_ID'))
|
|
||||||
|
|
||||||
changes = await local_conn.fetch(f"""
|
|
||||||
SELECT * FROM "{table_name}"
|
|
||||||
WHERE version > $1 AND server_id = $2
|
|
||||||
ORDER BY version ASC
|
|
||||||
""", last_synced_version, os.environ.get('TS_ID'))
|
|
||||||
|
|
||||||
if changes:
|
|
||||||
changes_count = await self.apply_batch_changes(remote_conn, table_name, changes, primary_key)
|
|
||||||
|
|
||||||
if changes_count > 0:
|
|
||||||
debug(f"Pushed {changes_count} changes for table {table_name} to {pool_entry['ts_id']}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
err(f"Error pushing changes for table {table_name} to {pool_entry['ts_id']}: {str(e)}")
|
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
info(f"Successfully pushed changes to {pool_entry['ts_id']}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
|
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
|
|
||||||
async def close_db_pools(self):
|
async def close_db_pools(self):
|
||||||
info("Closing database connection pools...")
|
info("Closing database connection pools...")
|
||||||
for pool_key, pool in self.db_pools.items():
|
for pool_key, pool in self.db_pools.items():
|
||||||
|
@ -500,114 +450,6 @@ class APIConfig(BaseModel):
|
||||||
err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
|
err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
async def apply_batch_changes(self, conn, table_name, changes, primary_key):
|
|
||||||
if conn is None or not changes:
|
|
||||||
debug(f"Skipping apply_batch_changes because conn is none or there are no changes.")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
columns = list(changes[0].keys())
|
|
||||||
placeholders = [f'${i+1}' for i in range(len(columns))]
|
|
||||||
|
|
||||||
if primary_key:
|
|
||||||
insert_query = f"""
|
|
||||||
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
|
|
||||||
VALUES ({', '.join(placeholders)})
|
|
||||||
ON CONFLICT ("{primary_key}") DO UPDATE SET
|
|
||||||
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in [primary_key, 'version', 'server_id'])},
|
|
||||||
version = EXCLUDED.version,
|
|
||||||
server_id = EXCLUDED.server_id
|
|
||||||
WHERE "{table_name}".version < EXCLUDED.version
|
|
||||||
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
|
|
||||||
"""
|
|
||||||
else:
|
|
||||||
# For tables without a primary key, we'll use all columns for conflict resolution
|
|
||||||
insert_query = f"""
|
|
||||||
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
|
|
||||||
VALUES ({', '.join(placeholders)})
|
|
||||||
ON CONFLICT DO NOTHING
|
|
||||||
"""
|
|
||||||
|
|
||||||
# debug(f"Generated insert query for {table_name}: {insert_query}")
|
|
||||||
|
|
||||||
affected_rows = 0
|
|
||||||
async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"):
|
|
||||||
values = [change[col] for col in columns]
|
|
||||||
# debug(f"Executing query for {table_name} with values: {values}")
|
|
||||||
result = await conn.execute(insert_query, *values)
|
|
||||||
affected_rows += int(result.split()[-1])
|
|
||||||
|
|
||||||
return affected_rows
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
err(f"Error applying batch changes to {table_name}: {str(e)}")
|
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
async def pull_changes(self, source_pool_entry, batch_size=10000):
|
|
||||||
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
|
|
||||||
debug("Skipping self-sync")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
total_changes = 0
|
|
||||||
source_id = source_pool_entry['ts_id']
|
|
||||||
source_ip = source_pool_entry['ts_ip']
|
|
||||||
dest_id = os.environ.get('TS_ID')
|
|
||||||
dest_ip = self.local_db['ts_ip']
|
|
||||||
|
|
||||||
info(f"Starting sync from source {source_id} ({source_ip}) to destination {dest_id} ({dest_ip})")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with self.get_connection(source_pool_entry) as source_conn:
|
|
||||||
async with self.get_connection(self.local_db) as dest_conn:
|
|
||||||
tables = await source_conn.fetch("""
|
|
||||||
SELECT tablename FROM pg_tables
|
|
||||||
WHERE schemaname = 'public'
|
|
||||||
""")
|
|
||||||
|
|
||||||
async for table in tqdm(tables, desc="Syncing tables", unit="table"):
|
|
||||||
table_name = table['tablename']
|
|
||||||
try:
|
|
||||||
if table_name in self.SPECIAL_TABLES:
|
|
||||||
await self.sync_special_table(source_conn, dest_conn, table_name)
|
|
||||||
else:
|
|
||||||
primary_key = await self.ensure_sync_columns(dest_conn, table_name)
|
|
||||||
last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
|
|
||||||
|
|
||||||
changes = await source_conn.fetch(f"""
|
|
||||||
SELECT * FROM "{table_name}"
|
|
||||||
WHERE version > $1 AND server_id = $2
|
|
||||||
ORDER BY version ASC
|
|
||||||
LIMIT $3
|
|
||||||
""", last_synced_version, source_id, batch_size)
|
|
||||||
|
|
||||||
if changes:
|
|
||||||
changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, primary_key)
|
|
||||||
total_changes += changes_count
|
|
||||||
|
|
||||||
if changes_count > 0:
|
|
||||||
info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}")
|
|
||||||
else:
|
|
||||||
debug(f"No changes to sync for {table_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
err(f"Error syncing table {table_name}: {str(e)}")
|
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
err(f"Error during sync process: {str(e)}")
|
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
info(f"Sync summary:")
|
|
||||||
info(f" Total changes: {total_changes}")
|
|
||||||
info(f" Tables synced: {len(tables)}")
|
|
||||||
info(f" Source: {source_id} ({source_ip})")
|
|
||||||
info(f" Destination: {dest_id} ({dest_ip})")
|
|
||||||
|
|
||||||
return total_changes
|
|
||||||
|
|
||||||
async def get_online_hosts(self) -> List[Dict[str, Any]]:
|
async def get_online_hosts(self) -> List[Dict[str, Any]]:
|
||||||
online_hosts = []
|
online_hosts = []
|
||||||
for pool_entry in self.POOL:
|
for pool_entry in self.POOL:
|
||||||
|
@ -619,31 +461,6 @@ class APIConfig(BaseModel):
|
||||||
err(f"Error checking host {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
|
err(f"Error checking host {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
|
||||||
return online_hosts
|
return online_hosts
|
||||||
|
|
||||||
async def push_changes_to_all(self):
|
|
||||||
for pool_entry in self.POOL:
|
|
||||||
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
|
|
||||||
try:
|
|
||||||
await self.push_changes_to_one(pool_entry)
|
|
||||||
except Exception as e:
|
|
||||||
err(f"Error pushing changes to {pool_entry['ts_id']}: {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def get_last_synced_version(self, conn, table_name, server_id):
|
|
||||||
if conn is None:
|
|
||||||
debug(f"Skipping offline server...")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
if table_name in self.SPECIAL_TABLES:
|
|
||||||
debug(f"Skipping get_last_synced_version becaue {table_name} is special.")
|
|
||||||
return 0 # Special handling for tables without version column
|
|
||||||
|
|
||||||
return await conn.fetchval(f"""
|
|
||||||
SELECT COALESCE(MAX(version), 0)
|
|
||||||
FROM "{table_name}"
|
|
||||||
WHERE server_id = $1
|
|
||||||
""", server_id)
|
|
||||||
|
|
||||||
async def check_postgis(self, conn):
|
async def check_postgis(self, conn):
|
||||||
if conn is None:
|
if conn is None:
|
||||||
debug(f"Skipping offline server...")
|
debug(f"Skipping offline server...")
|
||||||
|
@ -661,11 +478,134 @@ class APIConfig(BaseModel):
|
||||||
err(f"Error checking PostGIS: {str(e)}")
|
err(f"Error checking PostGIS: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def sync_special_table(self, source_conn, dest_conn, table_name):
|
async def execute_write_query(self, query: str, *args, table_name: str):
|
||||||
|
if table_name in self.SPECIAL_TABLES:
|
||||||
|
return await self._execute_special_table_write(query, *args, table_name=table_name)
|
||||||
|
|
||||||
|
async with self.get_connection() as conn:
|
||||||
|
if conn is None:
|
||||||
|
raise ConnectionError("Failed to connect to local database")
|
||||||
|
|
||||||
|
# Ensure sync columns exist
|
||||||
|
primary_key = await self.ensure_sync_columns(conn, table_name)
|
||||||
|
|
||||||
|
# Execute the query
|
||||||
|
result = await conn.execute(query, *args)
|
||||||
|
|
||||||
|
# Get the primary key and new version of the affected row
|
||||||
|
if primary_key:
|
||||||
|
affected_row = await conn.fetchrow(f"""
|
||||||
|
SELECT "{primary_key}", version, server_id
|
||||||
|
FROM "{table_name}"
|
||||||
|
WHERE version = (SELECT MAX(version) FROM "{table_name}")
|
||||||
|
""")
|
||||||
|
if affected_row:
|
||||||
|
await self.push_change(table_name, affected_row[primary_key], affected_row['version'], affected_row['server_id'])
|
||||||
|
else:
|
||||||
|
# For tables without a primary key, we'll push all rows
|
||||||
|
await self.push_all_changes(table_name)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def push_change(self, table_name: str, pk_value: Any, version: int, server_id: str):
|
||||||
|
online_hosts = await self.get_online_hosts()
|
||||||
|
for pool_entry in online_hosts:
|
||||||
|
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
|
||||||
|
try:
|
||||||
|
async with self.get_connection(pool_entry) as remote_conn:
|
||||||
|
if remote_conn is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fetch the updated row from the local database
|
||||||
|
async with self.get_connection() as local_conn:
|
||||||
|
updated_row = await local_conn.fetchrow(f'SELECT * FROM "{table_name}" WHERE "{self.get_primary_key(table_name)}" = $1', pk_value)
|
||||||
|
|
||||||
|
if updated_row:
|
||||||
|
columns = updated_row.keys()
|
||||||
|
placeholders = [f'${i+1}' for i in range(len(columns))]
|
||||||
|
primary_key = self.get_primary_key(table_name)
|
||||||
|
|
||||||
|
insert_query = f"""
|
||||||
|
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
|
||||||
|
VALUES ({', '.join(placeholders)})
|
||||||
|
ON CONFLICT ("{primary_key}") DO UPDATE SET
|
||||||
|
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != primary_key)},
|
||||||
|
version = EXCLUDED.version,
|
||||||
|
server_id = EXCLUDED.server_id
|
||||||
|
WHERE "{table_name}".version < EXCLUDED.version
|
||||||
|
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
|
||||||
|
"""
|
||||||
|
await remote_conn.execute(insert_query, *updated_row.values())
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Error pushing change to {pool_entry['ts_id']}: {str(e)}")
|
||||||
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
async def push_all_changes(self, table_name: str):
|
||||||
|
online_hosts = await self.get_online_hosts()
|
||||||
|
for pool_entry in online_hosts:
|
||||||
|
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
|
||||||
|
try:
|
||||||
|
async with self.get_connection(pool_entry) as remote_conn:
|
||||||
|
if remote_conn is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fetch all rows from the local database
|
||||||
|
async with self.get_connection() as local_conn:
|
||||||
|
all_rows = await local_conn.fetch(f'SELECT * FROM "{table_name}"')
|
||||||
|
|
||||||
|
if all_rows:
|
||||||
|
columns = all_rows[0].keys()
|
||||||
|
placeholders = [f'${i+1}' for i in range(len(columns))]
|
||||||
|
|
||||||
|
insert_query = f"""
|
||||||
|
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
|
||||||
|
VALUES ({', '.join(placeholders)})
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Use a transaction to insert all rows
|
||||||
|
async with remote_conn.transaction():
|
||||||
|
for row in all_rows:
|
||||||
|
await remote_conn.execute(insert_query, *row.values())
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Error pushing all changes to {pool_entry['ts_id']}: {str(e)}")
|
||||||
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
async def _execute_special_table_write(self, query: str, *args, table_name: str):
|
||||||
if table_name == 'spatial_ref_sys':
|
if table_name == 'spatial_ref_sys':
|
||||||
return await self.sync_spatial_ref_sys(source_conn, dest_conn)
|
return await self._execute_spatial_ref_sys_write(query, *args)
|
||||||
# Add more special cases as needed
|
# Add more special cases as needed
|
||||||
|
|
||||||
|
async def _execute_spatial_ref_sys_write(self, query: str, *args):
|
||||||
|
result = None
|
||||||
|
async with self.get_connection() as local_conn:
|
||||||
|
if local_conn is None:
|
||||||
|
raise ConnectionError("Failed to connect to local database")
|
||||||
|
|
||||||
|
# Execute the query locally
|
||||||
|
result = await local_conn.execute(query, *args)
|
||||||
|
|
||||||
|
# Sync the entire spatial_ref_sys table with all online hosts
|
||||||
|
online_hosts = await self.get_online_hosts()
|
||||||
|
for pool_entry in online_hosts:
|
||||||
|
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
|
||||||
|
try:
|
||||||
|
async with self.get_connection(pool_entry) as remote_conn:
|
||||||
|
if remote_conn is None:
|
||||||
|
continue
|
||||||
|
await self.sync_spatial_ref_sys(local_conn, remote_conn)
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Error syncing spatial_ref_sys to {pool_entry['ts_id']}: {str(e)}")
|
||||||
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_primary_key(self, table_name: str) -> str:
|
||||||
|
# This method should return the primary key for the given table
|
||||||
|
# You might want to cache this information for performance
|
||||||
|
# For now, we'll assume it's always 'id', but you should implement proper logic here
|
||||||
|
return 'id'
|
||||||
|
|
||||||
async def sync_spatial_ref_sys(self, source_conn, dest_conn):
|
async def sync_spatial_ref_sys(self, source_conn, dest_conn):
|
||||||
try:
|
try:
|
||||||
# Get all entries from the source
|
# Get all entries from the source
|
||||||
|
@ -696,7 +636,6 @@ class APIConfig(BaseModel):
|
||||||
INSERT INTO spatial_ref_sys ({', '.join(f'"{col}"' for col in columns)})
|
INSERT INTO spatial_ref_sys ({', '.join(f'"{col}"' for col in columns)})
|
||||||
VALUES ({', '.join(placeholders)})
|
VALUES ({', '.join(placeholders)})
|
||||||
"""
|
"""
|
||||||
# debug(f"Inserting new entry for srid {srid}: {insert_query}")
|
|
||||||
await dest_conn.execute(insert_query, *source_entry.values())
|
await dest_conn.execute(insert_query, *source_entry.values())
|
||||||
inserts += 1
|
inserts += 1
|
||||||
elif source_entry != dest_dict[srid]:
|
elif source_entry != dest_dict[srid]:
|
||||||
|
@ -709,7 +648,6 @@ class APIConfig(BaseModel):
|
||||||
proj4text = $4::text
|
proj4text = $4::text
|
||||||
WHERE srid = $5::integer
|
WHERE srid = $5::integer
|
||||||
"""
|
"""
|
||||||
# debug(f"Updating entry for srid {srid}: {update_query}")
|
|
||||||
await dest_conn.execute(update_query,
|
await dest_conn.execute(update_query,
|
||||||
source_entry['auth_name'],
|
source_entry['auth_name'],
|
||||||
source_entry['auth_srid'],
|
source_entry['auth_srid'],
|
||||||
|
@ -727,63 +665,6 @@ class APIConfig(BaseModel):
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
err(f"Traceback: {traceback.format_exc()}")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def get_most_recent_source(self):
|
|
||||||
most_recent_source = None
|
|
||||||
max_version = -1
|
|
||||||
local_ts_id = os.environ.get('TS_ID')
|
|
||||||
online_hosts = await self.get_online_hosts()
|
|
||||||
num_online_hosts = len(online_hosts)
|
|
||||||
if num_online_hosts > 0:
|
|
||||||
online_ts_ids = [host['ts_id'] for host in online_hosts if host['ts_id'] != local_ts_id]
|
|
||||||
crit(f"Online hosts: {', '.join(online_ts_ids)}")
|
|
||||||
|
|
||||||
for pool_entry in online_hosts:
|
|
||||||
if pool_entry['ts_id'] == local_ts_id:
|
|
||||||
continue # Skip local database
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with self.get_connection(pool_entry) as conn:
|
|
||||||
tables = await conn.fetch("""
|
|
||||||
SELECT tablename FROM pg_tables
|
|
||||||
WHERE schemaname = 'public'
|
|
||||||
""")
|
|
||||||
|
|
||||||
for table in tables:
|
|
||||||
table_name = table['tablename']
|
|
||||||
if table_name in self.SPECIAL_TABLES:
|
|
||||||
continue # Skip special tables for version comparison
|
|
||||||
try:
|
|
||||||
result = await conn.fetchrow(f"""
|
|
||||||
SELECT MAX(version) as max_version, server_id
|
|
||||||
FROM "{table_name}"
|
|
||||||
WHERE version = (SELECT MAX(version) FROM "{table_name}")
|
|
||||||
GROUP BY server_id
|
|
||||||
ORDER BY MAX(version) DESC
|
|
||||||
LIMIT 1
|
|
||||||
""")
|
|
||||||
if result:
|
|
||||||
version, server_id = result['max_version'], result['server_id']
|
|
||||||
info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})")
|
|
||||||
if version > max_version:
|
|
||||||
max_version = version
|
|
||||||
most_recent_source = pool_entry
|
|
||||||
else:
|
|
||||||
debug(f"No data in table {table_name} for {pool_entry['ts_id']}")
|
|
||||||
except asyncpg.exceptions.UndefinedColumnError:
|
|
||||||
warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping.")
|
|
||||||
except Exception as e:
|
|
||||||
err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}")
|
|
||||||
|
|
||||||
except asyncpg.exceptions.ConnectionFailureError:
|
|
||||||
warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
|
|
||||||
except Exception as e:
|
|
||||||
err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}")
|
|
||||||
err(f"Traceback: {traceback.format_exc()}")
|
|
||||||
|
|
||||||
return most_recent_source
|
|
||||||
else:
|
|
||||||
warn(f"No other online hosts for sync")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class Location(BaseModel):
|
class Location(BaseModel):
|
||||||
|
|
|
@ -417,50 +417,49 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True:
|
||||||
await cl_download_file(download_url, target_path, session)
|
await cl_download_file(download_url, target_path, session)
|
||||||
debug(f"Downloaded {file_name} to {target_path}")
|
debug(f"Downloaded {file_name} to {target_path}")
|
||||||
|
|
||||||
|
|
||||||
@serve.get("/s", response_class=HTMLResponse)
|
@serve.get("/s", response_class=HTMLResponse)
|
||||||
async def shortener_form(request: Request):
|
async def shortener_form(request: Request):
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request})
|
return templates.TemplateResponse("shortener.html", {"request": request})
|
||||||
|
|
||||||
@serve.post("/s")
|
@serve.post("/s")
|
||||||
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
|
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
|
||||||
async with API.get_connection() as conn:
|
await create_tables()
|
||||||
await create_tables(conn)
|
|
||||||
|
|
||||||
if custom_code:
|
if custom_code:
|
||||||
if len(custom_code) != 3 or not custom_code.isalnum():
|
if len(custom_code) != 3 or not custom_code.isalnum():
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
|
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
|
||||||
|
|
||||||
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code)
|
existing = await API.execute_write_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls")
|
||||||
if existing:
|
if existing:
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
|
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
|
||||||
|
|
||||||
short_code = custom_code
|
short_code = custom_code
|
||||||
else:
|
else:
|
||||||
chars = string.ascii_letters + string.digits
|
chars = string.ascii_letters + string.digits
|
||||||
while True:
|
while True:
|
||||||
short_code = ''.join(random.choice(chars) for _ in range(3))
|
short_code = ''.join(random.choice(chars) for _ in range(3))
|
||||||
existing = await conn.fetchval('SELECT 1 FROM short_urls WHERE short_code = $1', short_code)
|
existing = await API.execute_write_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls")
|
||||||
if not existing:
|
if not existing:
|
||||||
break
|
break
|
||||||
|
|
||||||
await conn.execute(
|
await API.execute_write_query(
|
||||||
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
|
'INSERT INTO short_urls (short_code, long_url) VALUES ($1, $2)',
|
||||||
short_code, long_url
|
short_code, long_url,
|
||||||
)
|
table_name="short_urls"
|
||||||
|
)
|
||||||
|
|
||||||
short_url = f"https://sij.ai/{short_code}"
|
short_url = f"https://sij.ai/{short_code}"
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
|
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
|
||||||
|
|
||||||
async def create_tables(conn):
|
async def create_tables():
|
||||||
await conn.execute('''
|
await API.execute_write_query('''
|
||||||
CREATE TABLE IF NOT EXISTS short_urls (
|
CREATE TABLE IF NOT EXISTS short_urls (
|
||||||
short_code VARCHAR(3) PRIMARY KEY,
|
short_code VARCHAR(3) PRIMARY KEY,
|
||||||
long_url TEXT NOT NULL,
|
long_url TEXT NOT NULL,
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
)
|
)
|
||||||
''')
|
''', table_name="short_urls")
|
||||||
await conn.execute('''
|
await API.execute_write_query('''
|
||||||
CREATE TABLE IF NOT EXISTS click_logs (
|
CREATE TABLE IF NOT EXISTS click_logs (
|
||||||
id SERIAL PRIMARY KEY,
|
id SERIAL PRIMARY KEY,
|
||||||
short_code VARCHAR(3) REFERENCES short_urls(short_code),
|
short_code VARCHAR(3) REFERENCES short_urls(short_code),
|
||||||
|
@ -468,48 +467,51 @@ async def create_tables(conn):
|
||||||
ip_address TEXT,
|
ip_address TEXT,
|
||||||
user_agent TEXT
|
user_agent TEXT
|
||||||
)
|
)
|
||||||
''')
|
''', table_name="click_logs")
|
||||||
|
|
||||||
@serve.get("/{short_code}", response_class=RedirectResponse, status_code=301)
|
@serve.get("/{short_code}", response_class=RedirectResponse, status_code=301)
|
||||||
async def redirect_short_url(request: Request, short_code: str = PathParam(..., min_length=3, max_length=3)):
|
async def redirect_short_url(request: Request, short_code: str = PathParam(..., min_length=3, max_length=3)):
|
||||||
if request.headers.get('host') != 'sij.ai':
|
if request.headers.get('host') != 'sij.ai':
|
||||||
raise HTTPException(status_code=404, detail="Not Found")
|
raise HTTPException(status_code=404, detail="Not Found")
|
||||||
|
|
||||||
async with API.get_connection() as conn:
|
result = await API.execute_write_query(
|
||||||
result = await conn.fetchrow(
|
'SELECT long_url FROM short_urls WHERE short_code = $1',
|
||||||
'SELECT long_url FROM short_urls WHERE short_code = $1',
|
short_code,
|
||||||
short_code
|
table_name="short_urls"
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
await conn.execute(
|
await API.execute_write_query(
|
||||||
'INSERT INTO click_logs (short_code, ip_address, user_agent) VALUES ($1, $2, $3)',
|
'INSERT INTO click_logs (short_code, ip_address, user_agent) VALUES ($1, $2, $3)',
|
||||||
short_code, request.client.host, request.headers.get("user-agent")
|
short_code, request.client.host, request.headers.get("user-agent"),
|
||||||
)
|
table_name="click_logs"
|
||||||
return result['long_url']
|
)
|
||||||
else:
|
return result['long_url']
|
||||||
raise HTTPException(status_code=404, detail="Short URL not found")
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail="Short URL not found")
|
||||||
|
|
||||||
@serve.get("/analytics/{short_code}")
|
@serve.get("/analytics/{short_code}")
|
||||||
async def get_analytics(short_code: str):
|
async def get_analytics(short_code: str):
|
||||||
async with API.get_connection() as conn:
|
url_info = await API.execute_write_query(
|
||||||
url_info = await conn.fetchrow(
|
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
|
||||||
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
|
short_code,
|
||||||
short_code
|
table_name="short_urls"
|
||||||
)
|
)
|
||||||
if not url_info:
|
if not url_info:
|
||||||
raise HTTPException(status_code=404, detail="Short URL not found")
|
raise HTTPException(status_code=404, detail="Short URL not found")
|
||||||
|
|
||||||
click_count = await conn.fetchval(
|
click_count = await API.execute_write_query(
|
||||||
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1',
|
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1',
|
||||||
short_code
|
short_code,
|
||||||
)
|
table_name="click_logs"
|
||||||
|
)
|
||||||
clicks = await conn.fetch(
|
|
||||||
'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
|
clicks = await API.execute_write_query(
|
||||||
short_code
|
'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
|
||||||
)
|
short_code,
|
||||||
|
table_name="click_logs"
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"short_code": short_code,
|
"short_code": short_code,
|
||||||
"long_url": url_info['long_url'],
|
"long_url": url_info['long_url'],
|
||||||
|
@ -520,6 +522,7 @@ async def get_analytics(short_code: str):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, destination: str):
|
async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, destination: str):
|
||||||
dest_host, dest_port = destination.split(':')
|
dest_host, dest_port = destination.split(':')
|
||||||
dest_port = int(dest_port)
|
dest_port = int(dest_port)
|
||||||
|
|
Loading…
Add table
Reference in a new issue