From 25b12d86dc87fd7a2d88b6bf52edec02bedfded0 Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Thu, 1 Aug 2024 23:51:01 -0700 Subject: [PATCH] Much more efficient database sync method, first working version. --- sijapi/__init__.py | 1 + sijapi/__main__.py | 70 ++-- sijapi/classes.py | 801 ++++++++++++++++++++++++++++++-------- sijapi/routers/gis.py | 243 ++++++------ sijapi/routers/serve.py | 96 ++--- sijapi/routers/weather.py | 228 +++++------ 6 files changed, 950 insertions(+), 489 deletions(-) diff --git a/sijapi/__init__.py b/sijapi/__init__.py index bdb64a4..c2a14d3 100644 --- a/sijapi/__init__.py +++ b/sijapi/__init__.py @@ -20,6 +20,7 @@ L = Logger("Central", LOGS_DIR) # API essentials API = APIConfig.load('api', 'secrets') + Dir = Configuration.load('dirs') HOST = f"{API.BIND}:{API.PORT}" LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost'] diff --git a/sijapi/__main__.py b/sijapi/__main__.py index 27eae12..e609014 100755 --- a/sijapi/__main__.py +++ b/sijapi/__main__.py @@ -39,12 +39,11 @@ def warn(text: str): logger.warning(text) def err(text: str): logger.error(text) def crit(text: str): logger.critical(text) - @asynccontextmanager async def lifespan(app: FastAPI): # Startup crit("sijapi launched") - crit(f"Arguments: {args}") + info(f"Arguments: {args}") # Load routers if args.test: @@ -54,20 +53,10 @@ async def lifespan(app: FastAPI): if getattr(API.MODULES, module_name): load_router(module_name) - crit("Starting database synchronization...") try: # Initialize sync structures on all databases await API.initialize_sync() - - # Check if other instances have more recent data - source = await API.get_most_recent_source() - if source: - crit(f"Pulling changes from {source['ts_id']} ({source['ts_ip']})...") - total_changes = await API.pull_changes(source) - crit(f"Data pull complete. Total changes: {total_changes}") - else: - crit("No instances with more recent data found or all instances are offline.") - + except Exception as e: crit(f"Error during startup: {str(e)}") crit(f"Traceback: {traceback.format_exc()}") @@ -79,8 +68,6 @@ async def lifespan(app: FastAPI): await API.close_db_pools() crit("Database pools closed.") - - app = FastAPI(lifespan=lifespan) app.add_middleware( @@ -135,30 +122,41 @@ async def handle_exception_middleware(request: Request, call_next): return response -# This was removed on 7/31/2024 when we decided to instead use a targeted push sync approach. -deprecated = ''' -async def push_changes_background(): +@app.post("/sync/pull") +async def pull_changes(): try: - await API.push_changes_to_all() - except Exception as e: - err(f"Error pushing changes to other databases: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") - - -@app.middleware("http") -async def sync_middleware(request: Request, call_next): - response = await call_next(request) - - # Check if the request was a database write operation - if request.method in ["POST", "PUT", "PATCH", "DELETE"]: + await API.add_primary_keys_to_local_tables() + await API.add_primary_keys_to_remote_tables() try: - # Push changes to other databases - await API.push_changes_to_all() + + source = await API.get_most_recent_source() + + if source: + # Pull changes from the source + total_changes = await API.pull_changes(source) + + return JSONResponse(content={ + "status": "success", + "message": f"Pull complete. Total changes: {total_changes}", + "source": f"{source['ts_id']} ({source['ts_ip']})", + "changes": total_changes + }) + else: + return JSONResponse(content={ + "status": "info", + "message": "No instances with more recent data found or all instances are offline." + }) + except Exception as e: - err(f"Error pushing changes to other databases: {str(e)}") - - return response -''' + err(f"Error during pull: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Error during pull: {str(e)}") + + except Exception as e: + err(f"Error while ensuring primary keys to tables: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Error during primary key insurance: {str(e)}") + def load_router(router_name): router_file = ROUTER_DIR / f'{router_name}.py' diff --git a/sijapi/classes.py b/sijapi/classes.py index 191a564..ed7e5fd 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -5,6 +5,7 @@ import math import os import re import uuid +import time import aiofiles import aiohttp import asyncio @@ -180,14 +181,18 @@ class APIConfig(BaseModel): TZ: str KEYS: List[str] GARBAGE: Dict[str, Any] - SPECIAL_TABLES: ClassVar[List[str]] = ['spatial_ref_sys'] - db_pools: Dict[str, Any] = Field(default_factory=dict) + offline_servers: Dict[str, float] = Field(default_factory=dict) + offline_timeout: float = Field(default=30.0) # 30 second timeout for offline servers + online_hosts_cache: Dict[str, Tuple[List[Dict[str, Any]], float]] = Field(default_factory=dict) + online_hosts_cache_ttl: float = Field(default=30.0) # Cache TTL in seconds def __init__(self, **data): super().__init__(**data) - self._db_pools = {} + self.db_pools = {} + self.online_hosts_cache = {} # Initialize the cache + self._sync_tasks = {} class Config: arbitrary_types_allowed = True @@ -306,13 +311,48 @@ class APIConfig(BaseModel): raise ValueError(f"No database configuration found for TS_ID: {ts_id}") return local_db - @asynccontextmanager + async def get_online_hosts(self) -> List[Dict[str, Any]]: + current_time = time.time() + cache_key = "online_hosts" + + if cache_key in self.online_hosts_cache: + cached_hosts, cache_time = self.online_hosts_cache[cache_key] + if current_time - cache_time < self.online_hosts_cache_ttl: + return cached_hosts + + online_hosts = [] + local_ts_id = os.environ.get('TS_ID') + + for pool_entry in self.POOL: + if pool_entry['ts_id'] != local_ts_id: + pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}" + if pool_key in self.offline_servers: + if current_time - self.offline_servers[pool_key] < self.offline_timeout: + continue + else: + del self.offline_servers[pool_key] + + conn = await self.get_connection(pool_entry) + if conn is not None: + online_hosts.append(pool_entry) + await conn.close() + + self.online_hosts_cache[cache_key] = (online_hosts, current_time) + return online_hosts + async def get_connection(self, pool_entry: Dict[str, Any] = None): if pool_entry is None: pool_entry = self.local_db pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}" + # Check if the server is marked as offline + if pool_key in self.offline_servers: + if time.time() - self.offline_servers[pool_key] < self.offline_timeout: + return None + else: + del self.offline_servers[pool_key] + if pool_key not in self.db_pools: try: self.db_pools[pool_key] = await asyncpg.create_pool( @@ -326,27 +366,21 @@ class APIConfig(BaseModel): timeout=5 ) except Exception as e: - err(f"Failed to create connection pool for {pool_key}: {str(e)}") - yield None - return + warn(f"Failed to create connection pool for {pool_key}: {str(e)}") + self.offline_servers[pool_key] = time.time() + return None try: - async with self.db_pools[pool_key].acquire() as conn: - yield conn + return await asyncio.wait_for(self.db_pools[pool_key].acquire(), timeout=5) + except asyncio.TimeoutError: + warn(f"Timeout acquiring connection from pool for {pool_key}") + self.offline_servers[pool_key] = time.time() + return None except Exception as e: - err(f"Failed to acquire connection from pool for {pool_key}: {str(e)}") - yield None + warn(f"Failed to acquire connection for {pool_key}: {str(e)}") + self.offline_servers[pool_key] = time.time() + return None - async def close_db_pools(self): - info("Closing database connection pools...") - for pool_key, pool in self.db_pools.items(): - try: - await pool.close() - debug(f"Closed pool for {pool_key}") - except Exception as e: - err(f"Error closing pool for {pool_key}: {str(e)}") - self.db_pools.clear() - info("All database connection pools closed.") async def initialize_sync(self): local_ts_id = os.environ.get('TS_ID') @@ -356,52 +390,60 @@ class APIConfig(BaseModel): if pool_entry['ts_id'] == local_ts_id: continue # Skip local database try: - async with self.get_connection(pool_entry) as conn: - if conn is None: - continue # Skip this database if connection failed - - debug(f"Starting sync initialization for {pool_entry['ts_ip']}...") - - # Check PostGIS installation - postgis_installed = await self.check_postgis(conn) - if not postgis_installed: - warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.") - - tables = await conn.fetch(""" - SELECT tablename FROM pg_tables - WHERE schemaname = 'public' - """) - - for table in tables: - table_name = table['tablename'] - await self.ensure_sync_columns(conn, table_name) - - debug(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have necessary sync columns and triggers.") + conn = await self.get_connection(pool_entry) + if conn is None: + continue # Skip this database if connection failed + + debug(f"Starting sync initialization for {pool_entry['ts_ip']}...") + + # Check PostGIS installation + postgis_installed = await self.check_postgis(conn) + if not postgis_installed: + warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.") + + tables = await conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] + await self.ensure_sync_columns(conn, table_name) + + debug(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have necessary sync columns and triggers.") except Exception as e: err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") + + def _schedule_sync_task(self, table_name: str, pk_value: Any, version: int, server_id: str): + # Use a background task manager to handle syncing + task_key = f"{table_name}:{pk_value}" if pk_value else table_name + if task_key not in self._sync_tasks: + self._sync_tasks[task_key] = asyncio.create_task(self._sync_changes(table_name, pk_value, version, server_id)) + async def ensure_sync_columns(self, conn, table_name): - if conn is None: - debug(f"Skipping offline server...") - return None - - if table_name in self.SPECIAL_TABLES: - debug(f"Skipping sync columns for special table: {table_name}") + if conn is None or table_name in self.SPECIAL_TABLES: return None try: - # Get primary key information + # Check if primary key exists primary_key = await conn.fetchval(f""" SELECT a.attname FROM pg_index i - JOIN pg_attribute a ON a.attrelid = i.indrelid - AND a.attnum = ANY(i.indkey) - WHERE i.indrelid = '{table_name}'::regclass - AND i.indisprimary; + JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = '{table_name}'::regclass AND i.indisprimary; """) + if not primary_key: + # Add an id column as primary key if it doesn't exist + await conn.execute(f""" + ALTER TABLE "{table_name}" + ADD COLUMN IF NOT EXISTS id SERIAL PRIMARY KEY; + """) + primary_key = 'id' + # Ensure version column exists await conn.execute(f""" ALTER TABLE "{table_name}" @@ -445,21 +487,11 @@ class APIConfig(BaseModel): debug(f"Successfully ensured sync columns and trigger for table {table_name}") return primary_key - + except Exception as e: err(f"Error ensuring sync columns for table {table_name}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") - - async def get_online_hosts(self) -> List[Dict[str, Any]]: - online_hosts = [] - for pool_entry in self.POOL: - try: - async with self.get_connection(pool_entry) as conn: - if conn is not None: - online_hosts.append(pool_entry) - except Exception as e: - err(f"Error checking host {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}") - return online_hosts + return None async def check_postgis(self, conn): if conn is None: @@ -477,134 +509,508 @@ class APIConfig(BaseModel): except Exception as e: err(f"Error checking PostGIS: {str(e)}") return False + - async def execute_write_query(self, query: str, *args, table_name: str): - if table_name in self.SPECIAL_TABLES: - return await self._execute_special_table_write(query, *args, table_name=table_name) + async def pull_changes(self, source_pool_entry, batch_size=10000): + if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): + debug("Skipping self-sync") + return 0 - async with self.get_connection() as conn: - if conn is None: - raise ConnectionError("Failed to connect to local database") + total_changes = 0 + source_id = source_pool_entry['ts_id'] + source_ip = source_pool_entry['ts_ip'] + dest_id = os.environ.get('TS_ID') + dest_ip = self.local_db['ts_ip'] - # Ensure sync columns exist - primary_key = await self.ensure_sync_columns(conn, table_name) + info(f"Starting sync from source {source_id} ({source_ip}) to destination {dest_id} ({dest_ip})") - # Execute the query - result = await conn.execute(query, *args) + source_conn = None + dest_conn = None + try: + source_conn = await self.get_connection(source_pool_entry) + if source_conn is None: + warn(f"Unable to connect to source {source_id} ({source_ip}). Skipping sync.") + return 0 - # Get the primary key and new version of the affected row - if primary_key: - affected_row = await conn.fetchrow(f""" - SELECT "{primary_key}", version, server_id - FROM "{table_name}" - WHERE version = (SELECT MAX(version) FROM "{table_name}") - """) - if affected_row: - await self.push_change(table_name, affected_row[primary_key], affected_row['version'], affected_row['server_id']) - else: - # For tables without a primary key, we'll push all rows - await self.push_all_changes(table_name) + dest_conn = await self.get_connection(self.local_db) + if dest_conn is None: + warn(f"Unable to connect to local database. Skipping sync.") + return 0 - return result - - async def push_change(self, table_name: str, pk_value: Any, version: int, server_id: str): - online_hosts = await self.get_online_hosts() - for pool_entry in online_hosts: - if pool_entry['ts_id'] != os.environ.get('TS_ID'): + tables = await source_conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] try: - async with self.get_connection(pool_entry) as remote_conn: + if table_name in self.SPECIAL_TABLES: + await self.sync_special_table(source_conn, dest_conn, table_name) + else: + primary_key = await self.ensure_sync_columns(dest_conn, table_name) + last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id) + + changes = await source_conn.fetch(f""" + SELECT * FROM "{table_name}" + WHERE version > $1 AND server_id = $2 + ORDER BY version ASC + LIMIT $3 + """, last_synced_version, source_id, batch_size) + + if changes: + changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, primary_key) + total_changes += changes_count + + if changes_count > 0: + info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}") + else: + debug(f"No changes to sync for {table_name}") + + except Exception as e: + err(f"Error syncing table {table_name}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + + info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}") + + except Exception as e: + err(f"Error during sync process: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + + finally: + if source_conn: + await source_conn.close() + if dest_conn: + await dest_conn.close() + + info(f"Sync summary:") + info(f" Total changes: {total_changes}") + info(f" Tables synced: {len(tables) if 'tables' in locals() else 0}") + info(f" Source: {source_id} ({source_ip})") + info(f" Destination: {dest_id} ({dest_ip})") + + return total_changes + + + async def get_last_synced_version(self, conn, table_name, server_id): + if conn is None: + debug(f"Skipping offline server...") + return 0 + + if table_name in self.SPECIAL_TABLES: + debug(f"Skipping get_last_synced_version because {table_name} is special.") + return 0 # Special handling for tables without version column + + try: + last_version = await conn.fetchval(f""" + SELECT COALESCE(MAX(version), 0) + FROM "{table_name}" + WHERE server_id = $1 + """, server_id) + return last_version + except Exception as e: + err(f"Error getting last synced version for table {table_name}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + return 0 + + + async def get_most_recent_source(self): + most_recent_source = None + max_version = -1 + local_ts_id = os.environ.get('TS_ID') + online_hosts = await self.get_online_hosts() + num_online_hosts = len(online_hosts) + + if num_online_hosts > 0: + online_ts_ids = [host['ts_id'] for host in online_hosts if host['ts_id'] != local_ts_id] + crit(f"Online hosts: {', '.join(online_ts_ids)}") + + for pool_entry in online_hosts: + if pool_entry['ts_id'] == local_ts_id: + continue # Skip local database + + try: + conn = await self.get_connection(pool_entry) + if conn is None: + warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping.") + continue + + tables = await conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] + if table_name in self.SPECIAL_TABLES: + continue + try: + result = await conn.fetchrow(f""" + SELECT MAX(version) as max_version, server_id + FROM "{table_name}" + WHERE version = (SELECT MAX(version) FROM "{table_name}") + GROUP BY server_id + ORDER BY MAX(version) DESC + LIMIT 1 + """) + if result: + version, server_id = result['max_version'], result['server_id'] + info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})") + if version > max_version: + max_version = version + most_recent_source = pool_entry + else: + debug(f"No data in table {table_name} for {pool_entry['ts_id']}") + except asyncpg.exceptions.UndefinedColumnError: + warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping.") + except Exception as e: + err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}") + + except asyncpg.exceptions.ConnectionFailureError: + warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}. Skipping.") + except Exception as e: + err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + finally: + if conn: + await conn.close() + + if most_recent_source is None: + if num_online_hosts > 0: + warn("Could not determine most recent source. Using first available online host.") + most_recent_source = next(host for host in online_hosts if host['ts_id'] != local_ts_id) + else: + crit("No other online hosts available for sync.") + + return most_recent_source + + + + async def _sync_changes(self, table_name: str, primary_key: str): + try: + local_conn = await self.get_connection() + if local_conn is None: + return + + # Get the latest changes + changes = await local_conn.fetch(f""" + SELECT * FROM "{table_name}" + WHERE version > (SELECT COALESCE(MAX(version), 0) FROM "{table_name}" WHERE server_id != $1) + OR (version = (SELECT COALESCE(MAX(version), 0) FROM "{table_name}" WHERE server_id != $1) AND server_id = $1) + ORDER BY version ASC + """, os.environ.get('TS_ID')) + + if changes: + online_hosts = await self.get_online_hosts() + for pool_entry in online_hosts: + if pool_entry['ts_id'] != os.environ.get('TS_ID'): + remote_conn = await self.get_connection(pool_entry) if remote_conn is None: continue + try: + await self.apply_batch_changes(remote_conn, table_name, changes, primary_key) + finally: + await remote_conn.close() + except Exception as e: + err(f"Error syncing changes for {table_name}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + finally: + if 'local_conn' in locals(): + await local_conn.close() - # Fetch the updated row from the local database - async with self.get_connection() as local_conn: - updated_row = await local_conn.fetchrow(f'SELECT * FROM "{table_name}" WHERE "{self.get_primary_key(table_name)}" = $1', pk_value) - if updated_row: - columns = updated_row.keys() - placeholders = [f'${i+1}' for i in range(len(columns))] - primary_key = self.get_primary_key(table_name) - - insert_query = f""" - INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) - VALUES ({', '.join(placeholders)}) - ON CONFLICT ("{primary_key}") DO UPDATE SET - {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != primary_key)}, - version = EXCLUDED.version, - server_id = EXCLUDED.server_id - WHERE "{table_name}".version < EXCLUDED.version - OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) - """ - await remote_conn.execute(insert_query, *updated_row.values()) + async def execute_read_query(self, query: str, *args, table_name: str): + online_hosts = await self.get_online_hosts() + results = [] + max_version = -1 + latest_result = None + + for pool_entry in online_hosts: + conn = await self.get_connection(pool_entry) + if conn is None: + warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping read.") + continue + + try: + # Execute the query + result = await conn.fetch(query, *args) + + if not result: + continue + + # Check version if it's not a special table + if table_name not in self.SPECIAL_TABLES: + try: + version_query = f""" + SELECT MAX(version) as max_version, server_id + FROM "{table_name}" + WHERE version = (SELECT MAX(version) FROM "{table_name}") + GROUP BY server_id + ORDER BY MAX(version) DESC + LIMIT 1 + """ + version_result = await conn.fetchrow(version_query) + if version_result: + version = version_result['max_version'] + server_id = version_result['server_id'] + info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})") + if version > max_version: + max_version = version + latest_result = result + else: + debug(f"No data in table {table_name} for {pool_entry['ts_id']}") + except asyncpg.exceptions.UndefinedColumnError: + warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping version check.") + if latest_result is None: + latest_result = result + else: + # For special tables, just use the first result + if latest_result is None: + latest_result = result + + results.append((pool_entry['ts_id'], result)) + + except Exception as e: + err(f"Error executing read query on {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + finally: + await conn.close() + + if not latest_result: + warn(f"No results found for query on table {table_name}") + return [] + + # Log results from all databases + for ts_id, result in results: + info(f"Read result from {ts_id}: {result}") + + return [dict(r) for r in latest_result] # Convert Record objects to dictionaries + + async def execute_write_query(self, query: str, *args, table_name: str): + conn = await self.get_connection() + if conn is None: + raise ConnectionError("Failed to connect to local database") + + try: + if table_name in self.SPECIAL_TABLES: + return await self._execute_special_table_write(conn, query, *args, table_name=table_name) + + primary_key = await self.ensure_sync_columns(conn, table_name) + + result = await conn.execute(query, *args) + + asyncio.create_task(self._sync_changes(table_name, primary_key)) + + return [] + finally: + await conn.close() + + + async def _run_sync_tasks(self, tasks): + for task in tasks: + try: + await task + except Exception as e: + err(f"Error during background sync: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + + + async def push_change(self, table_name: str, pk_value: Any, version: int, server_id: str): + asyncio.create_task(self._push_change_background(table_name, pk_value, version, server_id)) + + async def _push_change_background(self, table_name: str, pk_value: Any, version: int, server_id: str): + online_hosts = await self.get_online_hosts() + successful_pushes = 0 + failed_pushes = 0 + + for pool_entry in online_hosts: + if pool_entry['ts_id'] != os.environ.get('TS_ID'): + remote_conn = await self.get_connection(pool_entry) + if remote_conn is None: + continue + + try: + local_conn = await self.get_connection() + if local_conn is None: + continue + + try: + updated_row = await local_conn.fetchrow(f'SELECT * FROM "{table_name}" WHERE "{self.get_primary_key(table_name)}" = $1', pk_value) + finally: + await local_conn.close() + + if updated_row: + columns = updated_row.keys() + placeholders = [f'${i+1}' for i in range(len(columns))] + primary_key = self.get_primary_key(table_name) + + remote_version = await remote_conn.fetchval(f""" + SELECT version FROM "{table_name}" + WHERE "{primary_key}" = $1 + """, updated_row[primary_key]) + + if remote_version is not None and remote_version >= updated_row['version']: + debug(f"Remote version for {table_name} in {pool_entry['ts_id']} is already up to date. Skipping push.") + successful_pushes += 1 + continue + + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT ("{primary_key}") DO UPDATE SET + {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != primary_key)}, + version = EXCLUDED.version, + server_id = EXCLUDED.server_id + WHERE "{table_name}".version < EXCLUDED.version + OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) + """ + await remote_conn.execute(insert_query, *updated_row.values()) + successful_pushes += 1 except Exception as e: err(f"Error pushing change to {pool_entry['ts_id']}: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") + failed_pushes += 1 + finally: + if remote_conn: + await remote_conn.close() + + if successful_pushes > 0: + info(f"Successfully pushed changes to {successful_pushes} server(s) for {table_name}") + if failed_pushes > 0: + warn(f"Failed to push changes to {failed_pushes} server(s) for {table_name}") + async def push_all_changes(self, table_name: str): online_hosts = await self.get_online_hosts() + tasks = [] + for pool_entry in online_hosts: if pool_entry['ts_id'] != os.environ.get('TS_ID'): - try: - async with self.get_connection(pool_entry) as remote_conn: - if remote_conn is None: - continue + task = asyncio.create_task(self._push_changes_to_host(pool_entry, table_name)) + tasks.append(task) - # Fetch all rows from the local database - async with self.get_connection() as local_conn: - all_rows = await local_conn.fetch(f'SELECT * FROM "{table_name}"') + results = await asyncio.gather(*tasks, return_exceptions=True) + successful_pushes = sum(1 for r in results if r is True) + failed_pushes = sum(1 for r in results if r is False or isinstance(r, Exception)) - if all_rows: - columns = all_rows[0].keys() - placeholders = [f'${i+1}' for i in range(len(columns))] - - insert_query = f""" - INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) - VALUES ({', '.join(placeholders)}) - ON CONFLICT DO NOTHING - """ - - # Use a transaction to insert all rows - async with remote_conn.transaction(): - for row in all_rows: - await remote_conn.execute(insert_query, *row.values()) - except Exception as e: - err(f"Error pushing all changes to {pool_entry['ts_id']}: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") + info(f"Push all changes summary for {table_name}: Successful: {successful_pushes}, Failed: {failed_pushes}") + if failed_pushes > 0: + warn(f"Failed to push all changes to {failed_pushes} server(s). Data may be out of sync.") - async def _execute_special_table_write(self, query: str, *args, table_name: str): - if table_name == 'spatial_ref_sys': - return await self._execute_spatial_ref_sys_write(query, *args) - # Add more special cases as needed + async def _push_changes_to_host(self, pool_entry: Dict[str, Any], table_name: str) -> bool: + remote_conn = None + try: + remote_conn = await self.get_connection(pool_entry) + if remote_conn is None: + warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping push.") + return False - async def _execute_spatial_ref_sys_write(self, query: str, *args): - result = None - async with self.get_connection() as local_conn: + local_conn = await self.get_connection() if local_conn is None: - raise ConnectionError("Failed to connect to local database") - - # Execute the query locally - result = await local_conn.execute(query, *args) + warn(f"Unable to connect to local database. Skipping push.") + return False + + try: + all_rows = await local_conn.fetch(f'SELECT * FROM "{table_name}"') + finally: + await local_conn.close() + + if all_rows: + columns = list(all_rows[0].keys()) + placeholders = [f'${i+1}' for i in range(len(columns))] + + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT DO NOTHING + """ + + async with remote_conn.transaction(): + for row in all_rows: + await remote_conn.execute(insert_query, *row.values()) + + return True + else: + debug(f"No rows to push for table {table_name}") + return True + except Exception as e: + err(f"Error pushing all changes to {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + return False + finally: + if remote_conn: + await remote_conn.close() + + async def _execute_special_table_write(self, conn, query: str, *args, table_name: str): + if table_name == 'spatial_ref_sys': + return await self._execute_spatial_ref_sys_write(conn, query, *args) + + async def _execute_spatial_ref_sys_write(self, local_conn, query: str, *args): + # Execute the query locally + result = await local_conn.execute(query, *args) # Sync the entire spatial_ref_sys table with all online hosts online_hosts = await self.get_online_hosts() for pool_entry in online_hosts: if pool_entry['ts_id'] != os.environ.get('TS_ID'): + remote_conn = await self.get_connection(pool_entry) + if remote_conn is None: + continue try: - async with self.get_connection(pool_entry) as remote_conn: - if remote_conn is None: - continue - await self.sync_spatial_ref_sys(local_conn, remote_conn) + await self.sync_spatial_ref_sys(local_conn, remote_conn) except Exception as e: err(f"Error syncing spatial_ref_sys to {pool_entry['ts_id']}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") + finally: + await remote_conn.close() return result - def get_primary_key(self, table_name: str) -> str: - # This method should return the primary key for the given table - # You might want to cache this information for performance - # For now, we'll assume it's always 'id', but you should implement proper logic here - return 'id' + async def apply_batch_changes(self, conn, table_name, changes, primary_key): + if conn is None or not changes: + debug(f"Skipping apply_batch_changes because conn is none or there are no changes.") + return 0 + + try: + columns = list(changes[0].keys()) + placeholders = [f'${i+1}' for i in range(len(columns))] + + if primary_key: + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT ("{primary_key}") DO UPDATE SET + {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in [primary_key, 'version', 'server_id'])}, + version = EXCLUDED.version, + server_id = EXCLUDED.server_id + WHERE "{table_name}".version < EXCLUDED.version + OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) + """ + else: + warn(f"Possible source of issue #4") + # For tables without a primary key, we'll use all columns for conflict resolution + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT DO NOTHING + """ + + affected_rows = 0 + async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"): + values = [change[col] for col in columns] + result = await conn.execute(insert_query, *values) + affected_rows += int(result.split()[-1]) + + return affected_rows + + except Exception as e: + err(f"Error applying batch changes to {table_name}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + return 0 + + async def sync_special_table(self, source_conn, dest_conn, table_name): + if table_name == 'spatial_ref_sys': + return await self.sync_spatial_ref_sys(source_conn, dest_conn) + # Add more special cases as needed async def sync_spatial_ref_sys(self, source_conn, dest_conn): try: @@ -666,6 +1072,77 @@ class APIConfig(BaseModel): return 0 + def get_primary_key(self, table_name: str) -> str: + # This method should return the primary key for the given table + # You might want to cache this information for performance + # For now, we'll assume it's always 'id', but you should implement proper logic here + return 'id' + + + async def add_primary_keys_to_local_tables(self): + conn = await self.get_connection() + + debug(f"Adding primary keys to existing tables...") + if conn is None: + raise ConnectionError("Failed to connect to local database") + + try: + tables = await conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] + if table_name not in self.SPECIAL_TABLES: + await self.ensure_sync_columns(conn, table_name) + finally: + await conn.close() + + async def add_primary_keys_to_remote_tables(self): + online_hosts = await self.get_online_hosts() + + for pool_entry in online_hosts: + conn = await self.get_connection(pool_entry) + if conn is None: + warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping primary key addition.") + continue + + try: + info(f"Adding primary keys to existing tables on {pool_entry['ts_id']}...") + tables = await conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] + if table_name not in self.SPECIAL_TABLES: + primary_key = await self.ensure_sync_columns(conn, table_name) + if primary_key: + info(f"Added/ensured primary key '{primary_key}' for table '{table_name}' on {pool_entry['ts_id']}") + else: + warn(f"Failed to add/ensure primary key for table '{table_name}' on {pool_entry['ts_id']}") + + info(f"Completed adding primary keys to existing tables on {pool_entry['ts_id']}") + except Exception as e: + err(f"Error adding primary keys to existing tables on {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + finally: + await conn.close() + + + async def close_db_pools(self): + info("Closing database connection pools...") + for pool_key, pool in self.db_pools.items(): + try: + await pool.close() + debug(f"Closed pool for {pool_key}") + except Exception as e: + err(f"Error closing pool for {pool_key}: {str(e)}") + self.db_pools.clear() + info("All database connection pools closed.") + class Location(BaseModel): latitude: float @@ -679,7 +1156,7 @@ class Location(BaseModel): state: Optional[str] = None country: Optional[str] = None context: Optional[Dict[str, Any]] = None - class_: Optional[str] = None + class_: Optional[str] = Field(None, alias="class") type: Optional[str] = None name: Optional[str] = None display_name: Optional[str] = None @@ -697,11 +1174,7 @@ class Location(BaseModel): json_encoders = { datetime: lambda dt: dt.isoformat(), } - - def model_dump(self): - data = self.dict() - data["datetime"] = self.datetime.isoformat() if self.datetime else None - return data + populate_by_name = True class Geocoder: diff --git a/sijapi/routers/gis.py b/sijapi/routers/gis.py index 65bde57..e7b91e8 100644 --- a/sijapi/routers/gis.py +++ b/sijapi/routers/gis.py @@ -120,7 +120,6 @@ async def get_last_location() -> Optional[Location]: return None - async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, datetime, None] = None) -> List[Location]: start_datetime = await dt(start) if end is None: @@ -133,10 +132,24 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, debug(f"Fetching locations between {start_datetime} and {end_datetime}") - async with API.get_connection() as conn: - locations = [] - # Check for records within the specified datetime range - range_locations = await conn.fetch(''' + query = ''' + SELECT id, datetime, + ST_X(ST_AsText(location)::geometry) AS longitude, + ST_Y(ST_AsText(location)::geometry) AS latitude, + ST_Z(ST_AsText(location)::geometry) AS elevation, + city, state, zip, street, + action, device_type, device_model, device_name, device_os + FROM locations + WHERE datetime >= $1 AND datetime <= $2 + ORDER BY datetime DESC + ''' + + locations = await API.execute_read_query(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None), table_name="locations") + + debug(f"Range locations query returned: {locations}") + + if not locations and (end is None or start_datetime.date() == end_datetime.date()): + fallback_query = ''' SELECT id, datetime, ST_X(ST_AsText(location)::geometry) AS longitude, ST_Y(ST_AsText(location)::geometry) AS latitude, @@ -144,30 +157,14 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, city, state, zip, street, action, device_type, device_model, device_name, device_os FROM locations - WHERE datetime >= $1 AND datetime <= $2 + WHERE datetime < $1 ORDER BY datetime DESC - ''', start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None)) - - debug(f"Range locations query returned: {range_locations}") - locations.extend(range_locations) - - if not locations and (end is None or start_datetime.date() == end_datetime.date()): - location_data = await conn.fetchrow(''' - SELECT id, datetime, - ST_X(ST_AsText(location)::geometry) AS longitude, - ST_Y(ST_AsText(location)::geometry) AS latitude, - ST_Z(ST_AsText(location)::geometry) AS elevation, - city, state, zip, street, - action, device_type, device_model, device_name, device_os - FROM locations - WHERE datetime < $1 - ORDER BY datetime DESC - LIMIT 1 - ''', start_datetime.replace(tzinfo=None)) - - debug(f"Fallback query returned: {location_data}") - if location_data: - locations.append(location_data) + LIMIT 1 + ''' + location_data = await API.execute_read_query(fallback_query, start_datetime.replace(tzinfo=None), table_name="locations") + debug(f"Fallback query returned: {location_data}") + if location_data: + locations = location_data debug(f"Locations found: {locations}") @@ -197,35 +194,32 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, return location_objects if location_objects else [] -# Function to fetch the last location before the specified datetime async def fetch_last_location_before(datetime: datetime) -> Optional[Location]: datetime = await dt(datetime) debug(f"Fetching last location before {datetime}") - async with API.get_connection() as conn: + query = ''' + SELECT id, datetime, + ST_X(ST_AsText(location)::geometry) AS longitude, + ST_Y(ST_AsText(location)::geometry) AS latitude, + ST_Z(ST_AsText(location)::geometry) AS elevation, + city, state, zip, street, country, + action + FROM locations + WHERE datetime < $1 + ORDER BY datetime DESC + LIMIT 1 + ''' + + location_data = await API.execute_read_query(query, datetime.replace(tzinfo=None), table_name="locations") - location_data = await conn.fetchrow(''' - SELECT id, datetime, - ST_X(ST_AsText(location)::geometry) AS longitude, - ST_Y(ST_AsText(location)::geometry) AS latitude, - ST_Z(ST_AsText(location)::geometry) AS elevation, - city, state, zip, street, country, - action - FROM locations - WHERE datetime < $1 - ORDER BY datetime DESC - LIMIT 1 - ''', datetime.replace(tzinfo=None)) - - await conn.close() - - if location_data: - debug(f"Last location found: {location_data}") - return Location(**location_data) - else: - debug("No location found before the specified datetime") - return None + if location_data: + debug(f"Last location found: {location_data[0]}") + return Location(**location_data[0]) + else: + debug("No location found before the specified datetime") + return None @gis.get("/map", response_class=HTMLResponse) async def generate_map_endpoint( @@ -247,16 +241,12 @@ async def generate_map_endpoint( return HTMLResponse(content=html_content) async def get_date_range(): - async with API.get_connection() as conn: - query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations" - row = await conn.fetchrow(query) - if row and row['min_date'] and row['max_date']: - return row['min_date'], row['max_date'] - else: - return datetime(2022, 1, 1), datetime.now() - - - + query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations" + row = await API.execute_read_query(query, table_name="locations") + if row and row[0]['min_date'] and row[0]['max_date']: + return row[0]['min_date'], row[0]['max_date'] + else: + return datetime(2022, 1, 1), datetime.now() async def generate_and_save_heatmap( start_date: Union[str, int, datetime], @@ -313,8 +303,6 @@ Generate a heatmap for the given date range and save it as a PNG file using Foli err(f"Error generating and saving heatmap: {str(e)}") raise - - async def generate_map(start_date: datetime, end_date: datetime, max_points: int): locations = await fetch_locations(start_date, end_date) if not locations: @@ -343,8 +331,6 @@ async def generate_map(start_date: datetime, end_date: datetime, max_points: int folium.TileLayer('cartodbdark_matter', name='Dark Mode').add_to(m) - - # In the generate_map function: draw = Draw( draw_options={ 'polygon': True, @@ -433,70 +419,70 @@ map.on(L.Draw.Event.CREATED, function (event) { return m.get_root().render() async def post_location(location: Location): - # if not location.datetime: - # info(f"location appears to be missing datetime: {location}") - # else: - # debug(f"post_location called with {location.datetime}") - async with API.get_connection() as conn: - try: - context = location.context or {} - action = context.get('action', 'manual') - device_type = context.get('device_type', 'Unknown') - device_model = context.get('device_model', 'Unknown') - device_name = context.get('device_name', 'Unknown') - device_os = context.get('device_os', 'Unknown') - - # Parse and localize the datetime - localized_datetime = await dt(location.datetime) + try: + context = location.context or {} + action = context.get('action', 'manual') + device_type = context.get('device_type', 'Unknown') + device_model = context.get('device_model', 'Unknown') + device_name = context.get('device_name', 'Unknown') + device_os = context.get('device_os', 'Unknown') + + # Parse and localize the datetime + localized_datetime = await dt(location.datetime) - await conn.execute(''' - INSERT INTO locations ( - datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os, - class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood, - suburb, county, country_code, country - ) - VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, - $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) - ''', localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state, - location.zip, location.street, action, device_type, device_model, device_name, device_os, - location.class_, location.type, location.name, location.display_name, - location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood, - location.suburb, location.county, location.country_code, location.country) - - await conn.close() - info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}") - return { - 'datetime': localized_datetime, - 'latitude': location.latitude, - 'longitude': location.longitude, - 'elevation': location.elevation, - 'city': location.city, - 'state': location.state, - 'zip': location.zip, - 'street': location.street, - 'action': action, - 'device_type': device_type, - 'device_model': device_model, - 'device_name': device_name, - 'device_os': device_os, - 'class_': location.class_, - 'type': location.type, - 'name': location.name, - 'display_name': location.display_name, - 'amenity': location.amenity, - 'house_number': location.house_number, - 'road': location.road, - 'quarter': location.quarter, - 'neighbourhood': location.neighbourhood, - 'suburb': location.suburb, - 'county': location.county, - 'country_code': location.country_code, - 'country': location.country - } - except Exception as e: - err(f"Error posting location {e}") - err(traceback.format_exc()) - return None + query = ''' + INSERT INTO locations ( + datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os, + class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood, + suburb, county, country_code, country + ) + VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, + $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) + ''' + + await API.execute_write_query( + query, + localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state, + location.zip, location.street, action, device_type, device_model, device_name, device_os, + location.class_, location.type, location.name, location.display_name, + location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood, + location.suburb, location.county, location.country_code, location.country, + table_name="locations" + ) + + info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}") + return { + 'datetime': localized_datetime, + 'latitude': location.latitude, + 'longitude': location.longitude, + 'elevation': location.elevation, + 'city': location.city, + 'state': location.state, + 'zip': location.zip, + 'street': location.street, + 'action': action, + 'device_type': device_type, + 'device_model': device_model, + 'device_name': device_name, + 'device_os': device_os, + 'class_': location.class_, + 'type': location.type, + 'name': location.name, + 'display_name': location.display_name, + 'amenity': location.amenity, + 'house_number': location.house_number, + 'road': location.road, + 'quarter': location.quarter, + 'neighbourhood': location.neighbourhood, + 'suburb': location.suburb, + 'county': location.county, + 'country_code': location.country_code, + 'country': location.country + } + except Exception as e: + err(f"Error posting location {e}") + err(traceback.format_exc()) + return None @gis.post("/locate") @@ -553,6 +539,7 @@ async def get_last_location_endpoint() -> JSONResponse: raise HTTPException(status_code=404, detail="No location found before the specified datetime") + @gis.get("/locate/{datetime_str}", response_model=List[Location]) async def get_locate(datetime_str: str, all: bool = False): try: diff --git a/sijapi/routers/serve.py b/sijapi/routers/serve.py index cb596a4..e87eb49 100644 --- a/sijapi/routers/serve.py +++ b/sijapi/routers/serve.py @@ -212,7 +212,6 @@ if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True: except requests.exceptions.RequestException: results.append(f"{address} is down") - # Generate a simple text-based graph graph = '|' * up_count + '.' * (len(addresses) - up_count) text_update = "\n".join(results) @@ -220,7 +219,6 @@ if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True: output = shellfish_run_widget_command(widget_command) return {"output": output, "graph": graph} - def shellfish_update_widget(update: WidgetUpdate): widget_command = ["widget"] @@ -290,6 +288,7 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True: for result in results: bg_tasks.add_task(cl_docket_process, result) return JSONResponse(content={"message": "Received"}, status_code=status.HTTP_200_OK) + async def cl_docket_process(result): async with httpx.AsyncClient() as session: @@ -346,11 +345,13 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True: await cl_download_file(file_url, target_path, session) debug(f"Downloaded {file_name} to {target_path}") + def cl_case_details(docket): case_info = CASETABLE.get(str(docket), {"code": "000", "shortname": "UNKNOWN"}) case_code = case_info.get("code") short_name = case_info.get("shortname") return case_code, short_name + async def cl_download_file(url: str, path: Path, session: aiohttp.ClientSession = None): headers = { @@ -417,19 +418,20 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True: await cl_download_file(download_url, target_path, session) debug(f"Downloaded {file_name} to {target_path}") + @serve.get("/s", response_class=HTMLResponse) async def shortener_form(request: Request): return templates.TemplateResponse("shortener.html", {"request": request}) + @serve.post("/s") async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)): - await create_tables() if custom_code: if len(custom_code) != 3 or not custom_code.isalnum(): return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"}) - existing = await API.execute_write_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls") + existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls") if existing: return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"}) @@ -437,8 +439,9 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c else: chars = string.ascii_letters + string.digits while True: + debug(f"FOUND THE ISSUE") short_code = ''.join(random.choice(chars) for _ in range(3)) - existing = await API.execute_write_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls") + existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls") if not existing: break @@ -451,48 +454,36 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c short_url = f"https://sij.ai/{short_code}" return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url}) -async def create_tables(): - await API.execute_write_query(''' - CREATE TABLE IF NOT EXISTS short_urls ( - short_code VARCHAR(3) PRIMARY KEY, - long_url TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - ''', table_name="short_urls") - await API.execute_write_query(''' - CREATE TABLE IF NOT EXISTS click_logs ( - id SERIAL PRIMARY KEY, - short_code VARCHAR(3) REFERENCES short_urls(short_code), - clicked_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - ip_address TEXT, - user_agent TEXT - ) - ''', table_name="click_logs") -@serve.get("/{short_code}", response_class=RedirectResponse, status_code=301) -async def redirect_short_url(request: Request, short_code: str = PathParam(..., min_length=3, max_length=3)): - if request.headers.get('host') != 'sij.ai': - raise HTTPException(status_code=404, detail="Not Found") - - result = await API.execute_write_query( +@serve.get("/{short_code}") +async def redirect_short_url(short_code: str): + results = await API.execute_read_query( 'SELECT long_url FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls" ) - if result: - await API.execute_write_query( - 'INSERT INTO click_logs (short_code, ip_address, user_agent) VALUES ($1, $2, $3)', - short_code, request.client.host, request.headers.get("user-agent"), - table_name="click_logs" - ) - return result['long_url'] - else: + if not results: raise HTTPException(status_code=404, detail="Short URL not found") + + long_url = results[0].get('long_url') + + if not long_url: + raise HTTPException(status_code=404, detail="Long URL not found") + + # Increment click count (you may want to do this asynchronously) + await API.execute_write_query( + 'INSERT INTO click_logs (short_code, clicked_at) VALUES ($1, $2)', + short_code, datetime.now(), + table_name="click_logs" + ) + + return RedirectResponse(url=long_url) + @serve.get("/analytics/{short_code}") async def get_analytics(short_code: str): - url_info = await API.execute_write_query( + url_info = await API.execute_read_query( 'SELECT long_url, created_at FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls" @@ -500,13 +491,13 @@ async def get_analytics(short_code: str): if not url_info: raise HTTPException(status_code=404, detail="Short URL not found") - click_count = await API.execute_write_query( + click_count = await API.execute_read_query( 'SELECT COUNT(*) FROM click_logs WHERE short_code = $1', short_code, table_name="click_logs" ) - clicks = await API.execute_write_query( + clicks = await API.execute_read_query( 'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100', short_code, table_name="click_logs" @@ -521,15 +512,20 @@ async def get_analytics(short_code: str): } - - async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, destination: str): - dest_host, dest_port = destination.split(':') - dest_port = int(dest_port) + try: + dest_host, dest_port = destination.split(':') + dest_port = int(dest_port) + except ValueError: + warn(f"Invalid destination format: {destination}. Expected 'host:port'.") + writer.close() + await writer.wait_closed() + return try: dest_reader, dest_writer = await asyncio.open_connection(dest_host, dest_port) except Exception as e: + warn(f"Failed to connect to destination {destination}: {str(e)}") writer.close() await writer.wait_closed() return @@ -543,7 +539,7 @@ async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWr dst.write(data) await dst.drain() except Exception as e: - pass + warn(f"Error in forwarding: {str(e)}") finally: dst.close() await dst.wait_closed() @@ -554,8 +550,12 @@ async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWr ) async def start_server(source: str, destination: str): - host, port = source.split(':') - port = int(port) + if ':' in source: + host, port = source.split(':') + port = int(port) + else: + host = source + port = 80 server = await asyncio.start_server( lambda r, w: forward_traffic(r, w, destination), @@ -566,6 +566,7 @@ async def start_server(source: str, destination: str): async with server: await server.serve_forever() + async def start_port_forwarding(): if hasattr(Serve, 'forwarding_rules'): for rule in Serve.forwarding_rules: @@ -573,6 +574,7 @@ async def start_port_forwarding(): else: warn("No forwarding rules found in the configuration.") + @serve.get("/forward_status") async def get_forward_status(): if hasattr(Serve, 'forwarding_rules'): @@ -580,5 +582,5 @@ async def get_forward_status(): else: return {"status": "inactive", "message": "No forwarding rules configured"} -# Add this to the end of your serve.py file + asyncio.create_task(start_port_forwarding()) \ No newline at end of file diff --git a/sijapi/routers/weather.py b/sijapi/routers/weather.py index a457fa1..55b091e 100644 --- a/sijapi/routers/weather.py +++ b/sijapi/routers/weather.py @@ -116,129 +116,129 @@ async def get_weather(date_time: dt_datetime, latitude: float, longitude: float, async def store_weather_to_db(date_time: dt_datetime, weather_data: dict): warn(f"Using {date_time.strftime('%Y-%m-%d %H:%M:%S')} as our datetime in store_weather_to_db") - async with API.get_connection() as conn: - try: - day_data = weather_data.get('days')[0] - debug(f"RAW DAY_DATA: {day_data}") - # Handle preciptype and stations as PostgreSQL arrays - preciptype_array = day_data.get('preciptype', []) or [] - stations_array = day_data.get('stations', []) or [] + try: + day_data = weather_data.get('days')[0] + debug(f"RAW DAY_DATA: {day_data}") + # Handle preciptype and stations as PostgreSQL arrays + preciptype_array = day_data.get('preciptype', []) or [] + stations_array = day_data.get('stations', []) or [] - date_str = date_time.strftime("%Y-%m-%d") - warn(f"Using {date_str} in our query in store_weather_to_db.") + date_str = date_time.strftime("%Y-%m-%d") + warn(f"Using {date_str} in our query in store_weather_to_db.") - # Get location details from weather data if available - longitude = weather_data.get('longitude') - latitude = weather_data.get('latitude') - tz = await GEO.tz_at(latitude, longitude) - elevation = await GEO.elevation(latitude, longitude) - location_point = f"POINTZ({longitude} {latitude} {elevation})" if longitude and latitude and elevation else None + # Get location details from weather data if available + longitude = weather_data.get('longitude') + latitude = weather_data.get('latitude') + tz = await GEO.tz_at(latitude, longitude) + elevation = await GEO.elevation(latitude, longitude) + location_point = f"POINTZ({longitude} {latitude} {elevation})" if longitude and latitude and elevation else None - warn(f"Uncorrected datetimes in store_weather_to_db: {day_data['datetime']}, sunrise: {day_data['sunrise']}, sunset: {day_data['sunset']}") - day_data['datetime'] = await gis.dt(day_data.get('datetimeEpoch')) - day_data['sunrise'] = await gis.dt(day_data.get('sunriseEpoch')) - day_data['sunset'] = await gis.dt(day_data.get('sunsetEpoch')) - warn(f"Corrected datetimes in store_weather_to_db: {day_data['datetime']}, sunrise: {day_data['sunrise']}, sunset: {day_data['sunset']}") + warn(f"Uncorrected datetimes in store_weather_to_db: {day_data['datetime']}, sunrise: {day_data['sunrise']}, sunset: {day_data['sunset']}") + day_data['datetime'] = await gis.dt(day_data.get('datetimeEpoch')) + day_data['sunrise'] = await gis.dt(day_data.get('sunriseEpoch')) + day_data['sunset'] = await gis.dt(day_data.get('sunsetEpoch')) + warn(f"Corrected datetimes in store_weather_to_db: {day_data['datetime']}, sunrise: {day_data['sunrise']}, sunset: {day_data['sunset']}") - daily_weather_params = ( - day_data.get('sunrise'), day_data.get('sunriseEpoch'), - day_data.get('sunset'), day_data.get('sunsetEpoch'), - day_data.get('description'), day_data.get('tempmax'), - day_data.get('tempmin'), day_data.get('uvindex'), - day_data.get('winddir'), day_data.get('windspeed'), - day_data.get('icon'), dt_datetime.now(tz), - day_data.get('datetime'), day_data.get('datetimeEpoch'), - day_data.get('temp'), day_data.get('feelslikemax'), - day_data.get('feelslikemin'), day_data.get('feelslike'), - day_data.get('dew'), day_data.get('humidity'), - day_data.get('precip'), day_data.get('precipprob'), - day_data.get('precipcover'), preciptype_array, - day_data.get('snow'), day_data.get('snowdepth'), - day_data.get('windgust'), day_data.get('pressure'), - day_data.get('cloudcover'), day_data.get('visibility'), - day_data.get('solarradiation'), day_data.get('solarenergy'), - day_data.get('severerisk', 0), day_data.get('moonphase'), - day_data.get('conditions'), stations_array, day_data.get('source'), - location_point - ) - except Exception as e: - err(f"Failed to prepare database query in store_weather_to_db! {e}") - - try: - daily_weather_query = ''' - INSERT INTO DailyWeather ( - sunrise, sunriseepoch, sunset, sunsetepoch, description, - tempmax, tempmin, uvindex, winddir, windspeed, icon, last_updated, - datetime, datetimeepoch, temp, feelslikemax, feelslikemin, feelslike, - dew, humidity, precip, precipprob, precipcover, preciptype, - snow, snowdepth, windgust, pressure, cloudcover, visibility, - solarradiation, solarenergy, severerisk, moonphase, conditions, - stations, source, location - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38) - RETURNING id - ''' + daily_weather_params = ( + day_data.get('sunrise'), day_data.get('sunriseEpoch'), + day_data.get('sunset'), day_data.get('sunsetEpoch'), + day_data.get('description'), day_data.get('tempmax'), + day_data.get('tempmin'), day_data.get('uvindex'), + day_data.get('winddir'), day_data.get('windspeed'), + day_data.get('icon'), dt_datetime.now(tz), + day_data.get('datetime'), day_data.get('datetimeEpoch'), + day_data.get('temp'), day_data.get('feelslikemax'), + day_data.get('feelslikemin'), day_data.get('feelslike'), + day_data.get('dew'), day_data.get('humidity'), + day_data.get('precip'), day_data.get('precipprob'), + day_data.get('precipcover'), preciptype_array, + day_data.get('snow'), day_data.get('snowdepth'), + day_data.get('windgust'), day_data.get('pressure'), + day_data.get('cloudcover'), day_data.get('visibility'), + day_data.get('solarradiation'), day_data.get('solarenergy'), + day_data.get('severerisk', 0), day_data.get('moonphase'), + day_data.get('conditions'), stations_array, day_data.get('source'), + location_point + ) + except Exception as e: + err(f"Failed to prepare database query in store_weather_to_db! {e}") + return "FAILURE" + + try: + daily_weather_query = ''' + INSERT INTO DailyWeather ( + sunrise, sunriseepoch, sunset, sunsetepoch, description, + tempmax, tempmin, uvindex, winddir, windspeed, icon, last_updated, + datetime, datetimeepoch, temp, feelslikemax, feelslikemin, feelslike, + dew, humidity, precip, precipprob, precipcover, preciptype, + snow, snowdepth, windgust, pressure, cloudcover, visibility, + solarradiation, solarenergy, severerisk, moonphase, conditions, + stations, source, location + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38) + RETURNING id + ''' + + daily_weather_id = await API.execute_write_query(daily_weather_query, *daily_weather_params, table_name="DailyWeather") + + if 'hours' in day_data: + debug(f"Processing hours now...") + for hour_data in day_data['hours']: + try: + await asyncio.sleep(0.01) + hour_data['datetime'] = await gis.dt(hour_data.get('datetimeEpoch')) + hour_preciptype_array = hour_data.get('preciptype', []) or [] + hour_stations_array = hour_data.get('stations', []) or [] + hourly_weather_params = ( + daily_weather_id, + hour_data['datetime'], + hour_data.get('datetimeEpoch'), + hour_data['temp'], + hour_data['feelslike'], + hour_data['humidity'], + hour_data['dew'], + hour_data['precip'], + hour_data['precipprob'], + hour_preciptype_array, + hour_data['snow'], + hour_data['snowdepth'], + hour_data['windgust'], + hour_data['windspeed'], + hour_data['winddir'], + hour_data['pressure'], + hour_data['cloudcover'], + hour_data['visibility'], + hour_data['solarradiation'], + hour_data['solarenergy'], + hour_data['uvindex'], + hour_data.get('severerisk', 0), + hour_data['conditions'], + hour_data['icon'], + hour_stations_array, + hour_data.get('source', ''), + ) - async with conn.transaction(): - daily_weather_id = await conn.fetchval(daily_weather_query, *daily_weather_params) - - if 'hours' in day_data: - debug(f"Processing hours now...") - for hour_data in day_data['hours']: try: - await asyncio.sleep(0.01) - hour_data['datetime'] = await gis.dt(hour_data.get('datetimeEpoch')) - hour_preciptype_array = hour_data.get('preciptype', []) or [] - hour_stations_array = hour_data.get('stations', []) or [] - hourly_weather_params = ( - daily_weather_id, - hour_data['datetime'], - hour_data.get('datetimeEpoch'), - hour_data['temp'], - hour_data['feelslike'], - hour_data['humidity'], - hour_data['dew'], - hour_data['precip'], - hour_data['precipprob'], - hour_preciptype_array, - hour_data['snow'], - hour_data['snowdepth'], - hour_data['windgust'], - hour_data['windspeed'], - hour_data['winddir'], - hour_data['pressure'], - hour_data['cloudcover'], - hour_data['visibility'], - hour_data['solarradiation'], - hour_data['solarenergy'], - hour_data['uvindex'], - hour_data.get('severerisk', 0), - hour_data['conditions'], - hour_data['icon'], - hour_stations_array, - hour_data.get('source', ''), - ) - - try: - hourly_weather_query = ''' - INSERT INTO HourlyWeather (daily_weather_id, datetime, datetimeepoch, temp, feelslike, humidity, dew, precip, precipprob, - preciptype, snow, snowdepth, windgust, windspeed, winddir, pressure, cloudcover, visibility, solarradiation, solarenergy, - uvindex, severerisk, conditions, icon, stations, source) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) - RETURNING id - ''' - async with conn.transaction(): - hourly_weather_id = await conn.fetchval(hourly_weather_query, *hourly_weather_params) - debug(f"Done processing hourly_weather_id {hourly_weather_id}") - except Exception as e: - err(f"EXCEPTION: {e}") - + hourly_weather_query = ''' + INSERT INTO HourlyWeather (daily_weather_id, datetime, datetimeepoch, temp, feelslike, humidity, dew, precip, precipprob, + preciptype, snow, snowdepth, windgust, windspeed, winddir, pressure, cloudcover, visibility, solarradiation, solarenergy, + uvindex, severerisk, conditions, icon, stations, source) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) + RETURNING id + ''' + hourly_weather_id = await API.execute_write_query(hourly_weather_query, *hourly_weather_params, table_name="HourlyWeather") + debug(f"Done processing hourly_weather_id {hourly_weather_id}") except Exception as e: err(f"EXCEPTION: {e}") - return "SUCCESS" - - except Exception as e: - err(f"Error in dailyweather storage: {e}") + except Exception as e: + err(f"EXCEPTION: {e}") + + return "SUCCESS" + + except Exception as e: + err(f"Error in dailyweather storage: {e}") + return "FAILURE" +