From 299d27b1e773f1ee3bdb47e55dc824e4f176674f Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Thu, 1 Aug 2024 23:51:01 -0700 Subject: [PATCH] Much more efficient database sync method, first working version. --- sijapi/__init__.py | 1 + sijapi/__main__.py | 70 ++-- sijapi/classes.py | 801 ++++++++++++++++++++++++++++++-------- sijapi/routers/gis.py | 243 ++++++------ sijapi/routers/loc.py | 393 ------------------- sijapi/routers/serve.py | 96 ++--- sijapi/routers/weather.py | 228 +++++------ 7 files changed, 950 insertions(+), 882 deletions(-) delete mode 100644 sijapi/routers/loc.py diff --git a/sijapi/__init__.py b/sijapi/__init__.py index bdb64a4..c2a14d3 100644 --- a/sijapi/__init__.py +++ b/sijapi/__init__.py @@ -20,6 +20,7 @@ L = Logger("Central", LOGS_DIR) # API essentials API = APIConfig.load('api', 'secrets') + Dir = Configuration.load('dirs') HOST = f"{API.BIND}:{API.PORT}" LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost'] diff --git a/sijapi/__main__.py b/sijapi/__main__.py index 27eae12..e609014 100755 --- a/sijapi/__main__.py +++ b/sijapi/__main__.py @@ -39,12 +39,11 @@ def warn(text: str): logger.warning(text) def err(text: str): logger.error(text) def crit(text: str): logger.critical(text) - @asynccontextmanager async def lifespan(app: FastAPI): # Startup crit("sijapi launched") - crit(f"Arguments: {args}") + info(f"Arguments: {args}") # Load routers if args.test: @@ -54,20 +53,10 @@ async def lifespan(app: FastAPI): if getattr(API.MODULES, module_name): load_router(module_name) - crit("Starting database synchronization...") try: # Initialize sync structures on all databases await API.initialize_sync() - - # Check if other instances have more recent data - source = await API.get_most_recent_source() - if source: - crit(f"Pulling changes from {source['ts_id']} ({source['ts_ip']})...") - total_changes = await API.pull_changes(source) - crit(f"Data pull complete. Total changes: {total_changes}") - else: - crit("No instances with more recent data found or all instances are offline.") - + except Exception as e: crit(f"Error during startup: {str(e)}") crit(f"Traceback: {traceback.format_exc()}") @@ -79,8 +68,6 @@ async def lifespan(app: FastAPI): await API.close_db_pools() crit("Database pools closed.") - - app = FastAPI(lifespan=lifespan) app.add_middleware( @@ -135,30 +122,41 @@ async def handle_exception_middleware(request: Request, call_next): return response -# This was removed on 7/31/2024 when we decided to instead use a targeted push sync approach. -deprecated = ''' -async def push_changes_background(): +@app.post("/sync/pull") +async def pull_changes(): try: - await API.push_changes_to_all() - except Exception as e: - err(f"Error pushing changes to other databases: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") - - -@app.middleware("http") -async def sync_middleware(request: Request, call_next): - response = await call_next(request) - - # Check if the request was a database write operation - if request.method in ["POST", "PUT", "PATCH", "DELETE"]: + await API.add_primary_keys_to_local_tables() + await API.add_primary_keys_to_remote_tables() try: - # Push changes to other databases - await API.push_changes_to_all() + + source = await API.get_most_recent_source() + + if source: + # Pull changes from the source + total_changes = await API.pull_changes(source) + + return JSONResponse(content={ + "status": "success", + "message": f"Pull complete. Total changes: {total_changes}", + "source": f"{source['ts_id']} ({source['ts_ip']})", + "changes": total_changes + }) + else: + return JSONResponse(content={ + "status": "info", + "message": "No instances with more recent data found or all instances are offline." + }) + except Exception as e: - err(f"Error pushing changes to other databases: {str(e)}") - - return response -''' + err(f"Error during pull: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Error during pull: {str(e)}") + + except Exception as e: + err(f"Error while ensuring primary keys to tables: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Error during primary key insurance: {str(e)}") + def load_router(router_name): router_file = ROUTER_DIR / f'{router_name}.py' diff --git a/sijapi/classes.py b/sijapi/classes.py index 191a564..ed7e5fd 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -5,6 +5,7 @@ import math import os import re import uuid +import time import aiofiles import aiohttp import asyncio @@ -180,14 +181,18 @@ class APIConfig(BaseModel): TZ: str KEYS: List[str] GARBAGE: Dict[str, Any] - SPECIAL_TABLES: ClassVar[List[str]] = ['spatial_ref_sys'] - db_pools: Dict[str, Any] = Field(default_factory=dict) + offline_servers: Dict[str, float] = Field(default_factory=dict) + offline_timeout: float = Field(default=30.0) # 30 second timeout for offline servers + online_hosts_cache: Dict[str, Tuple[List[Dict[str, Any]], float]] = Field(default_factory=dict) + online_hosts_cache_ttl: float = Field(default=30.0) # Cache TTL in seconds def __init__(self, **data): super().__init__(**data) - self._db_pools = {} + self.db_pools = {} + self.online_hosts_cache = {} # Initialize the cache + self._sync_tasks = {} class Config: arbitrary_types_allowed = True @@ -306,13 +311,48 @@ class APIConfig(BaseModel): raise ValueError(f"No database configuration found for TS_ID: {ts_id}") return local_db - @asynccontextmanager + async def get_online_hosts(self) -> List[Dict[str, Any]]: + current_time = time.time() + cache_key = "online_hosts" + + if cache_key in self.online_hosts_cache: + cached_hosts, cache_time = self.online_hosts_cache[cache_key] + if current_time - cache_time < self.online_hosts_cache_ttl: + return cached_hosts + + online_hosts = [] + local_ts_id = os.environ.get('TS_ID') + + for pool_entry in self.POOL: + if pool_entry['ts_id'] != local_ts_id: + pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}" + if pool_key in self.offline_servers: + if current_time - self.offline_servers[pool_key] < self.offline_timeout: + continue + else: + del self.offline_servers[pool_key] + + conn = await self.get_connection(pool_entry) + if conn is not None: + online_hosts.append(pool_entry) + await conn.close() + + self.online_hosts_cache[cache_key] = (online_hosts, current_time) + return online_hosts + async def get_connection(self, pool_entry: Dict[str, Any] = None): if pool_entry is None: pool_entry = self.local_db pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}" + # Check if the server is marked as offline + if pool_key in self.offline_servers: + if time.time() - self.offline_servers[pool_key] < self.offline_timeout: + return None + else: + del self.offline_servers[pool_key] + if pool_key not in self.db_pools: try: self.db_pools[pool_key] = await asyncpg.create_pool( @@ -326,27 +366,21 @@ class APIConfig(BaseModel): timeout=5 ) except Exception as e: - err(f"Failed to create connection pool for {pool_key}: {str(e)}") - yield None - return + warn(f"Failed to create connection pool for {pool_key}: {str(e)}") + self.offline_servers[pool_key] = time.time() + return None try: - async with self.db_pools[pool_key].acquire() as conn: - yield conn + return await asyncio.wait_for(self.db_pools[pool_key].acquire(), timeout=5) + except asyncio.TimeoutError: + warn(f"Timeout acquiring connection from pool for {pool_key}") + self.offline_servers[pool_key] = time.time() + return None except Exception as e: - err(f"Failed to acquire connection from pool for {pool_key}: {str(e)}") - yield None + warn(f"Failed to acquire connection for {pool_key}: {str(e)}") + self.offline_servers[pool_key] = time.time() + return None - async def close_db_pools(self): - info("Closing database connection pools...") - for pool_key, pool in self.db_pools.items(): - try: - await pool.close() - debug(f"Closed pool for {pool_key}") - except Exception as e: - err(f"Error closing pool for {pool_key}: {str(e)}") - self.db_pools.clear() - info("All database connection pools closed.") async def initialize_sync(self): local_ts_id = os.environ.get('TS_ID') @@ -356,52 +390,60 @@ class APIConfig(BaseModel): if pool_entry['ts_id'] == local_ts_id: continue # Skip local database try: - async with self.get_connection(pool_entry) as conn: - if conn is None: - continue # Skip this database if connection failed - - debug(f"Starting sync initialization for {pool_entry['ts_ip']}...") - - # Check PostGIS installation - postgis_installed = await self.check_postgis(conn) - if not postgis_installed: - warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.") - - tables = await conn.fetch(""" - SELECT tablename FROM pg_tables - WHERE schemaname = 'public' - """) - - for table in tables: - table_name = table['tablename'] - await self.ensure_sync_columns(conn, table_name) - - debug(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have necessary sync columns and triggers.") + conn = await self.get_connection(pool_entry) + if conn is None: + continue # Skip this database if connection failed + + debug(f"Starting sync initialization for {pool_entry['ts_ip']}...") + + # Check PostGIS installation + postgis_installed = await self.check_postgis(conn) + if not postgis_installed: + warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.") + + tables = await conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] + await self.ensure_sync_columns(conn, table_name) + + debug(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have necessary sync columns and triggers.") except Exception as e: err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") + + def _schedule_sync_task(self, table_name: str, pk_value: Any, version: int, server_id: str): + # Use a background task manager to handle syncing + task_key = f"{table_name}:{pk_value}" if pk_value else table_name + if task_key not in self._sync_tasks: + self._sync_tasks[task_key] = asyncio.create_task(self._sync_changes(table_name, pk_value, version, server_id)) + async def ensure_sync_columns(self, conn, table_name): - if conn is None: - debug(f"Skipping offline server...") - return None - - if table_name in self.SPECIAL_TABLES: - debug(f"Skipping sync columns for special table: {table_name}") + if conn is None or table_name in self.SPECIAL_TABLES: return None try: - # Get primary key information + # Check if primary key exists primary_key = await conn.fetchval(f""" SELECT a.attname FROM pg_index i - JOIN pg_attribute a ON a.attrelid = i.indrelid - AND a.attnum = ANY(i.indkey) - WHERE i.indrelid = '{table_name}'::regclass - AND i.indisprimary; + JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = '{table_name}'::regclass AND i.indisprimary; """) + if not primary_key: + # Add an id column as primary key if it doesn't exist + await conn.execute(f""" + ALTER TABLE "{table_name}" + ADD COLUMN IF NOT EXISTS id SERIAL PRIMARY KEY; + """) + primary_key = 'id' + # Ensure version column exists await conn.execute(f""" ALTER TABLE "{table_name}" @@ -445,21 +487,11 @@ class APIConfig(BaseModel): debug(f"Successfully ensured sync columns and trigger for table {table_name}") return primary_key - + except Exception as e: err(f"Error ensuring sync columns for table {table_name}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") - - async def get_online_hosts(self) -> List[Dict[str, Any]]: - online_hosts = [] - for pool_entry in self.POOL: - try: - async with self.get_connection(pool_entry) as conn: - if conn is not None: - online_hosts.append(pool_entry) - except Exception as e: - err(f"Error checking host {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}") - return online_hosts + return None async def check_postgis(self, conn): if conn is None: @@ -477,134 +509,508 @@ class APIConfig(BaseModel): except Exception as e: err(f"Error checking PostGIS: {str(e)}") return False + - async def execute_write_query(self, query: str, *args, table_name: str): - if table_name in self.SPECIAL_TABLES: - return await self._execute_special_table_write(query, *args, table_name=table_name) + async def pull_changes(self, source_pool_entry, batch_size=10000): + if source_pool_entry['ts_id'] == os.environ.get('TS_ID'): + debug("Skipping self-sync") + return 0 - async with self.get_connection() as conn: - if conn is None: - raise ConnectionError("Failed to connect to local database") + total_changes = 0 + source_id = source_pool_entry['ts_id'] + source_ip = source_pool_entry['ts_ip'] + dest_id = os.environ.get('TS_ID') + dest_ip = self.local_db['ts_ip'] - # Ensure sync columns exist - primary_key = await self.ensure_sync_columns(conn, table_name) + info(f"Starting sync from source {source_id} ({source_ip}) to destination {dest_id} ({dest_ip})") - # Execute the query - result = await conn.execute(query, *args) + source_conn = None + dest_conn = None + try: + source_conn = await self.get_connection(source_pool_entry) + if source_conn is None: + warn(f"Unable to connect to source {source_id} ({source_ip}). Skipping sync.") + return 0 - # Get the primary key and new version of the affected row - if primary_key: - affected_row = await conn.fetchrow(f""" - SELECT "{primary_key}", version, server_id - FROM "{table_name}" - WHERE version = (SELECT MAX(version) FROM "{table_name}") - """) - if affected_row: - await self.push_change(table_name, affected_row[primary_key], affected_row['version'], affected_row['server_id']) - else: - # For tables without a primary key, we'll push all rows - await self.push_all_changes(table_name) + dest_conn = await self.get_connection(self.local_db) + if dest_conn is None: + warn(f"Unable to connect to local database. Skipping sync.") + return 0 - return result - - async def push_change(self, table_name: str, pk_value: Any, version: int, server_id: str): - online_hosts = await self.get_online_hosts() - for pool_entry in online_hosts: - if pool_entry['ts_id'] != os.environ.get('TS_ID'): + tables = await source_conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] try: - async with self.get_connection(pool_entry) as remote_conn: + if table_name in self.SPECIAL_TABLES: + await self.sync_special_table(source_conn, dest_conn, table_name) + else: + primary_key = await self.ensure_sync_columns(dest_conn, table_name) + last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id) + + changes = await source_conn.fetch(f""" + SELECT * FROM "{table_name}" + WHERE version > $1 AND server_id = $2 + ORDER BY version ASC + LIMIT $3 + """, last_synced_version, source_id, batch_size) + + if changes: + changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, primary_key) + total_changes += changes_count + + if changes_count > 0: + info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}") + else: + debug(f"No changes to sync for {table_name}") + + except Exception as e: + err(f"Error syncing table {table_name}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + + info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}") + + except Exception as e: + err(f"Error during sync process: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + + finally: + if source_conn: + await source_conn.close() + if dest_conn: + await dest_conn.close() + + info(f"Sync summary:") + info(f" Total changes: {total_changes}") + info(f" Tables synced: {len(tables) if 'tables' in locals() else 0}") + info(f" Source: {source_id} ({source_ip})") + info(f" Destination: {dest_id} ({dest_ip})") + + return total_changes + + + async def get_last_synced_version(self, conn, table_name, server_id): + if conn is None: + debug(f"Skipping offline server...") + return 0 + + if table_name in self.SPECIAL_TABLES: + debug(f"Skipping get_last_synced_version because {table_name} is special.") + return 0 # Special handling for tables without version column + + try: + last_version = await conn.fetchval(f""" + SELECT COALESCE(MAX(version), 0) + FROM "{table_name}" + WHERE server_id = $1 + """, server_id) + return last_version + except Exception as e: + err(f"Error getting last synced version for table {table_name}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + return 0 + + + async def get_most_recent_source(self): + most_recent_source = None + max_version = -1 + local_ts_id = os.environ.get('TS_ID') + online_hosts = await self.get_online_hosts() + num_online_hosts = len(online_hosts) + + if num_online_hosts > 0: + online_ts_ids = [host['ts_id'] for host in online_hosts if host['ts_id'] != local_ts_id] + crit(f"Online hosts: {', '.join(online_ts_ids)}") + + for pool_entry in online_hosts: + if pool_entry['ts_id'] == local_ts_id: + continue # Skip local database + + try: + conn = await self.get_connection(pool_entry) + if conn is None: + warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping.") + continue + + tables = await conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] + if table_name in self.SPECIAL_TABLES: + continue + try: + result = await conn.fetchrow(f""" + SELECT MAX(version) as max_version, server_id + FROM "{table_name}" + WHERE version = (SELECT MAX(version) FROM "{table_name}") + GROUP BY server_id + ORDER BY MAX(version) DESC + LIMIT 1 + """) + if result: + version, server_id = result['max_version'], result['server_id'] + info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})") + if version > max_version: + max_version = version + most_recent_source = pool_entry + else: + debug(f"No data in table {table_name} for {pool_entry['ts_id']}") + except asyncpg.exceptions.UndefinedColumnError: + warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping.") + except Exception as e: + err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}") + + except asyncpg.exceptions.ConnectionFailureError: + warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}. Skipping.") + except Exception as e: + err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + finally: + if conn: + await conn.close() + + if most_recent_source is None: + if num_online_hosts > 0: + warn("Could not determine most recent source. Using first available online host.") + most_recent_source = next(host for host in online_hosts if host['ts_id'] != local_ts_id) + else: + crit("No other online hosts available for sync.") + + return most_recent_source + + + + async def _sync_changes(self, table_name: str, primary_key: str): + try: + local_conn = await self.get_connection() + if local_conn is None: + return + + # Get the latest changes + changes = await local_conn.fetch(f""" + SELECT * FROM "{table_name}" + WHERE version > (SELECT COALESCE(MAX(version), 0) FROM "{table_name}" WHERE server_id != $1) + OR (version = (SELECT COALESCE(MAX(version), 0) FROM "{table_name}" WHERE server_id != $1) AND server_id = $1) + ORDER BY version ASC + """, os.environ.get('TS_ID')) + + if changes: + online_hosts = await self.get_online_hosts() + for pool_entry in online_hosts: + if pool_entry['ts_id'] != os.environ.get('TS_ID'): + remote_conn = await self.get_connection(pool_entry) if remote_conn is None: continue + try: + await self.apply_batch_changes(remote_conn, table_name, changes, primary_key) + finally: + await remote_conn.close() + except Exception as e: + err(f"Error syncing changes for {table_name}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + finally: + if 'local_conn' in locals(): + await local_conn.close() - # Fetch the updated row from the local database - async with self.get_connection() as local_conn: - updated_row = await local_conn.fetchrow(f'SELECT * FROM "{table_name}" WHERE "{self.get_primary_key(table_name)}" = $1', pk_value) - if updated_row: - columns = updated_row.keys() - placeholders = [f'${i+1}' for i in range(len(columns))] - primary_key = self.get_primary_key(table_name) - - insert_query = f""" - INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) - VALUES ({', '.join(placeholders)}) - ON CONFLICT ("{primary_key}") DO UPDATE SET - {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != primary_key)}, - version = EXCLUDED.version, - server_id = EXCLUDED.server_id - WHERE "{table_name}".version < EXCLUDED.version - OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) - """ - await remote_conn.execute(insert_query, *updated_row.values()) + async def execute_read_query(self, query: str, *args, table_name: str): + online_hosts = await self.get_online_hosts() + results = [] + max_version = -1 + latest_result = None + + for pool_entry in online_hosts: + conn = await self.get_connection(pool_entry) + if conn is None: + warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping read.") + continue + + try: + # Execute the query + result = await conn.fetch(query, *args) + + if not result: + continue + + # Check version if it's not a special table + if table_name not in self.SPECIAL_TABLES: + try: + version_query = f""" + SELECT MAX(version) as max_version, server_id + FROM "{table_name}" + WHERE version = (SELECT MAX(version) FROM "{table_name}") + GROUP BY server_id + ORDER BY MAX(version) DESC + LIMIT 1 + """ + version_result = await conn.fetchrow(version_query) + if version_result: + version = version_result['max_version'] + server_id = version_result['server_id'] + info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})") + if version > max_version: + max_version = version + latest_result = result + else: + debug(f"No data in table {table_name} for {pool_entry['ts_id']}") + except asyncpg.exceptions.UndefinedColumnError: + warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping version check.") + if latest_result is None: + latest_result = result + else: + # For special tables, just use the first result + if latest_result is None: + latest_result = result + + results.append((pool_entry['ts_id'], result)) + + except Exception as e: + err(f"Error executing read query on {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + finally: + await conn.close() + + if not latest_result: + warn(f"No results found for query on table {table_name}") + return [] + + # Log results from all databases + for ts_id, result in results: + info(f"Read result from {ts_id}: {result}") + + return [dict(r) for r in latest_result] # Convert Record objects to dictionaries + + async def execute_write_query(self, query: str, *args, table_name: str): + conn = await self.get_connection() + if conn is None: + raise ConnectionError("Failed to connect to local database") + + try: + if table_name in self.SPECIAL_TABLES: + return await self._execute_special_table_write(conn, query, *args, table_name=table_name) + + primary_key = await self.ensure_sync_columns(conn, table_name) + + result = await conn.execute(query, *args) + + asyncio.create_task(self._sync_changes(table_name, primary_key)) + + return [] + finally: + await conn.close() + + + async def _run_sync_tasks(self, tasks): + for task in tasks: + try: + await task + except Exception as e: + err(f"Error during background sync: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + + + async def push_change(self, table_name: str, pk_value: Any, version: int, server_id: str): + asyncio.create_task(self._push_change_background(table_name, pk_value, version, server_id)) + + async def _push_change_background(self, table_name: str, pk_value: Any, version: int, server_id: str): + online_hosts = await self.get_online_hosts() + successful_pushes = 0 + failed_pushes = 0 + + for pool_entry in online_hosts: + if pool_entry['ts_id'] != os.environ.get('TS_ID'): + remote_conn = await self.get_connection(pool_entry) + if remote_conn is None: + continue + + try: + local_conn = await self.get_connection() + if local_conn is None: + continue + + try: + updated_row = await local_conn.fetchrow(f'SELECT * FROM "{table_name}" WHERE "{self.get_primary_key(table_name)}" = $1', pk_value) + finally: + await local_conn.close() + + if updated_row: + columns = updated_row.keys() + placeholders = [f'${i+1}' for i in range(len(columns))] + primary_key = self.get_primary_key(table_name) + + remote_version = await remote_conn.fetchval(f""" + SELECT version FROM "{table_name}" + WHERE "{primary_key}" = $1 + """, updated_row[primary_key]) + + if remote_version is not None and remote_version >= updated_row['version']: + debug(f"Remote version for {table_name} in {pool_entry['ts_id']} is already up to date. Skipping push.") + successful_pushes += 1 + continue + + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT ("{primary_key}") DO UPDATE SET + {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != primary_key)}, + version = EXCLUDED.version, + server_id = EXCLUDED.server_id + WHERE "{table_name}".version < EXCLUDED.version + OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) + """ + await remote_conn.execute(insert_query, *updated_row.values()) + successful_pushes += 1 except Exception as e: err(f"Error pushing change to {pool_entry['ts_id']}: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") + failed_pushes += 1 + finally: + if remote_conn: + await remote_conn.close() + + if successful_pushes > 0: + info(f"Successfully pushed changes to {successful_pushes} server(s) for {table_name}") + if failed_pushes > 0: + warn(f"Failed to push changes to {failed_pushes} server(s) for {table_name}") + async def push_all_changes(self, table_name: str): online_hosts = await self.get_online_hosts() + tasks = [] + for pool_entry in online_hosts: if pool_entry['ts_id'] != os.environ.get('TS_ID'): - try: - async with self.get_connection(pool_entry) as remote_conn: - if remote_conn is None: - continue + task = asyncio.create_task(self._push_changes_to_host(pool_entry, table_name)) + tasks.append(task) - # Fetch all rows from the local database - async with self.get_connection() as local_conn: - all_rows = await local_conn.fetch(f'SELECT * FROM "{table_name}"') + results = await asyncio.gather(*tasks, return_exceptions=True) + successful_pushes = sum(1 for r in results if r is True) + failed_pushes = sum(1 for r in results if r is False or isinstance(r, Exception)) - if all_rows: - columns = all_rows[0].keys() - placeholders = [f'${i+1}' for i in range(len(columns))] - - insert_query = f""" - INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) - VALUES ({', '.join(placeholders)}) - ON CONFLICT DO NOTHING - """ - - # Use a transaction to insert all rows - async with remote_conn.transaction(): - for row in all_rows: - await remote_conn.execute(insert_query, *row.values()) - except Exception as e: - err(f"Error pushing all changes to {pool_entry['ts_id']}: {str(e)}") - err(f"Traceback: {traceback.format_exc()}") + info(f"Push all changes summary for {table_name}: Successful: {successful_pushes}, Failed: {failed_pushes}") + if failed_pushes > 0: + warn(f"Failed to push all changes to {failed_pushes} server(s). Data may be out of sync.") - async def _execute_special_table_write(self, query: str, *args, table_name: str): - if table_name == 'spatial_ref_sys': - return await self._execute_spatial_ref_sys_write(query, *args) - # Add more special cases as needed + async def _push_changes_to_host(self, pool_entry: Dict[str, Any], table_name: str) -> bool: + remote_conn = None + try: + remote_conn = await self.get_connection(pool_entry) + if remote_conn is None: + warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping push.") + return False - async def _execute_spatial_ref_sys_write(self, query: str, *args): - result = None - async with self.get_connection() as local_conn: + local_conn = await self.get_connection() if local_conn is None: - raise ConnectionError("Failed to connect to local database") - - # Execute the query locally - result = await local_conn.execute(query, *args) + warn(f"Unable to connect to local database. Skipping push.") + return False + + try: + all_rows = await local_conn.fetch(f'SELECT * FROM "{table_name}"') + finally: + await local_conn.close() + + if all_rows: + columns = list(all_rows[0].keys()) + placeholders = [f'${i+1}' for i in range(len(columns))] + + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT DO NOTHING + """ + + async with remote_conn.transaction(): + for row in all_rows: + await remote_conn.execute(insert_query, *row.values()) + + return True + else: + debug(f"No rows to push for table {table_name}") + return True + except Exception as e: + err(f"Error pushing all changes to {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + return False + finally: + if remote_conn: + await remote_conn.close() + + async def _execute_special_table_write(self, conn, query: str, *args, table_name: str): + if table_name == 'spatial_ref_sys': + return await self._execute_spatial_ref_sys_write(conn, query, *args) + + async def _execute_spatial_ref_sys_write(self, local_conn, query: str, *args): + # Execute the query locally + result = await local_conn.execute(query, *args) # Sync the entire spatial_ref_sys table with all online hosts online_hosts = await self.get_online_hosts() for pool_entry in online_hosts: if pool_entry['ts_id'] != os.environ.get('TS_ID'): + remote_conn = await self.get_connection(pool_entry) + if remote_conn is None: + continue try: - async with self.get_connection(pool_entry) as remote_conn: - if remote_conn is None: - continue - await self.sync_spatial_ref_sys(local_conn, remote_conn) + await self.sync_spatial_ref_sys(local_conn, remote_conn) except Exception as e: err(f"Error syncing spatial_ref_sys to {pool_entry['ts_id']}: {str(e)}") err(f"Traceback: {traceback.format_exc()}") + finally: + await remote_conn.close() return result - def get_primary_key(self, table_name: str) -> str: - # This method should return the primary key for the given table - # You might want to cache this information for performance - # For now, we'll assume it's always 'id', but you should implement proper logic here - return 'id' + async def apply_batch_changes(self, conn, table_name, changes, primary_key): + if conn is None or not changes: + debug(f"Skipping apply_batch_changes because conn is none or there are no changes.") + return 0 + + try: + columns = list(changes[0].keys()) + placeholders = [f'${i+1}' for i in range(len(columns))] + + if primary_key: + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT ("{primary_key}") DO UPDATE SET + {', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in [primary_key, 'version', 'server_id'])}, + version = EXCLUDED.version, + server_id = EXCLUDED.server_id + WHERE "{table_name}".version < EXCLUDED.version + OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) + """ + else: + warn(f"Possible source of issue #4") + # For tables without a primary key, we'll use all columns for conflict resolution + insert_query = f""" + INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) + VALUES ({', '.join(placeholders)}) + ON CONFLICT DO NOTHING + """ + + affected_rows = 0 + async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"): + values = [change[col] for col in columns] + result = await conn.execute(insert_query, *values) + affected_rows += int(result.split()[-1]) + + return affected_rows + + except Exception as e: + err(f"Error applying batch changes to {table_name}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + return 0 + + async def sync_special_table(self, source_conn, dest_conn, table_name): + if table_name == 'spatial_ref_sys': + return await self.sync_spatial_ref_sys(source_conn, dest_conn) + # Add more special cases as needed async def sync_spatial_ref_sys(self, source_conn, dest_conn): try: @@ -666,6 +1072,77 @@ class APIConfig(BaseModel): return 0 + def get_primary_key(self, table_name: str) -> str: + # This method should return the primary key for the given table + # You might want to cache this information for performance + # For now, we'll assume it's always 'id', but you should implement proper logic here + return 'id' + + + async def add_primary_keys_to_local_tables(self): + conn = await self.get_connection() + + debug(f"Adding primary keys to existing tables...") + if conn is None: + raise ConnectionError("Failed to connect to local database") + + try: + tables = await conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] + if table_name not in self.SPECIAL_TABLES: + await self.ensure_sync_columns(conn, table_name) + finally: + await conn.close() + + async def add_primary_keys_to_remote_tables(self): + online_hosts = await self.get_online_hosts() + + for pool_entry in online_hosts: + conn = await self.get_connection(pool_entry) + if conn is None: + warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping primary key addition.") + continue + + try: + info(f"Adding primary keys to existing tables on {pool_entry['ts_id']}...") + tables = await conn.fetch(""" + SELECT tablename FROM pg_tables + WHERE schemaname = 'public' + """) + + for table in tables: + table_name = table['tablename'] + if table_name not in self.SPECIAL_TABLES: + primary_key = await self.ensure_sync_columns(conn, table_name) + if primary_key: + info(f"Added/ensured primary key '{primary_key}' for table '{table_name}' on {pool_entry['ts_id']}") + else: + warn(f"Failed to add/ensure primary key for table '{table_name}' on {pool_entry['ts_id']}") + + info(f"Completed adding primary keys to existing tables on {pool_entry['ts_id']}") + except Exception as e: + err(f"Error adding primary keys to existing tables on {pool_entry['ts_id']}: {str(e)}") + err(f"Traceback: {traceback.format_exc()}") + finally: + await conn.close() + + + async def close_db_pools(self): + info("Closing database connection pools...") + for pool_key, pool in self.db_pools.items(): + try: + await pool.close() + debug(f"Closed pool for {pool_key}") + except Exception as e: + err(f"Error closing pool for {pool_key}: {str(e)}") + self.db_pools.clear() + info("All database connection pools closed.") + class Location(BaseModel): latitude: float @@ -679,7 +1156,7 @@ class Location(BaseModel): state: Optional[str] = None country: Optional[str] = None context: Optional[Dict[str, Any]] = None - class_: Optional[str] = None + class_: Optional[str] = Field(None, alias="class") type: Optional[str] = None name: Optional[str] = None display_name: Optional[str] = None @@ -697,11 +1174,7 @@ class Location(BaseModel): json_encoders = { datetime: lambda dt: dt.isoformat(), } - - def model_dump(self): - data = self.dict() - data["datetime"] = self.datetime.isoformat() if self.datetime else None - return data + populate_by_name = True class Geocoder: diff --git a/sijapi/routers/gis.py b/sijapi/routers/gis.py index 65bde57..e7b91e8 100644 --- a/sijapi/routers/gis.py +++ b/sijapi/routers/gis.py @@ -120,7 +120,6 @@ async def get_last_location() -> Optional[Location]: return None - async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, datetime, None] = None) -> List[Location]: start_datetime = await dt(start) if end is None: @@ -133,10 +132,24 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, debug(f"Fetching locations between {start_datetime} and {end_datetime}") - async with API.get_connection() as conn: - locations = [] - # Check for records within the specified datetime range - range_locations = await conn.fetch(''' + query = ''' + SELECT id, datetime, + ST_X(ST_AsText(location)::geometry) AS longitude, + ST_Y(ST_AsText(location)::geometry) AS latitude, + ST_Z(ST_AsText(location)::geometry) AS elevation, + city, state, zip, street, + action, device_type, device_model, device_name, device_os + FROM locations + WHERE datetime >= $1 AND datetime <= $2 + ORDER BY datetime DESC + ''' + + locations = await API.execute_read_query(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None), table_name="locations") + + debug(f"Range locations query returned: {locations}") + + if not locations and (end is None or start_datetime.date() == end_datetime.date()): + fallback_query = ''' SELECT id, datetime, ST_X(ST_AsText(location)::geometry) AS longitude, ST_Y(ST_AsText(location)::geometry) AS latitude, @@ -144,30 +157,14 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, city, state, zip, street, action, device_type, device_model, device_name, device_os FROM locations - WHERE datetime >= $1 AND datetime <= $2 + WHERE datetime < $1 ORDER BY datetime DESC - ''', start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None)) - - debug(f"Range locations query returned: {range_locations}") - locations.extend(range_locations) - - if not locations and (end is None or start_datetime.date() == end_datetime.date()): - location_data = await conn.fetchrow(''' - SELECT id, datetime, - ST_X(ST_AsText(location)::geometry) AS longitude, - ST_Y(ST_AsText(location)::geometry) AS latitude, - ST_Z(ST_AsText(location)::geometry) AS elevation, - city, state, zip, street, - action, device_type, device_model, device_name, device_os - FROM locations - WHERE datetime < $1 - ORDER BY datetime DESC - LIMIT 1 - ''', start_datetime.replace(tzinfo=None)) - - debug(f"Fallback query returned: {location_data}") - if location_data: - locations.append(location_data) + LIMIT 1 + ''' + location_data = await API.execute_read_query(fallback_query, start_datetime.replace(tzinfo=None), table_name="locations") + debug(f"Fallback query returned: {location_data}") + if location_data: + locations = location_data debug(f"Locations found: {locations}") @@ -197,35 +194,32 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, return location_objects if location_objects else [] -# Function to fetch the last location before the specified datetime async def fetch_last_location_before(datetime: datetime) -> Optional[Location]: datetime = await dt(datetime) debug(f"Fetching last location before {datetime}") - async with API.get_connection() as conn: + query = ''' + SELECT id, datetime, + ST_X(ST_AsText(location)::geometry) AS longitude, + ST_Y(ST_AsText(location)::geometry) AS latitude, + ST_Z(ST_AsText(location)::geometry) AS elevation, + city, state, zip, street, country, + action + FROM locations + WHERE datetime < $1 + ORDER BY datetime DESC + LIMIT 1 + ''' + + location_data = await API.execute_read_query(query, datetime.replace(tzinfo=None), table_name="locations") - location_data = await conn.fetchrow(''' - SELECT id, datetime, - ST_X(ST_AsText(location)::geometry) AS longitude, - ST_Y(ST_AsText(location)::geometry) AS latitude, - ST_Z(ST_AsText(location)::geometry) AS elevation, - city, state, zip, street, country, - action - FROM locations - WHERE datetime < $1 - ORDER BY datetime DESC - LIMIT 1 - ''', datetime.replace(tzinfo=None)) - - await conn.close() - - if location_data: - debug(f"Last location found: {location_data}") - return Location(**location_data) - else: - debug("No location found before the specified datetime") - return None + if location_data: + debug(f"Last location found: {location_data[0]}") + return Location(**location_data[0]) + else: + debug("No location found before the specified datetime") + return None @gis.get("/map", response_class=HTMLResponse) async def generate_map_endpoint( @@ -247,16 +241,12 @@ async def generate_map_endpoint( return HTMLResponse(content=html_content) async def get_date_range(): - async with API.get_connection() as conn: - query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations" - row = await conn.fetchrow(query) - if row and row['min_date'] and row['max_date']: - return row['min_date'], row['max_date'] - else: - return datetime(2022, 1, 1), datetime.now() - - - + query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations" + row = await API.execute_read_query(query, table_name="locations") + if row and row[0]['min_date'] and row[0]['max_date']: + return row[0]['min_date'], row[0]['max_date'] + else: + return datetime(2022, 1, 1), datetime.now() async def generate_and_save_heatmap( start_date: Union[str, int, datetime], @@ -313,8 +303,6 @@ Generate a heatmap for the given date range and save it as a PNG file using Foli err(f"Error generating and saving heatmap: {str(e)}") raise - - async def generate_map(start_date: datetime, end_date: datetime, max_points: int): locations = await fetch_locations(start_date, end_date) if not locations: @@ -343,8 +331,6 @@ async def generate_map(start_date: datetime, end_date: datetime, max_points: int folium.TileLayer('cartodbdark_matter', name='Dark Mode').add_to(m) - - # In the generate_map function: draw = Draw( draw_options={ 'polygon': True, @@ -433,70 +419,70 @@ map.on(L.Draw.Event.CREATED, function (event) { return m.get_root().render() async def post_location(location: Location): - # if not location.datetime: - # info(f"location appears to be missing datetime: {location}") - # else: - # debug(f"post_location called with {location.datetime}") - async with API.get_connection() as conn: - try: - context = location.context or {} - action = context.get('action', 'manual') - device_type = context.get('device_type', 'Unknown') - device_model = context.get('device_model', 'Unknown') - device_name = context.get('device_name', 'Unknown') - device_os = context.get('device_os', 'Unknown') - - # Parse and localize the datetime - localized_datetime = await dt(location.datetime) + try: + context = location.context or {} + action = context.get('action', 'manual') + device_type = context.get('device_type', 'Unknown') + device_model = context.get('device_model', 'Unknown') + device_name = context.get('device_name', 'Unknown') + device_os = context.get('device_os', 'Unknown') + + # Parse and localize the datetime + localized_datetime = await dt(location.datetime) - await conn.execute(''' - INSERT INTO locations ( - datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os, - class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood, - suburb, county, country_code, country - ) - VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, - $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) - ''', localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state, - location.zip, location.street, action, device_type, device_model, device_name, device_os, - location.class_, location.type, location.name, location.display_name, - location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood, - location.suburb, location.county, location.country_code, location.country) - - await conn.close() - info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}") - return { - 'datetime': localized_datetime, - 'latitude': location.latitude, - 'longitude': location.longitude, - 'elevation': location.elevation, - 'city': location.city, - 'state': location.state, - 'zip': location.zip, - 'street': location.street, - 'action': action, - 'device_type': device_type, - 'device_model': device_model, - 'device_name': device_name, - 'device_os': device_os, - 'class_': location.class_, - 'type': location.type, - 'name': location.name, - 'display_name': location.display_name, - 'amenity': location.amenity, - 'house_number': location.house_number, - 'road': location.road, - 'quarter': location.quarter, - 'neighbourhood': location.neighbourhood, - 'suburb': location.suburb, - 'county': location.county, - 'country_code': location.country_code, - 'country': location.country - } - except Exception as e: - err(f"Error posting location {e}") - err(traceback.format_exc()) - return None + query = ''' + INSERT INTO locations ( + datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os, + class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood, + suburb, county, country_code, country + ) + VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, + $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) + ''' + + await API.execute_write_query( + query, + localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state, + location.zip, location.street, action, device_type, device_model, device_name, device_os, + location.class_, location.type, location.name, location.display_name, + location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood, + location.suburb, location.county, location.country_code, location.country, + table_name="locations" + ) + + info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}") + return { + 'datetime': localized_datetime, + 'latitude': location.latitude, + 'longitude': location.longitude, + 'elevation': location.elevation, + 'city': location.city, + 'state': location.state, + 'zip': location.zip, + 'street': location.street, + 'action': action, + 'device_type': device_type, + 'device_model': device_model, + 'device_name': device_name, + 'device_os': device_os, + 'class_': location.class_, + 'type': location.type, + 'name': location.name, + 'display_name': location.display_name, + 'amenity': location.amenity, + 'house_number': location.house_number, + 'road': location.road, + 'quarter': location.quarter, + 'neighbourhood': location.neighbourhood, + 'suburb': location.suburb, + 'county': location.county, + 'country_code': location.country_code, + 'country': location.country + } + except Exception as e: + err(f"Error posting location {e}") + err(traceback.format_exc()) + return None @gis.post("/locate") @@ -553,6 +539,7 @@ async def get_last_location_endpoint() -> JSONResponse: raise HTTPException(status_code=404, detail="No location found before the specified datetime") + @gis.get("/locate/{datetime_str}", response_model=List[Location]) async def get_locate(datetime_str: str, all: bool = False): try: diff --git a/sijapi/routers/loc.py b/sijapi/routers/loc.py deleted file mode 100644 index fff27e9..0000000 --- a/sijapi/routers/loc.py +++ /dev/null @@ -1,393 +0,0 @@ -''' -Uses Postgres/PostGIS for location tracking (data obtained via the companion mobile Pythonista scripts), and for geocoding purposes. -''' -from fastapi import APIRouter, HTTPException, Query -from fastapi.responses import HTMLResponse, JSONResponse -import yaml -from typing import List, Tuple, Union -import traceback -from datetime import datetime, timezone -from typing import Union, List -import folium -from zoneinfo import ZoneInfo -from dateutil.parser import parse as dateutil_parse -from typing import Optional, List, Union -from datetime import datetime -from sijapi import L, API, TZ, GEO -from sijapi.classes import Location -from sijapi.utilities import haversine - -loc = APIRouter() -logger = L.get_module_logger("loc") - -async def dt( - date_time: Union[str, int, datetime], - tz: Union[str, ZoneInfo, None] = None -) -> datetime: - try: - # Convert integer (epoch time) to UTC datetime - if isinstance(date_time, int): - date_time = datetime.utcfromtimestamp(date_time).replace(tzinfo=timezone.utc) - logger.debug(f"Converted epoch time {date_time} to UTC datetime object.") - # Convert string to datetime if necessary - elif isinstance(date_time, str): - date_time = dateutil_parse(date_time) - logger.debug(f"Converted string '{date_time}' to datetime object.") - - if not isinstance(date_time, datetime): - raise ValueError(f"Input must be a string, integer (epoch time), or datetime object. What we received: {date_time}, type {type(date_time)}") - - # Ensure the datetime is timezone-aware (UTC if not specified) - if date_time.tzinfo is None: - date_time = date_time.replace(tzinfo=timezone.utc) - logger.debug("Added UTC timezone to naive datetime.") - - # Handle provided timezone - if tz is not None: - if isinstance(tz, str): - if tz == "local": - last_loc = await get_timezone_without_timezone(date_time) - tz = await GEO.tz_at(last_loc.latitude, last_loc.longitude) - logger.debug(f"Using local timezone: {tz}") - else: - try: - tz = ZoneInfo(tz) - except Exception as e: - logger.error(f"Invalid timezone string '{tz}'. Error: {e}") - raise ValueError(f"Invalid timezone string: {tz}") - elif isinstance(tz, ZoneInfo): - pass # tz is already a ZoneInfo object - else: - raise ValueError(f"What we needed: tz == 'local', a string, or a ZoneInfo object. What we got: tz, a {type(tz)}, == {tz})") - - # Convert to the provided or determined timezone - date_time = date_time.astimezone(tz) - logger.debug(f"Converted datetime to timezone: {tz}") - - return date_time - except ValueError as e: - logger.error(f"Error in dt: {e}") - raise - except Exception as e: - logger.error(f"Unexpected error in dt: {e}") - raise ValueError(f"Failed to process datetime: {e}") - - -async def get_timezone_without_timezone(date_time): - # This is a bit convoluted because we're trying to solve the paradox of needing to know the location in order to determine the timezone, but needing the timezone to be certain we've got the right location if this datetime coincided with inter-timezone travel. Our imperfect solution is to use UTC for an initial location query to determine roughly where we were at the time, get that timezone, then check the location again using that timezone, and if this location is different from the one using UTC, get the timezone again usng it, otherwise use the one we already sourced using UTC. - - # Step 1: Use UTC as an interim timezone to query location - interim_dt = date_time.replace(tzinfo=ZoneInfo("UTC")) - interim_loc = await fetch_last_location_before(interim_dt) - - # Step 2: Get a preliminary timezone based on the interim location - interim_tz = await GEO.tz_current((interim_loc.latitude, interim_loc.longitude)) - - # Step 3: Apply this preliminary timezone and query location again - query_dt = date_time.replace(tzinfo=ZoneInfo(interim_tz)) - query_loc = await fetch_last_location_before(query_dt) - - # Step 4: Get the final timezone, reusing interim_tz if location hasn't changed - return interim_tz if query_loc == interim_loc else await GEO.tz_current(query_loc.latitude, query_loc.longitude) - - -async def get_last_location() -> Optional[Location]: - query_datetime = datetime.now(TZ) - logger.debug(f"Query_datetime: {query_datetime}") - - this_location = await fetch_last_location_before(query_datetime) - - if this_location: - logger.debug(f"location: {this_location}") - return this_location - - return None - - -async def fetch_locations(start: datetime, end: datetime = None) -> List[Location]: - start_datetime = await dt(start) - if end is None: - end_datetime = await dt(start_datetime.replace(hour=23, minute=59, second=59)) - else: - end_datetime = await dt(end) - - if start_datetime.time() == datetime.min.time() and end_datetime.time() == datetime.min.time(): - end_datetime = end_datetime.replace(hour=23, minute=59, second=59) - - logger.debug(f"Fetching locations between {start_datetime} and {end_datetime}") - - async with API.get_connection() as conn: - locations = [] - # Check for records within the specified datetime range - range_locations = await conn.fetch(''' - SELECT id, datetime, - ST_X(ST_AsText(location)::geometry) AS longitude, - ST_Y(ST_AsText(location)::geometry) AS latitude, - ST_Z(ST_AsText(location)::geometry) AS elevation, - city, state, zip, street, - action, device_type, device_model, device_name, device_os - FROM locations - WHERE datetime >= $1 AND datetime <= $2 - ORDER BY datetime DESC - ''', start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None)) - - logger.debug(f"Range locations query returned: {range_locations}") - locations.extend(range_locations) - - if not locations and (end is None or start_datetime.date() == end_datetime.date()): - location_data = await conn.fetchrow(''' - SELECT id, datetime, - ST_X(ST_AsText(location)::geometry) AS longitude, - ST_Y(ST_AsText(location)::geometry) AS latitude, - ST_Z(ST_AsText(location)::geometry) AS elevation, - city, state, zip, street, - action, device_type, device_model, device_name, device_os - FROM locations - WHERE datetime < $1 - ORDER BY datetime DESC - LIMIT 1 - ''', start_datetime.replace(tzinfo=None)) - - logger.debug(f"Fallback query returned: {location_data}") - if location_data: - locations.append(location_data) - - logger.debug(f"Locations found: {locations}") - - # Sort location_data based on the datetime field in descending order - sorted_locations = sorted(locations, key=lambda x: x['datetime'], reverse=True) - - # Create Location objects directly from the location data - location_objects = [ - Location( - latitude=location['latitude'], - longitude=location['longitude'], - datetime=location['datetime'], - elevation=location.get('elevation'), - city=location.get('city'), - state=location.get('state'), - zip=location.get('zip'), - street=location.get('street'), - context={ - 'action': location.get('action'), - 'device_type': location.get('device_type'), - 'device_model': location.get('device_model'), - 'device_name': location.get('device_name'), - 'device_os': location.get('device_os') - } - ) for location in sorted_locations if location['latitude'] is not None and location['longitude'] is not None - ] - - return location_objects if location_objects else [] - -# Function to fetch the last location before the specified datetime -async def fetch_last_location_before(datetime: datetime) -> Optional[Location]: - datetime = await dt(datetime) - - logger.debug(f"Fetching last location before {datetime}") - - async with API.get_connection() as conn: - - location_data = await conn.fetchrow(''' - SELECT id, datetime, - ST_X(ST_AsText(location)::geometry) AS longitude, - ST_Y(ST_AsText(location)::geometry) AS latitude, - ST_Z(ST_AsText(location)::geometry) AS elevation, - city, state, zip, street, country, - action - FROM locations - WHERE datetime < $1 - ORDER BY datetime DESC - LIMIT 1 - ''', datetime.replace(tzinfo=None)) - - await conn.close() - - if location_data: - logger.debug(f"Last location found: {location_data}") - return Location(**location_data) - else: - logger.debug("No location found before the specified datetime") - return None - -@loc.get("/map/start_date={start_date_str}&end_date={end_date_str}", response_class=HTMLResponse) -async def generate_map_endpoint(start_date_str: str, end_date_str: str): - try: - start_date = await dt(start_date_str) - end_date = await dt(end_date_str) - except ValueError: - raise HTTPException(status_code=400, detail="Invalid date format") - - html_content = await generate_map(start_date, end_date) - return HTMLResponse(content=html_content) - - -@loc.get("/map", response_class=HTMLResponse) -async def generate_alltime_map_endpoint(): - try: - start_date = await dt(datetime.fromisoformat("2022-01-01")) - end_date = dt(datetime.now()) - except ValueError: - raise HTTPException(status_code=400, detail="Invalid date format") - - html_content = await generate_map(start_date, end_date) - return HTMLResponse(content=html_content) - - -async def generate_map(start_date: datetime, end_date: datetime): - locations = await fetch_locations(start_date, end_date) - if not locations: - raise HTTPException(status_code=404, detail="No locations found for the given date range") - - # Create a folium map centered around the first location - map_center = [locations[0].latitude, locations[0].longitude] - m = folium.Map(location=map_center, zoom_start=5) - - # Add markers for each location - for location in locations: - folium.Marker( - location=[location.latitude, location.longitude], - popup=f"{location.city}, {location.state}
Elevation: {location.elevation}m
Date: {location.datetime}", - tooltip=f"{location.city}, {location.state}" - ).add_to(m) - - # Save the map to an HTML file and return the HTML content - map_html = "map.html" - m.save(map_html) - - with open(map_html, 'r') as file: - html_content = file.read() - - return html_content - -async def post_location(location: Location): - async with API.get_connection() as conn: - try: - context = location.context or {} - action = context.get('action', 'manual') - device_type = context.get('device_type', 'Unknown') - device_model = context.get('device_model', 'Unknown') - device_name = context.get('device_name', 'Unknown') - device_os = context.get('device_os', 'Unknown') - - # Parse and localize the datetime - localized_datetime = await dt(location.datetime) - - await conn.execute(''' - INSERT INTO locations ( - datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os, - class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood, - suburb, county, country_code, country - ) - VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, - $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) - ''', localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state, - location.zip, location.street, action, device_type, device_model, device_name, device_os, - location.class_, location.type, location.name, location.display_name, - location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood, - location.suburb, location.county, location.country_code, location.country) - - await conn.close() - logger.info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}") - return { - 'datetime': localized_datetime, - 'latitude': location.latitude, - 'longitude': location.longitude, - 'elevation': location.elevation, - 'city': location.city, - 'state': location.state, - 'zip': location.zip, - 'street': location.street, - 'action': action, - 'device_type': device_type, - 'device_model': device_model, - 'device_name': device_name, - 'device_os': device_os, - 'class_': location.class_, - 'type': location.type, - 'name': location.name, - 'display_name': location.display_name, - 'amenity': location.amenity, - 'house_number': location.house_number, - 'road': location.road, - 'quarter': location.quarter, - 'neighbourhood': location.neighbourhood, - 'suburb': location.suburb, - 'county': location.county, - 'country_code': location.country_code, - 'country': location.country - } - except Exception as e: - logger.error(f"Error posting location {e}") - logger.error(traceback.format_exc()) - return None - - -@loc.post("/locate") -async def post_locate_endpoint(locations: Union[Location, List[Location]]): - if isinstance(locations, Location): - locations = [locations] - - # Prepare locations - for lcn in locations: - if not lcn.datetime: - tz = await GEO.tz_at(lcn.latitude, lcn.longitude) - lcn.datetime = datetime.now(ZoneInfo(tz)).isoformat() - - if not lcn.context: - lcn.context = { - "action": "missing", - "device_type": "API", - "device_model": "Unknown", - "device_name": "Unknown", - "device_os": "Unknown" - } - logger.debug(f"Location received for processing: {lcn}") - - geocoded_locations = await GEO.code(locations) - - responses = [] - if isinstance(geocoded_locations, List): - for location in geocoded_locations: - logger.debug(f"Final location to be submitted to database: {location}") - location_entry = await post_location(location) - if location_entry: - responses.append({"location_data": location_entry}) - else: - logger.warning(f"Posting location to database appears to have failed.") - else: - logger.debug(f"Final location to be submitted to database: {geocoded_locations}") - location_entry = await post_location(geocoded_locations) - if location_entry: - responses.append({"location_data": location_entry}) - else: - logger.warning(f"Posting location to database appears to have failed.") - - return {"message": "Locations and weather updated", "results": responses} - - -@loc.get("/locate", response_model=Location) -async def get_last_location_endpoint() -> JSONResponse: - this_location = await get_last_location() - - if this_location: - location_dict = this_location.model_dump() - location_dict["datetime"] = this_location.datetime.isoformat() - return JSONResponse(content=location_dict) - else: - raise HTTPException(status_code=404, detail="No location found before the specified datetime") - -@loc.get("/locate/{datetime_str}", response_model=List[Location]) -async def get_locate(datetime_str: str, all: bool = False): - try: - date_time = await dt(datetime_str) - except ValueError as e: - logger.error(f"Invalid datetime string provided: {datetime_str}") - return ["ERROR: INVALID DATETIME PROVIDED. USE YYYYMMDDHHmmss or YYYYMMDD format."] - - locations = await fetch_locations(date_time) - if not locations: - raise HTTPException(status_code=404, detail="No nearby data found for this date and time") - - return locations if all else [locations[0]] - diff --git a/sijapi/routers/serve.py b/sijapi/routers/serve.py index cb596a4..e87eb49 100644 --- a/sijapi/routers/serve.py +++ b/sijapi/routers/serve.py @@ -212,7 +212,6 @@ if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True: except requests.exceptions.RequestException: results.append(f"{address} is down") - # Generate a simple text-based graph graph = '|' * up_count + '.' * (len(addresses) - up_count) text_update = "\n".join(results) @@ -220,7 +219,6 @@ if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True: output = shellfish_run_widget_command(widget_command) return {"output": output, "graph": graph} - def shellfish_update_widget(update: WidgetUpdate): widget_command = ["widget"] @@ -290,6 +288,7 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True: for result in results: bg_tasks.add_task(cl_docket_process, result) return JSONResponse(content={"message": "Received"}, status_code=status.HTTP_200_OK) + async def cl_docket_process(result): async with httpx.AsyncClient() as session: @@ -346,11 +345,13 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True: await cl_download_file(file_url, target_path, session) debug(f"Downloaded {file_name} to {target_path}") + def cl_case_details(docket): case_info = CASETABLE.get(str(docket), {"code": "000", "shortname": "UNKNOWN"}) case_code = case_info.get("code") short_name = case_info.get("shortname") return case_code, short_name + async def cl_download_file(url: str, path: Path, session: aiohttp.ClientSession = None): headers = { @@ -417,19 +418,20 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True: await cl_download_file(download_url, target_path, session) debug(f"Downloaded {file_name} to {target_path}") + @serve.get("/s", response_class=HTMLResponse) async def shortener_form(request: Request): return templates.TemplateResponse("shortener.html", {"request": request}) + @serve.post("/s") async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)): - await create_tables() if custom_code: if len(custom_code) != 3 or not custom_code.isalnum(): return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"}) - existing = await API.execute_write_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls") + existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls") if existing: return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"}) @@ -437,8 +439,9 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c else: chars = string.ascii_letters + string.digits while True: + debug(f"FOUND THE ISSUE") short_code = ''.join(random.choice(chars) for _ in range(3)) - existing = await API.execute_write_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls") + existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls") if not existing: break @@ -451,48 +454,36 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c short_url = f"https://sij.ai/{short_code}" return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url}) -async def create_tables(): - await API.execute_write_query(''' - CREATE TABLE IF NOT EXISTS short_urls ( - short_code VARCHAR(3) PRIMARY KEY, - long_url TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ) - ''', table_name="short_urls") - await API.execute_write_query(''' - CREATE TABLE IF NOT EXISTS click_logs ( - id SERIAL PRIMARY KEY, - short_code VARCHAR(3) REFERENCES short_urls(short_code), - clicked_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - ip_address TEXT, - user_agent TEXT - ) - ''', table_name="click_logs") -@serve.get("/{short_code}", response_class=RedirectResponse, status_code=301) -async def redirect_short_url(request: Request, short_code: str = PathParam(..., min_length=3, max_length=3)): - if request.headers.get('host') != 'sij.ai': - raise HTTPException(status_code=404, detail="Not Found") - - result = await API.execute_write_query( +@serve.get("/{short_code}") +async def redirect_short_url(short_code: str): + results = await API.execute_read_query( 'SELECT long_url FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls" ) - if result: - await API.execute_write_query( - 'INSERT INTO click_logs (short_code, ip_address, user_agent) VALUES ($1, $2, $3)', - short_code, request.client.host, request.headers.get("user-agent"), - table_name="click_logs" - ) - return result['long_url'] - else: + if not results: raise HTTPException(status_code=404, detail="Short URL not found") + + long_url = results[0].get('long_url') + + if not long_url: + raise HTTPException(status_code=404, detail="Long URL not found") + + # Increment click count (you may want to do this asynchronously) + await API.execute_write_query( + 'INSERT INTO click_logs (short_code, clicked_at) VALUES ($1, $2)', + short_code, datetime.now(), + table_name="click_logs" + ) + + return RedirectResponse(url=long_url) + @serve.get("/analytics/{short_code}") async def get_analytics(short_code: str): - url_info = await API.execute_write_query( + url_info = await API.execute_read_query( 'SELECT long_url, created_at FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls" @@ -500,13 +491,13 @@ async def get_analytics(short_code: str): if not url_info: raise HTTPException(status_code=404, detail="Short URL not found") - click_count = await API.execute_write_query( + click_count = await API.execute_read_query( 'SELECT COUNT(*) FROM click_logs WHERE short_code = $1', short_code, table_name="click_logs" ) - clicks = await API.execute_write_query( + clicks = await API.execute_read_query( 'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100', short_code, table_name="click_logs" @@ -521,15 +512,20 @@ async def get_analytics(short_code: str): } - - async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, destination: str): - dest_host, dest_port = destination.split(':') - dest_port = int(dest_port) + try: + dest_host, dest_port = destination.split(':') + dest_port = int(dest_port) + except ValueError: + warn(f"Invalid destination format: {destination}. Expected 'host:port'.") + writer.close() + await writer.wait_closed() + return try: dest_reader, dest_writer = await asyncio.open_connection(dest_host, dest_port) except Exception as e: + warn(f"Failed to connect to destination {destination}: {str(e)}") writer.close() await writer.wait_closed() return @@ -543,7 +539,7 @@ async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWr dst.write(data) await dst.drain() except Exception as e: - pass + warn(f"Error in forwarding: {str(e)}") finally: dst.close() await dst.wait_closed() @@ -554,8 +550,12 @@ async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWr ) async def start_server(source: str, destination: str): - host, port = source.split(':') - port = int(port) + if ':' in source: + host, port = source.split(':') + port = int(port) + else: + host = source + port = 80 server = await asyncio.start_server( lambda r, w: forward_traffic(r, w, destination), @@ -566,6 +566,7 @@ async def start_server(source: str, destination: str): async with server: await server.serve_forever() + async def start_port_forwarding(): if hasattr(Serve, 'forwarding_rules'): for rule in Serve.forwarding_rules: @@ -573,6 +574,7 @@ async def start_port_forwarding(): else: warn("No forwarding rules found in the configuration.") + @serve.get("/forward_status") async def get_forward_status(): if hasattr(Serve, 'forwarding_rules'): @@ -580,5 +582,5 @@ async def get_forward_status(): else: return {"status": "inactive", "message": "No forwarding rules configured"} -# Add this to the end of your serve.py file + asyncio.create_task(start_port_forwarding()) \ No newline at end of file diff --git a/sijapi/routers/weather.py b/sijapi/routers/weather.py index a457fa1..55b091e 100644 --- a/sijapi/routers/weather.py +++ b/sijapi/routers/weather.py @@ -116,129 +116,129 @@ async def get_weather(date_time: dt_datetime, latitude: float, longitude: float, async def store_weather_to_db(date_time: dt_datetime, weather_data: dict): warn(f"Using {date_time.strftime('%Y-%m-%d %H:%M:%S')} as our datetime in store_weather_to_db") - async with API.get_connection() as conn: - try: - day_data = weather_data.get('days')[0] - debug(f"RAW DAY_DATA: {day_data}") - # Handle preciptype and stations as PostgreSQL arrays - preciptype_array = day_data.get('preciptype', []) or [] - stations_array = day_data.get('stations', []) or [] + try: + day_data = weather_data.get('days')[0] + debug(f"RAW DAY_DATA: {day_data}") + # Handle preciptype and stations as PostgreSQL arrays + preciptype_array = day_data.get('preciptype', []) or [] + stations_array = day_data.get('stations', []) or [] - date_str = date_time.strftime("%Y-%m-%d") - warn(f"Using {date_str} in our query in store_weather_to_db.") + date_str = date_time.strftime("%Y-%m-%d") + warn(f"Using {date_str} in our query in store_weather_to_db.") - # Get location details from weather data if available - longitude = weather_data.get('longitude') - latitude = weather_data.get('latitude') - tz = await GEO.tz_at(latitude, longitude) - elevation = await GEO.elevation(latitude, longitude) - location_point = f"POINTZ({longitude} {latitude} {elevation})" if longitude and latitude and elevation else None + # Get location details from weather data if available + longitude = weather_data.get('longitude') + latitude = weather_data.get('latitude') + tz = await GEO.tz_at(latitude, longitude) + elevation = await GEO.elevation(latitude, longitude) + location_point = f"POINTZ({longitude} {latitude} {elevation})" if longitude and latitude and elevation else None - warn(f"Uncorrected datetimes in store_weather_to_db: {day_data['datetime']}, sunrise: {day_data['sunrise']}, sunset: {day_data['sunset']}") - day_data['datetime'] = await gis.dt(day_data.get('datetimeEpoch')) - day_data['sunrise'] = await gis.dt(day_data.get('sunriseEpoch')) - day_data['sunset'] = await gis.dt(day_data.get('sunsetEpoch')) - warn(f"Corrected datetimes in store_weather_to_db: {day_data['datetime']}, sunrise: {day_data['sunrise']}, sunset: {day_data['sunset']}") + warn(f"Uncorrected datetimes in store_weather_to_db: {day_data['datetime']}, sunrise: {day_data['sunrise']}, sunset: {day_data['sunset']}") + day_data['datetime'] = await gis.dt(day_data.get('datetimeEpoch')) + day_data['sunrise'] = await gis.dt(day_data.get('sunriseEpoch')) + day_data['sunset'] = await gis.dt(day_data.get('sunsetEpoch')) + warn(f"Corrected datetimes in store_weather_to_db: {day_data['datetime']}, sunrise: {day_data['sunrise']}, sunset: {day_data['sunset']}") - daily_weather_params = ( - day_data.get('sunrise'), day_data.get('sunriseEpoch'), - day_data.get('sunset'), day_data.get('sunsetEpoch'), - day_data.get('description'), day_data.get('tempmax'), - day_data.get('tempmin'), day_data.get('uvindex'), - day_data.get('winddir'), day_data.get('windspeed'), - day_data.get('icon'), dt_datetime.now(tz), - day_data.get('datetime'), day_data.get('datetimeEpoch'), - day_data.get('temp'), day_data.get('feelslikemax'), - day_data.get('feelslikemin'), day_data.get('feelslike'), - day_data.get('dew'), day_data.get('humidity'), - day_data.get('precip'), day_data.get('precipprob'), - day_data.get('precipcover'), preciptype_array, - day_data.get('snow'), day_data.get('snowdepth'), - day_data.get('windgust'), day_data.get('pressure'), - day_data.get('cloudcover'), day_data.get('visibility'), - day_data.get('solarradiation'), day_data.get('solarenergy'), - day_data.get('severerisk', 0), day_data.get('moonphase'), - day_data.get('conditions'), stations_array, day_data.get('source'), - location_point - ) - except Exception as e: - err(f"Failed to prepare database query in store_weather_to_db! {e}") - - try: - daily_weather_query = ''' - INSERT INTO DailyWeather ( - sunrise, sunriseepoch, sunset, sunsetepoch, description, - tempmax, tempmin, uvindex, winddir, windspeed, icon, last_updated, - datetime, datetimeepoch, temp, feelslikemax, feelslikemin, feelslike, - dew, humidity, precip, precipprob, precipcover, preciptype, - snow, snowdepth, windgust, pressure, cloudcover, visibility, - solarradiation, solarenergy, severerisk, moonphase, conditions, - stations, source, location - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38) - RETURNING id - ''' + daily_weather_params = ( + day_data.get('sunrise'), day_data.get('sunriseEpoch'), + day_data.get('sunset'), day_data.get('sunsetEpoch'), + day_data.get('description'), day_data.get('tempmax'), + day_data.get('tempmin'), day_data.get('uvindex'), + day_data.get('winddir'), day_data.get('windspeed'), + day_data.get('icon'), dt_datetime.now(tz), + day_data.get('datetime'), day_data.get('datetimeEpoch'), + day_data.get('temp'), day_data.get('feelslikemax'), + day_data.get('feelslikemin'), day_data.get('feelslike'), + day_data.get('dew'), day_data.get('humidity'), + day_data.get('precip'), day_data.get('precipprob'), + day_data.get('precipcover'), preciptype_array, + day_data.get('snow'), day_data.get('snowdepth'), + day_data.get('windgust'), day_data.get('pressure'), + day_data.get('cloudcover'), day_data.get('visibility'), + day_data.get('solarradiation'), day_data.get('solarenergy'), + day_data.get('severerisk', 0), day_data.get('moonphase'), + day_data.get('conditions'), stations_array, day_data.get('source'), + location_point + ) + except Exception as e: + err(f"Failed to prepare database query in store_weather_to_db! {e}") + return "FAILURE" + + try: + daily_weather_query = ''' + INSERT INTO DailyWeather ( + sunrise, sunriseepoch, sunset, sunsetepoch, description, + tempmax, tempmin, uvindex, winddir, windspeed, icon, last_updated, + datetime, datetimeepoch, temp, feelslikemax, feelslikemin, feelslike, + dew, humidity, precip, precipprob, precipcover, preciptype, + snow, snowdepth, windgust, pressure, cloudcover, visibility, + solarradiation, solarenergy, severerisk, moonphase, conditions, + stations, source, location + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38) + RETURNING id + ''' + + daily_weather_id = await API.execute_write_query(daily_weather_query, *daily_weather_params, table_name="DailyWeather") + + if 'hours' in day_data: + debug(f"Processing hours now...") + for hour_data in day_data['hours']: + try: + await asyncio.sleep(0.01) + hour_data['datetime'] = await gis.dt(hour_data.get('datetimeEpoch')) + hour_preciptype_array = hour_data.get('preciptype', []) or [] + hour_stations_array = hour_data.get('stations', []) or [] + hourly_weather_params = ( + daily_weather_id, + hour_data['datetime'], + hour_data.get('datetimeEpoch'), + hour_data['temp'], + hour_data['feelslike'], + hour_data['humidity'], + hour_data['dew'], + hour_data['precip'], + hour_data['precipprob'], + hour_preciptype_array, + hour_data['snow'], + hour_data['snowdepth'], + hour_data['windgust'], + hour_data['windspeed'], + hour_data['winddir'], + hour_data['pressure'], + hour_data['cloudcover'], + hour_data['visibility'], + hour_data['solarradiation'], + hour_data['solarenergy'], + hour_data['uvindex'], + hour_data.get('severerisk', 0), + hour_data['conditions'], + hour_data['icon'], + hour_stations_array, + hour_data.get('source', ''), + ) - async with conn.transaction(): - daily_weather_id = await conn.fetchval(daily_weather_query, *daily_weather_params) - - if 'hours' in day_data: - debug(f"Processing hours now...") - for hour_data in day_data['hours']: try: - await asyncio.sleep(0.01) - hour_data['datetime'] = await gis.dt(hour_data.get('datetimeEpoch')) - hour_preciptype_array = hour_data.get('preciptype', []) or [] - hour_stations_array = hour_data.get('stations', []) or [] - hourly_weather_params = ( - daily_weather_id, - hour_data['datetime'], - hour_data.get('datetimeEpoch'), - hour_data['temp'], - hour_data['feelslike'], - hour_data['humidity'], - hour_data['dew'], - hour_data['precip'], - hour_data['precipprob'], - hour_preciptype_array, - hour_data['snow'], - hour_data['snowdepth'], - hour_data['windgust'], - hour_data['windspeed'], - hour_data['winddir'], - hour_data['pressure'], - hour_data['cloudcover'], - hour_data['visibility'], - hour_data['solarradiation'], - hour_data['solarenergy'], - hour_data['uvindex'], - hour_data.get('severerisk', 0), - hour_data['conditions'], - hour_data['icon'], - hour_stations_array, - hour_data.get('source', ''), - ) - - try: - hourly_weather_query = ''' - INSERT INTO HourlyWeather (daily_weather_id, datetime, datetimeepoch, temp, feelslike, humidity, dew, precip, precipprob, - preciptype, snow, snowdepth, windgust, windspeed, winddir, pressure, cloudcover, visibility, solarradiation, solarenergy, - uvindex, severerisk, conditions, icon, stations, source) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) - RETURNING id - ''' - async with conn.transaction(): - hourly_weather_id = await conn.fetchval(hourly_weather_query, *hourly_weather_params) - debug(f"Done processing hourly_weather_id {hourly_weather_id}") - except Exception as e: - err(f"EXCEPTION: {e}") - + hourly_weather_query = ''' + INSERT INTO HourlyWeather (daily_weather_id, datetime, datetimeepoch, temp, feelslike, humidity, dew, precip, precipprob, + preciptype, snow, snowdepth, windgust, windspeed, winddir, pressure, cloudcover, visibility, solarradiation, solarenergy, + uvindex, severerisk, conditions, icon, stations, source) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) + RETURNING id + ''' + hourly_weather_id = await API.execute_write_query(hourly_weather_query, *hourly_weather_params, table_name="HourlyWeather") + debug(f"Done processing hourly_weather_id {hourly_weather_id}") except Exception as e: err(f"EXCEPTION: {e}") - return "SUCCESS" - - except Exception as e: - err(f"Error in dailyweather storage: {e}") + except Exception as e: + err(f"EXCEPTION: {e}") + + return "SUCCESS" + + except Exception as e: + err(f"Error in dailyweather storage: {e}") + return "FAILURE" +