Much more efficient database sync method, first working version.

This commit is contained in:
sanj 2024-08-01 23:51:01 -07:00
parent d01f24ad45
commit 299d27b1e7
7 changed files with 950 additions and 882 deletions

View file

@ -20,6 +20,7 @@ L = Logger("Central", LOGS_DIR)
# API essentials # API essentials
API = APIConfig.load('api', 'secrets') API = APIConfig.load('api', 'secrets')
Dir = Configuration.load('dirs') Dir = Configuration.load('dirs')
HOST = f"{API.BIND}:{API.PORT}" HOST = f"{API.BIND}:{API.PORT}"
LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost'] LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']

View file

@ -39,12 +39,11 @@ def warn(text: str): logger.warning(text)
def err(text: str): logger.error(text) def err(text: str): logger.error(text)
def crit(text: str): logger.critical(text) def crit(text: str): logger.critical(text)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
# Startup # Startup
crit("sijapi launched") crit("sijapi launched")
crit(f"Arguments: {args}") info(f"Arguments: {args}")
# Load routers # Load routers
if args.test: if args.test:
@ -54,20 +53,10 @@ async def lifespan(app: FastAPI):
if getattr(API.MODULES, module_name): if getattr(API.MODULES, module_name):
load_router(module_name) load_router(module_name)
crit("Starting database synchronization...")
try: try:
# Initialize sync structures on all databases # Initialize sync structures on all databases
await API.initialize_sync() await API.initialize_sync()
# Check if other instances have more recent data
source = await API.get_most_recent_source()
if source:
crit(f"Pulling changes from {source['ts_id']} ({source['ts_ip']})...")
total_changes = await API.pull_changes(source)
crit(f"Data pull complete. Total changes: {total_changes}")
else:
crit("No instances with more recent data found or all instances are offline.")
except Exception as e: except Exception as e:
crit(f"Error during startup: {str(e)}") crit(f"Error during startup: {str(e)}")
crit(f"Traceback: {traceback.format_exc()}") crit(f"Traceback: {traceback.format_exc()}")
@ -79,8 +68,6 @@ async def lifespan(app: FastAPI):
await API.close_db_pools() await API.close_db_pools()
crit("Database pools closed.") crit("Database pools closed.")
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware( app.add_middleware(
@ -135,30 +122,41 @@ async def handle_exception_middleware(request: Request, call_next):
return response return response
# This was removed on 7/31/2024 when we decided to instead use a targeted push sync approach. @app.post("/sync/pull")
deprecated = ''' async def pull_changes():
async def push_changes_background():
try: try:
await API.push_changes_to_all() await API.add_primary_keys_to_local_tables()
except Exception as e: await API.add_primary_keys_to_remote_tables()
err(f"Error pushing changes to other databases: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
@app.middleware("http")
async def sync_middleware(request: Request, call_next):
response = await call_next(request)
# Check if the request was a database write operation
if request.method in ["POST", "PUT", "PATCH", "DELETE"]:
try: try:
# Push changes to other databases
await API.push_changes_to_all()
except Exception as e:
err(f"Error pushing changes to other databases: {str(e)}")
return response source = await API.get_most_recent_source()
'''
if source:
# Pull changes from the source
total_changes = await API.pull_changes(source)
return JSONResponse(content={
"status": "success",
"message": f"Pull complete. Total changes: {total_changes}",
"source": f"{source['ts_id']} ({source['ts_ip']})",
"changes": total_changes
})
else:
return JSONResponse(content={
"status": "info",
"message": "No instances with more recent data found or all instances are offline."
})
except Exception as e:
err(f"Error during pull: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Error during pull: {str(e)}")
except Exception as e:
err(f"Error while ensuring primary keys to tables: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Error during primary key insurance: {str(e)}")
def load_router(router_name): def load_router(router_name):
router_file = ROUTER_DIR / f'{router_name}.py' router_file = ROUTER_DIR / f'{router_name}.py'

View file

@ -5,6 +5,7 @@ import math
import os import os
import re import re
import uuid import uuid
import time
import aiofiles import aiofiles
import aiohttp import aiohttp
import asyncio import asyncio
@ -180,14 +181,18 @@ class APIConfig(BaseModel):
TZ: str TZ: str
KEYS: List[str] KEYS: List[str]
GARBAGE: Dict[str, Any] GARBAGE: Dict[str, Any]
SPECIAL_TABLES: ClassVar[List[str]] = ['spatial_ref_sys'] SPECIAL_TABLES: ClassVar[List[str]] = ['spatial_ref_sys']
db_pools: Dict[str, Any] = Field(default_factory=dict) db_pools: Dict[str, Any] = Field(default_factory=dict)
offline_servers: Dict[str, float] = Field(default_factory=dict)
offline_timeout: float = Field(default=30.0) # 30 second timeout for offline servers
online_hosts_cache: Dict[str, Tuple[List[Dict[str, Any]], float]] = Field(default_factory=dict)
online_hosts_cache_ttl: float = Field(default=30.0) # Cache TTL in seconds
def __init__(self, **data): def __init__(self, **data):
super().__init__(**data) super().__init__(**data)
self._db_pools = {} self.db_pools = {}
self.online_hosts_cache = {} # Initialize the cache
self._sync_tasks = {}
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True
@ -306,13 +311,48 @@ class APIConfig(BaseModel):
raise ValueError(f"No database configuration found for TS_ID: {ts_id}") raise ValueError(f"No database configuration found for TS_ID: {ts_id}")
return local_db return local_db
@asynccontextmanager async def get_online_hosts(self) -> List[Dict[str, Any]]:
current_time = time.time()
cache_key = "online_hosts"
if cache_key in self.online_hosts_cache:
cached_hosts, cache_time = self.online_hosts_cache[cache_key]
if current_time - cache_time < self.online_hosts_cache_ttl:
return cached_hosts
online_hosts = []
local_ts_id = os.environ.get('TS_ID')
for pool_entry in self.POOL:
if pool_entry['ts_id'] != local_ts_id:
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
if pool_key in self.offline_servers:
if current_time - self.offline_servers[pool_key] < self.offline_timeout:
continue
else:
del self.offline_servers[pool_key]
conn = await self.get_connection(pool_entry)
if conn is not None:
online_hosts.append(pool_entry)
await conn.close()
self.online_hosts_cache[cache_key] = (online_hosts, current_time)
return online_hosts
async def get_connection(self, pool_entry: Dict[str, Any] = None): async def get_connection(self, pool_entry: Dict[str, Any] = None):
if pool_entry is None: if pool_entry is None:
pool_entry = self.local_db pool_entry = self.local_db
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}" pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
# Check if the server is marked as offline
if pool_key in self.offline_servers:
if time.time() - self.offline_servers[pool_key] < self.offline_timeout:
return None
else:
del self.offline_servers[pool_key]
if pool_key not in self.db_pools: if pool_key not in self.db_pools:
try: try:
self.db_pools[pool_key] = await asyncpg.create_pool( self.db_pools[pool_key] = await asyncpg.create_pool(
@ -326,27 +366,21 @@ class APIConfig(BaseModel):
timeout=5 timeout=5
) )
except Exception as e: except Exception as e:
err(f"Failed to create connection pool for {pool_key}: {str(e)}") warn(f"Failed to create connection pool for {pool_key}: {str(e)}")
yield None self.offline_servers[pool_key] = time.time()
return return None
try: try:
async with self.db_pools[pool_key].acquire() as conn: return await asyncio.wait_for(self.db_pools[pool_key].acquire(), timeout=5)
yield conn except asyncio.TimeoutError:
warn(f"Timeout acquiring connection from pool for {pool_key}")
self.offline_servers[pool_key] = time.time()
return None
except Exception as e: except Exception as e:
err(f"Failed to acquire connection from pool for {pool_key}: {str(e)}") warn(f"Failed to acquire connection for {pool_key}: {str(e)}")
yield None self.offline_servers[pool_key] = time.time()
return None
async def close_db_pools(self):
info("Closing database connection pools...")
for pool_key, pool in self.db_pools.items():
try:
await pool.close()
debug(f"Closed pool for {pool_key}")
except Exception as e:
err(f"Error closing pool for {pool_key}: {str(e)}")
self.db_pools.clear()
info("All database connection pools closed.")
async def initialize_sync(self): async def initialize_sync(self):
local_ts_id = os.environ.get('TS_ID') local_ts_id = os.environ.get('TS_ID')
@ -356,52 +390,60 @@ class APIConfig(BaseModel):
if pool_entry['ts_id'] == local_ts_id: if pool_entry['ts_id'] == local_ts_id:
continue # Skip local database continue # Skip local database
try: try:
async with self.get_connection(pool_entry) as conn: conn = await self.get_connection(pool_entry)
if conn is None: if conn is None:
continue # Skip this database if connection failed continue # Skip this database if connection failed
debug(f"Starting sync initialization for {pool_entry['ts_ip']}...") debug(f"Starting sync initialization for {pool_entry['ts_ip']}...")
# Check PostGIS installation # Check PostGIS installation
postgis_installed = await self.check_postgis(conn) postgis_installed = await self.check_postgis(conn)
if not postgis_installed: if not postgis_installed:
warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.") warn(f"PostGIS is not installed on {pool_entry['ts_id']} ({pool_entry['ts_ip']}). Some spatial operations may fail.")
tables = await conn.fetch(""" tables = await conn.fetch("""
SELECT tablename FROM pg_tables SELECT tablename FROM pg_tables
WHERE schemaname = 'public' WHERE schemaname = 'public'
""") """)
for table in tables: for table in tables:
table_name = table['tablename'] table_name = table['tablename']
await self.ensure_sync_columns(conn, table_name) await self.ensure_sync_columns(conn, table_name)
debug(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have necessary sync columns and triggers.") debug(f"Sync initialization complete for {pool_entry['ts_ip']}. All tables now have necessary sync columns and triggers.")
except Exception as e: except Exception as e:
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}") err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
async def ensure_sync_columns(self, conn, table_name): def _schedule_sync_task(self, table_name: str, pk_value: Any, version: int, server_id: str):
if conn is None: # Use a background task manager to handle syncing
debug(f"Skipping offline server...") task_key = f"{table_name}:{pk_value}" if pk_value else table_name
return None if task_key not in self._sync_tasks:
self._sync_tasks[task_key] = asyncio.create_task(self._sync_changes(table_name, pk_value, version, server_id))
if table_name in self.SPECIAL_TABLES:
debug(f"Skipping sync columns for special table: {table_name}") async def ensure_sync_columns(self, conn, table_name):
if conn is None or table_name in self.SPECIAL_TABLES:
return None return None
try: try:
# Get primary key information # Check if primary key exists
primary_key = await conn.fetchval(f""" primary_key = await conn.fetchval(f"""
SELECT a.attname SELECT a.attname
FROM pg_index i FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
AND a.attnum = ANY(i.indkey) WHERE i.indrelid = '{table_name}'::regclass AND i.indisprimary;
WHERE i.indrelid = '{table_name}'::regclass
AND i.indisprimary;
""") """)
if not primary_key:
# Add an id column as primary key if it doesn't exist
await conn.execute(f"""
ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS id SERIAL PRIMARY KEY;
""")
primary_key = 'id'
# Ensure version column exists # Ensure version column exists
await conn.execute(f""" await conn.execute(f"""
ALTER TABLE "{table_name}" ALTER TABLE "{table_name}"
@ -449,17 +491,7 @@ class APIConfig(BaseModel):
except Exception as e: except Exception as e:
err(f"Error ensuring sync columns for table {table_name}: {str(e)}") err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
return None
async def get_online_hosts(self) -> List[Dict[str, Any]]:
online_hosts = []
for pool_entry in self.POOL:
try:
async with self.get_connection(pool_entry) as conn:
if conn is not None:
online_hosts.append(pool_entry)
except Exception as e:
err(f"Error checking host {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
return online_hosts
async def check_postgis(self, conn): async def check_postgis(self, conn):
if conn is None: if conn is None:
@ -478,133 +510,507 @@ class APIConfig(BaseModel):
err(f"Error checking PostGIS: {str(e)}") err(f"Error checking PostGIS: {str(e)}")
return False return False
async def execute_write_query(self, query: str, *args, table_name: str):
if table_name in self.SPECIAL_TABLES:
return await self._execute_special_table_write(query, *args, table_name=table_name)
async with self.get_connection() as conn: async def pull_changes(self, source_pool_entry, batch_size=10000):
if conn is None: if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
raise ConnectionError("Failed to connect to local database") debug("Skipping self-sync")
return 0
# Ensure sync columns exist total_changes = 0
primary_key = await self.ensure_sync_columns(conn, table_name) source_id = source_pool_entry['ts_id']
source_ip = source_pool_entry['ts_ip']
dest_id = os.environ.get('TS_ID')
dest_ip = self.local_db['ts_ip']
# Execute the query info(f"Starting sync from source {source_id} ({source_ip}) to destination {dest_id} ({dest_ip})")
result = await conn.execute(query, *args)
# Get the primary key and new version of the affected row source_conn = None
if primary_key: dest_conn = None
affected_row = await conn.fetchrow(f""" try:
SELECT "{primary_key}", version, server_id source_conn = await self.get_connection(source_pool_entry)
FROM "{table_name}" if source_conn is None:
WHERE version = (SELECT MAX(version) FROM "{table_name}") warn(f"Unable to connect to source {source_id} ({source_ip}). Skipping sync.")
""") return 0
if affected_row:
await self.push_change(table_name, affected_row[primary_key], affected_row['version'], affected_row['server_id'])
else:
# For tables without a primary key, we'll push all rows
await self.push_all_changes(table_name)
return result dest_conn = await self.get_connection(self.local_db)
if dest_conn is None:
warn(f"Unable to connect to local database. Skipping sync.")
return 0
async def push_change(self, table_name: str, pk_value: Any, version: int, server_id: str): tables = await source_conn.fetch("""
online_hosts = await self.get_online_hosts() SELECT tablename FROM pg_tables
for pool_entry in online_hosts: WHERE schemaname = 'public'
if pool_entry['ts_id'] != os.environ.get('TS_ID'): """)
for table in tables:
table_name = table['tablename']
try: try:
async with self.get_connection(pool_entry) as remote_conn: if table_name in self.SPECIAL_TABLES:
await self.sync_special_table(source_conn, dest_conn, table_name)
else:
primary_key = await self.ensure_sync_columns(dest_conn, table_name)
last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
changes = await source_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2
ORDER BY version ASC
LIMIT $3
""", last_synced_version, source_id, batch_size)
if changes:
changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, primary_key)
total_changes += changes_count
if changes_count > 0:
info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}")
else:
debug(f"No changes to sync for {table_name}")
except Exception as e:
err(f"Error syncing table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}")
except Exception as e:
err(f"Error during sync process: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
finally:
if source_conn:
await source_conn.close()
if dest_conn:
await dest_conn.close()
info(f"Sync summary:")
info(f" Total changes: {total_changes}")
info(f" Tables synced: {len(tables) if 'tables' in locals() else 0}")
info(f" Source: {source_id} ({source_ip})")
info(f" Destination: {dest_id} ({dest_ip})")
return total_changes
async def get_last_synced_version(self, conn, table_name, server_id):
if conn is None:
debug(f"Skipping offline server...")
return 0
if table_name in self.SPECIAL_TABLES:
debug(f"Skipping get_last_synced_version because {table_name} is special.")
return 0 # Special handling for tables without version column
try:
last_version = await conn.fetchval(f"""
SELECT COALESCE(MAX(version), 0)
FROM "{table_name}"
WHERE server_id = $1
""", server_id)
return last_version
except Exception as e:
err(f"Error getting last synced version for table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
return 0
async def get_most_recent_source(self):
most_recent_source = None
max_version = -1
local_ts_id = os.environ.get('TS_ID')
online_hosts = await self.get_online_hosts()
num_online_hosts = len(online_hosts)
if num_online_hosts > 0:
online_ts_ids = [host['ts_id'] for host in online_hosts if host['ts_id'] != local_ts_id]
crit(f"Online hosts: {', '.join(online_ts_ids)}")
for pool_entry in online_hosts:
if pool_entry['ts_id'] == local_ts_id:
continue # Skip local database
try:
conn = await self.get_connection(pool_entry)
if conn is None:
warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping.")
continue
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
if table_name in self.SPECIAL_TABLES:
continue
try:
result = await conn.fetchrow(f"""
SELECT MAX(version) as max_version, server_id
FROM "{table_name}"
WHERE version = (SELECT MAX(version) FROM "{table_name}")
GROUP BY server_id
ORDER BY MAX(version) DESC
LIMIT 1
""")
if result:
version, server_id = result['max_version'], result['server_id']
info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})")
if version > max_version:
max_version = version
most_recent_source = pool_entry
else:
debug(f"No data in table {table_name} for {pool_entry['ts_id']}")
except asyncpg.exceptions.UndefinedColumnError:
warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping.")
except Exception as e:
err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}")
except asyncpg.exceptions.ConnectionFailureError:
warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}. Skipping.")
except Exception as e:
err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
finally:
if conn:
await conn.close()
if most_recent_source is None:
if num_online_hosts > 0:
warn("Could not determine most recent source. Using first available online host.")
most_recent_source = next(host for host in online_hosts if host['ts_id'] != local_ts_id)
else:
crit("No other online hosts available for sync.")
return most_recent_source
async def _sync_changes(self, table_name: str, primary_key: str):
try:
local_conn = await self.get_connection()
if local_conn is None:
return
# Get the latest changes
changes = await local_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > (SELECT COALESCE(MAX(version), 0) FROM "{table_name}" WHERE server_id != $1)
OR (version = (SELECT COALESCE(MAX(version), 0) FROM "{table_name}" WHERE server_id != $1) AND server_id = $1)
ORDER BY version ASC
""", os.environ.get('TS_ID'))
if changes:
online_hosts = await self.get_online_hosts()
for pool_entry in online_hosts:
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
remote_conn = await self.get_connection(pool_entry)
if remote_conn is None: if remote_conn is None:
continue continue
try:
await self.apply_batch_changes(remote_conn, table_name, changes, primary_key)
finally:
await remote_conn.close()
except Exception as e:
err(f"Error syncing changes for {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
finally:
if 'local_conn' in locals():
await local_conn.close()
# Fetch the updated row from the local database
async with self.get_connection() as local_conn:
updated_row = await local_conn.fetchrow(f'SELECT * FROM "{table_name}" WHERE "{self.get_primary_key(table_name)}" = $1', pk_value)
if updated_row: async def execute_read_query(self, query: str, *args, table_name: str):
columns = updated_row.keys() online_hosts = await self.get_online_hosts()
placeholders = [f'${i+1}' for i in range(len(columns))] results = []
primary_key = self.get_primary_key(table_name) max_version = -1
latest_result = None
insert_query = f""" for pool_entry in online_hosts:
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) conn = await self.get_connection(pool_entry)
VALUES ({', '.join(placeholders)}) if conn is None:
ON CONFLICT ("{primary_key}") DO UPDATE SET warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping read.")
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != primary_key)}, continue
version = EXCLUDED.version,
server_id = EXCLUDED.server_id try:
WHERE "{table_name}".version < EXCLUDED.version # Execute the query
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id) result = await conn.fetch(query, *args)
"""
await remote_conn.execute(insert_query, *updated_row.values()) if not result:
continue
# Check version if it's not a special table
if table_name not in self.SPECIAL_TABLES:
try:
version_query = f"""
SELECT MAX(version) as max_version, server_id
FROM "{table_name}"
WHERE version = (SELECT MAX(version) FROM "{table_name}")
GROUP BY server_id
ORDER BY MAX(version) DESC
LIMIT 1
"""
version_result = await conn.fetchrow(version_query)
if version_result:
version = version_result['max_version']
server_id = version_result['server_id']
info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})")
if version > max_version:
max_version = version
latest_result = result
else:
debug(f"No data in table {table_name} for {pool_entry['ts_id']}")
except asyncpg.exceptions.UndefinedColumnError:
warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping version check.")
if latest_result is None:
latest_result = result
else:
# For special tables, just use the first result
if latest_result is None:
latest_result = result
results.append((pool_entry['ts_id'], result))
except Exception as e:
err(f"Error executing read query on {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
finally:
await conn.close()
if not latest_result:
warn(f"No results found for query on table {table_name}")
return []
# Log results from all databases
for ts_id, result in results:
info(f"Read result from {ts_id}: {result}")
return [dict(r) for r in latest_result] # Convert Record objects to dictionaries
async def execute_write_query(self, query: str, *args, table_name: str):
conn = await self.get_connection()
if conn is None:
raise ConnectionError("Failed to connect to local database")
try:
if table_name in self.SPECIAL_TABLES:
return await self._execute_special_table_write(conn, query, *args, table_name=table_name)
primary_key = await self.ensure_sync_columns(conn, table_name)
result = await conn.execute(query, *args)
asyncio.create_task(self._sync_changes(table_name, primary_key))
return []
finally:
await conn.close()
async def _run_sync_tasks(self, tasks):
for task in tasks:
try:
await task
except Exception as e:
err(f"Error during background sync: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
async def push_change(self, table_name: str, pk_value: Any, version: int, server_id: str):
asyncio.create_task(self._push_change_background(table_name, pk_value, version, server_id))
async def _push_change_background(self, table_name: str, pk_value: Any, version: int, server_id: str):
online_hosts = await self.get_online_hosts()
successful_pushes = 0
failed_pushes = 0
for pool_entry in online_hosts:
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
remote_conn = await self.get_connection(pool_entry)
if remote_conn is None:
continue
try:
local_conn = await self.get_connection()
if local_conn is None:
continue
try:
updated_row = await local_conn.fetchrow(f'SELECT * FROM "{table_name}" WHERE "{self.get_primary_key(table_name)}" = $1', pk_value)
finally:
await local_conn.close()
if updated_row:
columns = updated_row.keys()
placeholders = [f'${i+1}' for i in range(len(columns))]
primary_key = self.get_primary_key(table_name)
remote_version = await remote_conn.fetchval(f"""
SELECT version FROM "{table_name}"
WHERE "{primary_key}" = $1
""", updated_row[primary_key])
if remote_version is not None and remote_version >= updated_row['version']:
debug(f"Remote version for {table_name} in {pool_entry['ts_id']} is already up to date. Skipping push.")
successful_pushes += 1
continue
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT ("{primary_key}") DO UPDATE SET
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col != primary_key)},
version = EXCLUDED.version,
server_id = EXCLUDED.server_id
WHERE "{table_name}".version < EXCLUDED.version
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
"""
await remote_conn.execute(insert_query, *updated_row.values())
successful_pushes += 1
except Exception as e: except Exception as e:
err(f"Error pushing change to {pool_entry['ts_id']}: {str(e)}") err(f"Error pushing change to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") failed_pushes += 1
finally:
if remote_conn:
await remote_conn.close()
if successful_pushes > 0:
info(f"Successfully pushed changes to {successful_pushes} server(s) for {table_name}")
if failed_pushes > 0:
warn(f"Failed to push changes to {failed_pushes} server(s) for {table_name}")
async def push_all_changes(self, table_name: str): async def push_all_changes(self, table_name: str):
online_hosts = await self.get_online_hosts() online_hosts = await self.get_online_hosts()
tasks = []
for pool_entry in online_hosts: for pool_entry in online_hosts:
if pool_entry['ts_id'] != os.environ.get('TS_ID'): if pool_entry['ts_id'] != os.environ.get('TS_ID'):
try: task = asyncio.create_task(self._push_changes_to_host(pool_entry, table_name))
async with self.get_connection(pool_entry) as remote_conn: tasks.append(task)
if remote_conn is None:
continue
# Fetch all rows from the local database results = await asyncio.gather(*tasks, return_exceptions=True)
async with self.get_connection() as local_conn: successful_pushes = sum(1 for r in results if r is True)
all_rows = await local_conn.fetch(f'SELECT * FROM "{table_name}"') failed_pushes = sum(1 for r in results if r is False or isinstance(r, Exception))
if all_rows: info(f"Push all changes summary for {table_name}: Successful: {successful_pushes}, Failed: {failed_pushes}")
columns = all_rows[0].keys() if failed_pushes > 0:
placeholders = [f'${i+1}' for i in range(len(columns))] warn(f"Failed to push all changes to {failed_pushes} server(s). Data may be out of sync.")
insert_query = f""" async def _push_changes_to_host(self, pool_entry: Dict[str, Any], table_name: str) -> bool:
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)}) remote_conn = None
VALUES ({', '.join(placeholders)}) try:
ON CONFLICT DO NOTHING remote_conn = await self.get_connection(pool_entry)
""" if remote_conn is None:
warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping push.")
return False
# Use a transaction to insert all rows local_conn = await self.get_connection()
async with remote_conn.transaction():
for row in all_rows:
await remote_conn.execute(insert_query, *row.values())
except Exception as e:
err(f"Error pushing all changes to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
async def _execute_special_table_write(self, query: str, *args, table_name: str):
if table_name == 'spatial_ref_sys':
return await self._execute_spatial_ref_sys_write(query, *args)
# Add more special cases as needed
async def _execute_spatial_ref_sys_write(self, query: str, *args):
result = None
async with self.get_connection() as local_conn:
if local_conn is None: if local_conn is None:
raise ConnectionError("Failed to connect to local database") warn(f"Unable to connect to local database. Skipping push.")
return False
# Execute the query locally try:
result = await local_conn.execute(query, *args) all_rows = await local_conn.fetch(f'SELECT * FROM "{table_name}"')
finally:
await local_conn.close()
if all_rows:
columns = list(all_rows[0].keys())
placeholders = [f'${i+1}' for i in range(len(columns))]
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT DO NOTHING
"""
async with remote_conn.transaction():
for row in all_rows:
await remote_conn.execute(insert_query, *row.values())
return True
else:
debug(f"No rows to push for table {table_name}")
return True
except Exception as e:
err(f"Error pushing all changes to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
return False
finally:
if remote_conn:
await remote_conn.close()
async def _execute_special_table_write(self, conn, query: str, *args, table_name: str):
if table_name == 'spatial_ref_sys':
return await self._execute_spatial_ref_sys_write(conn, query, *args)
async def _execute_spatial_ref_sys_write(self, local_conn, query: str, *args):
# Execute the query locally
result = await local_conn.execute(query, *args)
# Sync the entire spatial_ref_sys table with all online hosts # Sync the entire spatial_ref_sys table with all online hosts
online_hosts = await self.get_online_hosts() online_hosts = await self.get_online_hosts()
for pool_entry in online_hosts: for pool_entry in online_hosts:
if pool_entry['ts_id'] != os.environ.get('TS_ID'): if pool_entry['ts_id'] != os.environ.get('TS_ID'):
remote_conn = await self.get_connection(pool_entry)
if remote_conn is None:
continue
try: try:
async with self.get_connection(pool_entry) as remote_conn: await self.sync_spatial_ref_sys(local_conn, remote_conn)
if remote_conn is None:
continue
await self.sync_spatial_ref_sys(local_conn, remote_conn)
except Exception as e: except Exception as e:
err(f"Error syncing spatial_ref_sys to {pool_entry['ts_id']}: {str(e)}") err(f"Error syncing spatial_ref_sys to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}") err(f"Traceback: {traceback.format_exc()}")
finally:
await remote_conn.close()
return result return result
def get_primary_key(self, table_name: str) -> str: async def apply_batch_changes(self, conn, table_name, changes, primary_key):
# This method should return the primary key for the given table if conn is None or not changes:
# You might want to cache this information for performance debug(f"Skipping apply_batch_changes because conn is none or there are no changes.")
# For now, we'll assume it's always 'id', but you should implement proper logic here return 0
return 'id'
try:
columns = list(changes[0].keys())
placeholders = [f'${i+1}' for i in range(len(columns))]
if primary_key:
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT ("{primary_key}") DO UPDATE SET
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in [primary_key, 'version', 'server_id'])},
version = EXCLUDED.version,
server_id = EXCLUDED.server_id
WHERE "{table_name}".version < EXCLUDED.version
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
"""
else:
warn(f"Possible source of issue #4")
# For tables without a primary key, we'll use all columns for conflict resolution
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT DO NOTHING
"""
affected_rows = 0
async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"):
values = [change[col] for col in columns]
result = await conn.execute(insert_query, *values)
affected_rows += int(result.split()[-1])
return affected_rows
except Exception as e:
err(f"Error applying batch changes to {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
return 0
async def sync_special_table(self, source_conn, dest_conn, table_name):
if table_name == 'spatial_ref_sys':
return await self.sync_spatial_ref_sys(source_conn, dest_conn)
# Add more special cases as needed
async def sync_spatial_ref_sys(self, source_conn, dest_conn): async def sync_spatial_ref_sys(self, source_conn, dest_conn):
try: try:
@ -666,6 +1072,77 @@ class APIConfig(BaseModel):
return 0 return 0
def get_primary_key(self, table_name: str) -> str:
# This method should return the primary key for the given table
# You might want to cache this information for performance
# For now, we'll assume it's always 'id', but you should implement proper logic here
return 'id'
async def add_primary_keys_to_local_tables(self):
conn = await self.get_connection()
debug(f"Adding primary keys to existing tables...")
if conn is None:
raise ConnectionError("Failed to connect to local database")
try:
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
if table_name not in self.SPECIAL_TABLES:
await self.ensure_sync_columns(conn, table_name)
finally:
await conn.close()
async def add_primary_keys_to_remote_tables(self):
online_hosts = await self.get_online_hosts()
for pool_entry in online_hosts:
conn = await self.get_connection(pool_entry)
if conn is None:
warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping primary key addition.")
continue
try:
info(f"Adding primary keys to existing tables on {pool_entry['ts_id']}...")
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
if table_name not in self.SPECIAL_TABLES:
primary_key = await self.ensure_sync_columns(conn, table_name)
if primary_key:
info(f"Added/ensured primary key '{primary_key}' for table '{table_name}' on {pool_entry['ts_id']}")
else:
warn(f"Failed to add/ensure primary key for table '{table_name}' on {pool_entry['ts_id']}")
info(f"Completed adding primary keys to existing tables on {pool_entry['ts_id']}")
except Exception as e:
err(f"Error adding primary keys to existing tables on {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
finally:
await conn.close()
async def close_db_pools(self):
info("Closing database connection pools...")
for pool_key, pool in self.db_pools.items():
try:
await pool.close()
debug(f"Closed pool for {pool_key}")
except Exception as e:
err(f"Error closing pool for {pool_key}: {str(e)}")
self.db_pools.clear()
info("All database connection pools closed.")
class Location(BaseModel): class Location(BaseModel):
latitude: float latitude: float
@ -679,7 +1156,7 @@ class Location(BaseModel):
state: Optional[str] = None state: Optional[str] = None
country: Optional[str] = None country: Optional[str] = None
context: Optional[Dict[str, Any]] = None context: Optional[Dict[str, Any]] = None
class_: Optional[str] = None class_: Optional[str] = Field(None, alias="class")
type: Optional[str] = None type: Optional[str] = None
name: Optional[str] = None name: Optional[str] = None
display_name: Optional[str] = None display_name: Optional[str] = None
@ -697,11 +1174,7 @@ class Location(BaseModel):
json_encoders = { json_encoders = {
datetime: lambda dt: dt.isoformat(), datetime: lambda dt: dt.isoformat(),
} }
populate_by_name = True
def model_dump(self):
data = self.dict()
data["datetime"] = self.datetime.isoformat() if self.datetime else None
return data
class Geocoder: class Geocoder:

View file

@ -120,7 +120,6 @@ async def get_last_location() -> Optional[Location]:
return None return None
async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, datetime, None] = None) -> List[Location]: async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, datetime, None] = None) -> List[Location]:
start_datetime = await dt(start) start_datetime = await dt(start)
if end is None: if end is None:
@ -133,10 +132,24 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
debug(f"Fetching locations between {start_datetime} and {end_datetime}") debug(f"Fetching locations between {start_datetime} and {end_datetime}")
async with API.get_connection() as conn: query = '''
locations = [] SELECT id, datetime,
# Check for records within the specified datetime range ST_X(ST_AsText(location)::geometry) AS longitude,
range_locations = await conn.fetch(''' ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street,
action, device_type, device_model, device_name, device_os
FROM locations
WHERE datetime >= $1 AND datetime <= $2
ORDER BY datetime DESC
'''
locations = await API.execute_read_query(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None), table_name="locations")
debug(f"Range locations query returned: {locations}")
if not locations and (end is None or start_datetime.date() == end_datetime.date()):
fallback_query = '''
SELECT id, datetime, SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude, ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude, ST_Y(ST_AsText(location)::geometry) AS latitude,
@ -144,30 +157,14 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
city, state, zip, street, city, state, zip, street,
action, device_type, device_model, device_name, device_os action, device_type, device_model, device_name, device_os
FROM locations FROM locations
WHERE datetime >= $1 AND datetime <= $2 WHERE datetime < $1
ORDER BY datetime DESC ORDER BY datetime DESC
''', start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None)) LIMIT 1
'''
debug(f"Range locations query returned: {range_locations}") location_data = await API.execute_read_query(fallback_query, start_datetime.replace(tzinfo=None), table_name="locations")
locations.extend(range_locations) debug(f"Fallback query returned: {location_data}")
if location_data:
if not locations and (end is None or start_datetime.date() == end_datetime.date()): locations = location_data
location_data = await conn.fetchrow('''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street,
action, device_type, device_model, device_name, device_os
FROM locations
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
''', start_datetime.replace(tzinfo=None))
debug(f"Fallback query returned: {location_data}")
if location_data:
locations.append(location_data)
debug(f"Locations found: {locations}") debug(f"Locations found: {locations}")
@ -197,35 +194,32 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
return location_objects if location_objects else [] return location_objects if location_objects else []
# Function to fetch the last location before the specified datetime
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]: async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
datetime = await dt(datetime) datetime = await dt(datetime)
debug(f"Fetching last location before {datetime}") debug(f"Fetching last location before {datetime}")
async with API.get_connection() as conn: query = '''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street, country,
action
FROM locations
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
'''
location_data = await conn.fetchrow(''' location_data = await API.execute_read_query(query, datetime.replace(tzinfo=None), table_name="locations")
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street, country,
action
FROM locations
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
''', datetime.replace(tzinfo=None))
await conn.close() if location_data:
debug(f"Last location found: {location_data[0]}")
if location_data: return Location(**location_data[0])
debug(f"Last location found: {location_data}") else:
return Location(**location_data) debug("No location found before the specified datetime")
else: return None
debug("No location found before the specified datetime")
return None
@gis.get("/map", response_class=HTMLResponse) @gis.get("/map", response_class=HTMLResponse)
async def generate_map_endpoint( async def generate_map_endpoint(
@ -247,16 +241,12 @@ async def generate_map_endpoint(
return HTMLResponse(content=html_content) return HTMLResponse(content=html_content)
async def get_date_range(): async def get_date_range():
async with API.get_connection() as conn: query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations" row = await API.execute_read_query(query, table_name="locations")
row = await conn.fetchrow(query) if row and row[0]['min_date'] and row[0]['max_date']:
if row and row['min_date'] and row['max_date']: return row[0]['min_date'], row[0]['max_date']
return row['min_date'], row['max_date'] else:
else: return datetime(2022, 1, 1), datetime.now()
return datetime(2022, 1, 1), datetime.now()
async def generate_and_save_heatmap( async def generate_and_save_heatmap(
start_date: Union[str, int, datetime], start_date: Union[str, int, datetime],
@ -313,8 +303,6 @@ Generate a heatmap for the given date range and save it as a PNG file using Foli
err(f"Error generating and saving heatmap: {str(e)}") err(f"Error generating and saving heatmap: {str(e)}")
raise raise
async def generate_map(start_date: datetime, end_date: datetime, max_points: int): async def generate_map(start_date: datetime, end_date: datetime, max_points: int):
locations = await fetch_locations(start_date, end_date) locations = await fetch_locations(start_date, end_date)
if not locations: if not locations:
@ -343,8 +331,6 @@ async def generate_map(start_date: datetime, end_date: datetime, max_points: int
folium.TileLayer('cartodbdark_matter', name='Dark Mode').add_to(m) folium.TileLayer('cartodbdark_matter', name='Dark Mode').add_to(m)
# In the generate_map function:
draw = Draw( draw = Draw(
draw_options={ draw_options={
'polygon': True, 'polygon': True,
@ -433,70 +419,70 @@ map.on(L.Draw.Event.CREATED, function (event) {
return m.get_root().render() return m.get_root().render()
async def post_location(location: Location): async def post_location(location: Location):
# if not location.datetime: try:
# info(f"location appears to be missing datetime: {location}") context = location.context or {}
# else: action = context.get('action', 'manual')
# debug(f"post_location called with {location.datetime}") device_type = context.get('device_type', 'Unknown')
async with API.get_connection() as conn: device_model = context.get('device_model', 'Unknown')
try: device_name = context.get('device_name', 'Unknown')
context = location.context or {} device_os = context.get('device_os', 'Unknown')
action = context.get('action', 'manual')
device_type = context.get('device_type', 'Unknown')
device_model = context.get('device_model', 'Unknown')
device_name = context.get('device_name', 'Unknown')
device_os = context.get('device_os', 'Unknown')
# Parse and localize the datetime # Parse and localize the datetime
localized_datetime = await dt(location.datetime) localized_datetime = await dt(location.datetime)
await conn.execute(''' query = '''
INSERT INTO locations ( INSERT INTO locations (
datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os, datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os,
class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood, class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood,
suburb, county, country_code, country suburb, county, country_code, country
) )
VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15,
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
''', localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state, '''
location.zip, location.street, action, device_type, device_model, device_name, device_os,
location.class_, location.type, location.name, location.display_name,
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
location.suburb, location.county, location.country_code, location.country)
await conn.close() await API.execute_write_query(
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}") query,
return { localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
'datetime': localized_datetime, location.zip, location.street, action, device_type, device_model, device_name, device_os,
'latitude': location.latitude, location.class_, location.type, location.name, location.display_name,
'longitude': location.longitude, location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
'elevation': location.elevation, location.suburb, location.county, location.country_code, location.country,
'city': location.city, table_name="locations"
'state': location.state, )
'zip': location.zip,
'street': location.street, info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")
'action': action, return {
'device_type': device_type, 'datetime': localized_datetime,
'device_model': device_model, 'latitude': location.latitude,
'device_name': device_name, 'longitude': location.longitude,
'device_os': device_os, 'elevation': location.elevation,
'class_': location.class_, 'city': location.city,
'type': location.type, 'state': location.state,
'name': location.name, 'zip': location.zip,
'display_name': location.display_name, 'street': location.street,
'amenity': location.amenity, 'action': action,
'house_number': location.house_number, 'device_type': device_type,
'road': location.road, 'device_model': device_model,
'quarter': location.quarter, 'device_name': device_name,
'neighbourhood': location.neighbourhood, 'device_os': device_os,
'suburb': location.suburb, 'class_': location.class_,
'county': location.county, 'type': location.type,
'country_code': location.country_code, 'name': location.name,
'country': location.country 'display_name': location.display_name,
} 'amenity': location.amenity,
except Exception as e: 'house_number': location.house_number,
err(f"Error posting location {e}") 'road': location.road,
err(traceback.format_exc()) 'quarter': location.quarter,
return None 'neighbourhood': location.neighbourhood,
'suburb': location.suburb,
'county': location.county,
'country_code': location.country_code,
'country': location.country
}
except Exception as e:
err(f"Error posting location {e}")
err(traceback.format_exc())
return None
@gis.post("/locate") @gis.post("/locate")
@ -553,6 +539,7 @@ async def get_last_location_endpoint() -> JSONResponse:
raise HTTPException(status_code=404, detail="No location found before the specified datetime") raise HTTPException(status_code=404, detail="No location found before the specified datetime")
@gis.get("/locate/{datetime_str}", response_model=List[Location]) @gis.get("/locate/{datetime_str}", response_model=List[Location])
async def get_locate(datetime_str: str, all: bool = False): async def get_locate(datetime_str: str, all: bool = False):
try: try:

View file

@ -1,393 +0,0 @@
'''
Uses Postgres/PostGIS for location tracking (data obtained via the companion mobile Pythonista scripts), and for geocoding purposes.
'''
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import HTMLResponse, JSONResponse
import yaml
from typing import List, Tuple, Union
import traceback
from datetime import datetime, timezone
from typing import Union, List
import folium
from zoneinfo import ZoneInfo
from dateutil.parser import parse as dateutil_parse
from typing import Optional, List, Union
from datetime import datetime
from sijapi import L, API, TZ, GEO
from sijapi.classes import Location
from sijapi.utilities import haversine
loc = APIRouter()
logger = L.get_module_logger("loc")
async def dt(
date_time: Union[str, int, datetime],
tz: Union[str, ZoneInfo, None] = None
) -> datetime:
try:
# Convert integer (epoch time) to UTC datetime
if isinstance(date_time, int):
date_time = datetime.utcfromtimestamp(date_time).replace(tzinfo=timezone.utc)
logger.debug(f"Converted epoch time {date_time} to UTC datetime object.")
# Convert string to datetime if necessary
elif isinstance(date_time, str):
date_time = dateutil_parse(date_time)
logger.debug(f"Converted string '{date_time}' to datetime object.")
if not isinstance(date_time, datetime):
raise ValueError(f"Input must be a string, integer (epoch time), or datetime object. What we received: {date_time}, type {type(date_time)}")
# Ensure the datetime is timezone-aware (UTC if not specified)
if date_time.tzinfo is None:
date_time = date_time.replace(tzinfo=timezone.utc)
logger.debug("Added UTC timezone to naive datetime.")
# Handle provided timezone
if tz is not None:
if isinstance(tz, str):
if tz == "local":
last_loc = await get_timezone_without_timezone(date_time)
tz = await GEO.tz_at(last_loc.latitude, last_loc.longitude)
logger.debug(f"Using local timezone: {tz}")
else:
try:
tz = ZoneInfo(tz)
except Exception as e:
logger.error(f"Invalid timezone string '{tz}'. Error: {e}")
raise ValueError(f"Invalid timezone string: {tz}")
elif isinstance(tz, ZoneInfo):
pass # tz is already a ZoneInfo object
else:
raise ValueError(f"What we needed: tz == 'local', a string, or a ZoneInfo object. What we got: tz, a {type(tz)}, == {tz})")
# Convert to the provided or determined timezone
date_time = date_time.astimezone(tz)
logger.debug(f"Converted datetime to timezone: {tz}")
return date_time
except ValueError as e:
logger.error(f"Error in dt: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error in dt: {e}")
raise ValueError(f"Failed to process datetime: {e}")
async def get_timezone_without_timezone(date_time):
# This is a bit convoluted because we're trying to solve the paradox of needing to know the location in order to determine the timezone, but needing the timezone to be certain we've got the right location if this datetime coincided with inter-timezone travel. Our imperfect solution is to use UTC for an initial location query to determine roughly where we were at the time, get that timezone, then check the location again using that timezone, and if this location is different from the one using UTC, get the timezone again usng it, otherwise use the one we already sourced using UTC.
# Step 1: Use UTC as an interim timezone to query location
interim_dt = date_time.replace(tzinfo=ZoneInfo("UTC"))
interim_loc = await fetch_last_location_before(interim_dt)
# Step 2: Get a preliminary timezone based on the interim location
interim_tz = await GEO.tz_current((interim_loc.latitude, interim_loc.longitude))
# Step 3: Apply this preliminary timezone and query location again
query_dt = date_time.replace(tzinfo=ZoneInfo(interim_tz))
query_loc = await fetch_last_location_before(query_dt)
# Step 4: Get the final timezone, reusing interim_tz if location hasn't changed
return interim_tz if query_loc == interim_loc else await GEO.tz_current(query_loc.latitude, query_loc.longitude)
async def get_last_location() -> Optional[Location]:
query_datetime = datetime.now(TZ)
logger.debug(f"Query_datetime: {query_datetime}")
this_location = await fetch_last_location_before(query_datetime)
if this_location:
logger.debug(f"location: {this_location}")
return this_location
return None
async def fetch_locations(start: datetime, end: datetime = None) -> List[Location]:
start_datetime = await dt(start)
if end is None:
end_datetime = await dt(start_datetime.replace(hour=23, minute=59, second=59))
else:
end_datetime = await dt(end)
if start_datetime.time() == datetime.min.time() and end_datetime.time() == datetime.min.time():
end_datetime = end_datetime.replace(hour=23, minute=59, second=59)
logger.debug(f"Fetching locations between {start_datetime} and {end_datetime}")
async with API.get_connection() as conn:
locations = []
# Check for records within the specified datetime range
range_locations = await conn.fetch('''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street,
action, device_type, device_model, device_name, device_os
FROM locations
WHERE datetime >= $1 AND datetime <= $2
ORDER BY datetime DESC
''', start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None))
logger.debug(f"Range locations query returned: {range_locations}")
locations.extend(range_locations)
if not locations and (end is None or start_datetime.date() == end_datetime.date()):
location_data = await conn.fetchrow('''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street,
action, device_type, device_model, device_name, device_os
FROM locations
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
''', start_datetime.replace(tzinfo=None))
logger.debug(f"Fallback query returned: {location_data}")
if location_data:
locations.append(location_data)
logger.debug(f"Locations found: {locations}")
# Sort location_data based on the datetime field in descending order
sorted_locations = sorted(locations, key=lambda x: x['datetime'], reverse=True)
# Create Location objects directly from the location data
location_objects = [
Location(
latitude=location['latitude'],
longitude=location['longitude'],
datetime=location['datetime'],
elevation=location.get('elevation'),
city=location.get('city'),
state=location.get('state'),
zip=location.get('zip'),
street=location.get('street'),
context={
'action': location.get('action'),
'device_type': location.get('device_type'),
'device_model': location.get('device_model'),
'device_name': location.get('device_name'),
'device_os': location.get('device_os')
}
) for location in sorted_locations if location['latitude'] is not None and location['longitude'] is not None
]
return location_objects if location_objects else []
# Function to fetch the last location before the specified datetime
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
datetime = await dt(datetime)
logger.debug(f"Fetching last location before {datetime}")
async with API.get_connection() as conn:
location_data = await conn.fetchrow('''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street, country,
action
FROM locations
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
''', datetime.replace(tzinfo=None))
await conn.close()
if location_data:
logger.debug(f"Last location found: {location_data}")
return Location(**location_data)
else:
logger.debug("No location found before the specified datetime")
return None
@loc.get("/map/start_date={start_date_str}&end_date={end_date_str}", response_class=HTMLResponse)
async def generate_map_endpoint(start_date_str: str, end_date_str: str):
try:
start_date = await dt(start_date_str)
end_date = await dt(end_date_str)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid date format")
html_content = await generate_map(start_date, end_date)
return HTMLResponse(content=html_content)
@loc.get("/map", response_class=HTMLResponse)
async def generate_alltime_map_endpoint():
try:
start_date = await dt(datetime.fromisoformat("2022-01-01"))
end_date = dt(datetime.now())
except ValueError:
raise HTTPException(status_code=400, detail="Invalid date format")
html_content = await generate_map(start_date, end_date)
return HTMLResponse(content=html_content)
async def generate_map(start_date: datetime, end_date: datetime):
locations = await fetch_locations(start_date, end_date)
if not locations:
raise HTTPException(status_code=404, detail="No locations found for the given date range")
# Create a folium map centered around the first location
map_center = [locations[0].latitude, locations[0].longitude]
m = folium.Map(location=map_center, zoom_start=5)
# Add markers for each location
for location in locations:
folium.Marker(
location=[location.latitude, location.longitude],
popup=f"{location.city}, {location.state}<br>Elevation: {location.elevation}m<br>Date: {location.datetime}",
tooltip=f"{location.city}, {location.state}"
).add_to(m)
# Save the map to an HTML file and return the HTML content
map_html = "map.html"
m.save(map_html)
with open(map_html, 'r') as file:
html_content = file.read()
return html_content
async def post_location(location: Location):
async with API.get_connection() as conn:
try:
context = location.context or {}
action = context.get('action', 'manual')
device_type = context.get('device_type', 'Unknown')
device_model = context.get('device_model', 'Unknown')
device_name = context.get('device_name', 'Unknown')
device_os = context.get('device_os', 'Unknown')
# Parse and localize the datetime
localized_datetime = await dt(location.datetime)
await conn.execute('''
INSERT INTO locations (
datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os,
class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood,
suburb, county, country_code, country
)
VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15,
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
''', localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
location.zip, location.street, action, device_type, device_model, device_name, device_os,
location.class_, location.type, location.name, location.display_name,
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
location.suburb, location.county, location.country_code, location.country)
await conn.close()
logger.info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")
return {
'datetime': localized_datetime,
'latitude': location.latitude,
'longitude': location.longitude,
'elevation': location.elevation,
'city': location.city,
'state': location.state,
'zip': location.zip,
'street': location.street,
'action': action,
'device_type': device_type,
'device_model': device_model,
'device_name': device_name,
'device_os': device_os,
'class_': location.class_,
'type': location.type,
'name': location.name,
'display_name': location.display_name,
'amenity': location.amenity,
'house_number': location.house_number,
'road': location.road,
'quarter': location.quarter,
'neighbourhood': location.neighbourhood,
'suburb': location.suburb,
'county': location.county,
'country_code': location.country_code,
'country': location.country
}
except Exception as e:
logger.error(f"Error posting location {e}")
logger.error(traceback.format_exc())
return None
@loc.post("/locate")
async def post_locate_endpoint(locations: Union[Location, List[Location]]):
if isinstance(locations, Location):
locations = [locations]
# Prepare locations
for lcn in locations:
if not lcn.datetime:
tz = await GEO.tz_at(lcn.latitude, lcn.longitude)
lcn.datetime = datetime.now(ZoneInfo(tz)).isoformat()
if not lcn.context:
lcn.context = {
"action": "missing",
"device_type": "API",
"device_model": "Unknown",
"device_name": "Unknown",
"device_os": "Unknown"
}
logger.debug(f"Location received for processing: {lcn}")
geocoded_locations = await GEO.code(locations)
responses = []
if isinstance(geocoded_locations, List):
for location in geocoded_locations:
logger.debug(f"Final location to be submitted to database: {location}")
location_entry = await post_location(location)
if location_entry:
responses.append({"location_data": location_entry})
else:
logger.warning(f"Posting location to database appears to have failed.")
else:
logger.debug(f"Final location to be submitted to database: {geocoded_locations}")
location_entry = await post_location(geocoded_locations)
if location_entry:
responses.append({"location_data": location_entry})
else:
logger.warning(f"Posting location to database appears to have failed.")
return {"message": "Locations and weather updated", "results": responses}
@loc.get("/locate", response_model=Location)
async def get_last_location_endpoint() -> JSONResponse:
this_location = await get_last_location()
if this_location:
location_dict = this_location.model_dump()
location_dict["datetime"] = this_location.datetime.isoformat()
return JSONResponse(content=location_dict)
else:
raise HTTPException(status_code=404, detail="No location found before the specified datetime")
@loc.get("/locate/{datetime_str}", response_model=List[Location])
async def get_locate(datetime_str: str, all: bool = False):
try:
date_time = await dt(datetime_str)
except ValueError as e:
logger.error(f"Invalid datetime string provided: {datetime_str}")
return ["ERROR: INVALID DATETIME PROVIDED. USE YYYYMMDDHHmmss or YYYYMMDD format."]
locations = await fetch_locations(date_time)
if not locations:
raise HTTPException(status_code=404, detail="No nearby data found for this date and time")
return locations if all else [locations[0]]

View file

@ -212,7 +212,6 @@ if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True:
except requests.exceptions.RequestException: except requests.exceptions.RequestException:
results.append(f"{address} is down") results.append(f"{address} is down")
# Generate a simple text-based graph
graph = '|' * up_count + '.' * (len(addresses) - up_count) graph = '|' * up_count + '.' * (len(addresses) - up_count)
text_update = "\n".join(results) text_update = "\n".join(results)
@ -220,7 +219,6 @@ if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True:
output = shellfish_run_widget_command(widget_command) output = shellfish_run_widget_command(widget_command)
return {"output": output, "graph": graph} return {"output": output, "graph": graph}
def shellfish_update_widget(update: WidgetUpdate): def shellfish_update_widget(update: WidgetUpdate):
widget_command = ["widget"] widget_command = ["widget"]
@ -291,6 +289,7 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True:
bg_tasks.add_task(cl_docket_process, result) bg_tasks.add_task(cl_docket_process, result)
return JSONResponse(content={"message": "Received"}, status_code=status.HTTP_200_OK) return JSONResponse(content={"message": "Received"}, status_code=status.HTTP_200_OK)
async def cl_docket_process(result): async def cl_docket_process(result):
async with httpx.AsyncClient() as session: async with httpx.AsyncClient() as session:
await cl_docket_process_result(result, session) await cl_docket_process_result(result, session)
@ -346,12 +345,14 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True:
await cl_download_file(file_url, target_path, session) await cl_download_file(file_url, target_path, session)
debug(f"Downloaded {file_name} to {target_path}") debug(f"Downloaded {file_name} to {target_path}")
def cl_case_details(docket): def cl_case_details(docket):
case_info = CASETABLE.get(str(docket), {"code": "000", "shortname": "UNKNOWN"}) case_info = CASETABLE.get(str(docket), {"code": "000", "shortname": "UNKNOWN"})
case_code = case_info.get("code") case_code = case_info.get("code")
short_name = case_info.get("shortname") short_name = case_info.get("shortname")
return case_code, short_name return case_code, short_name
async def cl_download_file(url: str, path: Path, session: aiohttp.ClientSession = None): async def cl_download_file(url: str, path: Path, session: aiohttp.ClientSession = None):
headers = { headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.82 Safari/537.36' 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.82 Safari/537.36'
@ -417,19 +418,20 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True:
await cl_download_file(download_url, target_path, session) await cl_download_file(download_url, target_path, session)
debug(f"Downloaded {file_name} to {target_path}") debug(f"Downloaded {file_name} to {target_path}")
@serve.get("/s", response_class=HTMLResponse) @serve.get("/s", response_class=HTMLResponse)
async def shortener_form(request: Request): async def shortener_form(request: Request):
return templates.TemplateResponse("shortener.html", {"request": request}) return templates.TemplateResponse("shortener.html", {"request": request})
@serve.post("/s") @serve.post("/s")
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)): async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
await create_tables()
if custom_code: if custom_code:
if len(custom_code) != 3 or not custom_code.isalnum(): if len(custom_code) != 3 or not custom_code.isalnum():
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"}) return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
existing = await API.execute_write_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls") existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls")
if existing: if existing:
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"}) return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
@ -437,8 +439,9 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c
else: else:
chars = string.ascii_letters + string.digits chars = string.ascii_letters + string.digits
while True: while True:
debug(f"FOUND THE ISSUE")
short_code = ''.join(random.choice(chars) for _ in range(3)) short_code = ''.join(random.choice(chars) for _ in range(3))
existing = await API.execute_write_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls") existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls")
if not existing: if not existing:
break break
@ -451,48 +454,36 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c
short_url = f"https://sij.ai/{short_code}" short_url = f"https://sij.ai/{short_code}"
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url}) return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
async def create_tables():
await API.execute_write_query('''
CREATE TABLE IF NOT EXISTS short_urls (
short_code VARCHAR(3) PRIMARY KEY,
long_url TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''', table_name="short_urls")
await API.execute_write_query('''
CREATE TABLE IF NOT EXISTS click_logs (
id SERIAL PRIMARY KEY,
short_code VARCHAR(3) REFERENCES short_urls(short_code),
clicked_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ip_address TEXT,
user_agent TEXT
)
''', table_name="click_logs")
@serve.get("/{short_code}", response_class=RedirectResponse, status_code=301) @serve.get("/{short_code}")
async def redirect_short_url(request: Request, short_code: str = PathParam(..., min_length=3, max_length=3)): async def redirect_short_url(short_code: str):
if request.headers.get('host') != 'sij.ai': results = await API.execute_read_query(
raise HTTPException(status_code=404, detail="Not Found")
result = await API.execute_write_query(
'SELECT long_url FROM short_urls WHERE short_code = $1', 'SELECT long_url FROM short_urls WHERE short_code = $1',
short_code, short_code,
table_name="short_urls" table_name="short_urls"
) )
if result: if not results:
await API.execute_write_query(
'INSERT INTO click_logs (short_code, ip_address, user_agent) VALUES ($1, $2, $3)',
short_code, request.client.host, request.headers.get("user-agent"),
table_name="click_logs"
)
return result['long_url']
else:
raise HTTPException(status_code=404, detail="Short URL not found") raise HTTPException(status_code=404, detail="Short URL not found")
long_url = results[0].get('long_url')
if not long_url:
raise HTTPException(status_code=404, detail="Long URL not found")
# Increment click count (you may want to do this asynchronously)
await API.execute_write_query(
'INSERT INTO click_logs (short_code, clicked_at) VALUES ($1, $2)',
short_code, datetime.now(),
table_name="click_logs"
)
return RedirectResponse(url=long_url)
@serve.get("/analytics/{short_code}") @serve.get("/analytics/{short_code}")
async def get_analytics(short_code: str): async def get_analytics(short_code: str):
url_info = await API.execute_write_query( url_info = await API.execute_read_query(
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1', 'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
short_code, short_code,
table_name="short_urls" table_name="short_urls"
@ -500,13 +491,13 @@ async def get_analytics(short_code: str):
if not url_info: if not url_info:
raise HTTPException(status_code=404, detail="Short URL not found") raise HTTPException(status_code=404, detail="Short URL not found")
click_count = await API.execute_write_query( click_count = await API.execute_read_query(
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1', 'SELECT COUNT(*) FROM click_logs WHERE short_code = $1',
short_code, short_code,
table_name="click_logs" table_name="click_logs"
) )
clicks = await API.execute_write_query( clicks = await API.execute_read_query(
'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100', 'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
short_code, short_code,
table_name="click_logs" table_name="click_logs"
@ -521,15 +512,20 @@ async def get_analytics(short_code: str):
} }
async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, destination: str): async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, destination: str):
dest_host, dest_port = destination.split(':') try:
dest_port = int(dest_port) dest_host, dest_port = destination.split(':')
dest_port = int(dest_port)
except ValueError:
warn(f"Invalid destination format: {destination}. Expected 'host:port'.")
writer.close()
await writer.wait_closed()
return
try: try:
dest_reader, dest_writer = await asyncio.open_connection(dest_host, dest_port) dest_reader, dest_writer = await asyncio.open_connection(dest_host, dest_port)
except Exception as e: except Exception as e:
warn(f"Failed to connect to destination {destination}: {str(e)}")
writer.close() writer.close()
await writer.wait_closed() await writer.wait_closed()
return return
@ -543,7 +539,7 @@ async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWr
dst.write(data) dst.write(data)
await dst.drain() await dst.drain()
except Exception as e: except Exception as e:
pass warn(f"Error in forwarding: {str(e)}")
finally: finally:
dst.close() dst.close()
await dst.wait_closed() await dst.wait_closed()
@ -554,8 +550,12 @@ async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWr
) )
async def start_server(source: str, destination: str): async def start_server(source: str, destination: str):
host, port = source.split(':') if ':' in source:
port = int(port) host, port = source.split(':')
port = int(port)
else:
host = source
port = 80
server = await asyncio.start_server( server = await asyncio.start_server(
lambda r, w: forward_traffic(r, w, destination), lambda r, w: forward_traffic(r, w, destination),
@ -566,6 +566,7 @@ async def start_server(source: str, destination: str):
async with server: async with server:
await server.serve_forever() await server.serve_forever()
async def start_port_forwarding(): async def start_port_forwarding():
if hasattr(Serve, 'forwarding_rules'): if hasattr(Serve, 'forwarding_rules'):
for rule in Serve.forwarding_rules: for rule in Serve.forwarding_rules:
@ -573,6 +574,7 @@ async def start_port_forwarding():
else: else:
warn("No forwarding rules found in the configuration.") warn("No forwarding rules found in the configuration.")
@serve.get("/forward_status") @serve.get("/forward_status")
async def get_forward_status(): async def get_forward_status():
if hasattr(Serve, 'forwarding_rules'): if hasattr(Serve, 'forwarding_rules'):
@ -580,5 +582,5 @@ async def get_forward_status():
else: else:
return {"status": "inactive", "message": "No forwarding rules configured"} return {"status": "inactive", "message": "No forwarding rules configured"}
# Add this to the end of your serve.py file
asyncio.create_task(start_port_forwarding()) asyncio.create_task(start_port_forwarding())

View file

@ -116,129 +116,129 @@ async def get_weather(date_time: dt_datetime, latitude: float, longitude: float,
async def store_weather_to_db(date_time: dt_datetime, weather_data: dict): async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
warn(f"Using {date_time.strftime('%Y-%m-%d %H:%M:%S')} as our datetime in store_weather_to_db") warn(f"Using {date_time.strftime('%Y-%m-%d %H:%M:%S')} as our datetime in store_weather_to_db")
async with API.get_connection() as conn: try:
try: day_data = weather_data.get('days')[0]
day_data = weather_data.get('days')[0] debug(f"RAW DAY_DATA: {day_data}")
debug(f"RAW DAY_DATA: {day_data}") # Handle preciptype and stations as PostgreSQL arrays
# Handle preciptype and stations as PostgreSQL arrays preciptype_array = day_data.get('preciptype', []) or []
preciptype_array = day_data.get('preciptype', []) or [] stations_array = day_data.get('stations', []) or []
stations_array = day_data.get('stations', []) or []
date_str = date_time.strftime("%Y-%m-%d") date_str = date_time.strftime("%Y-%m-%d")
warn(f"Using {date_str} in our query in store_weather_to_db.") warn(f"Using {date_str} in our query in store_weather_to_db.")
# Get location details from weather data if available # Get location details from weather data if available
longitude = weather_data.get('longitude') longitude = weather_data.get('longitude')
latitude = weather_data.get('latitude') latitude = weather_data.get('latitude')
tz = await GEO.tz_at(latitude, longitude) tz = await GEO.tz_at(latitude, longitude)
elevation = await GEO.elevation(latitude, longitude) elevation = await GEO.elevation(latitude, longitude)
location_point = f"POINTZ({longitude} {latitude} {elevation})" if longitude and latitude and elevation else None location_point = f"POINTZ({longitude} {latitude} {elevation})" if longitude and latitude and elevation else None
warn(f"Uncorrected datetimes in store_weather_to_db: {day_data['datetime']}, sunrise: {day_data['sunrise']}, sunset: {day_data['sunset']}") warn(f"Uncorrected datetimes in store_weather_to_db: {day_data['datetime']}, sunrise: {day_data['sunrise']}, sunset: {day_data['sunset']}")
day_data['datetime'] = await gis.dt(day_data.get('datetimeEpoch')) day_data['datetime'] = await gis.dt(day_data.get('datetimeEpoch'))
day_data['sunrise'] = await gis.dt(day_data.get('sunriseEpoch')) day_data['sunrise'] = await gis.dt(day_data.get('sunriseEpoch'))
day_data['sunset'] = await gis.dt(day_data.get('sunsetEpoch')) day_data['sunset'] = await gis.dt(day_data.get('sunsetEpoch'))
warn(f"Corrected datetimes in store_weather_to_db: {day_data['datetime']}, sunrise: {day_data['sunrise']}, sunset: {day_data['sunset']}") warn(f"Corrected datetimes in store_weather_to_db: {day_data['datetime']}, sunrise: {day_data['sunrise']}, sunset: {day_data['sunset']}")
daily_weather_params = ( daily_weather_params = (
day_data.get('sunrise'), day_data.get('sunriseEpoch'), day_data.get('sunrise'), day_data.get('sunriseEpoch'),
day_data.get('sunset'), day_data.get('sunsetEpoch'), day_data.get('sunset'), day_data.get('sunsetEpoch'),
day_data.get('description'), day_data.get('tempmax'), day_data.get('description'), day_data.get('tempmax'),
day_data.get('tempmin'), day_data.get('uvindex'), day_data.get('tempmin'), day_data.get('uvindex'),
day_data.get('winddir'), day_data.get('windspeed'), day_data.get('winddir'), day_data.get('windspeed'),
day_data.get('icon'), dt_datetime.now(tz), day_data.get('icon'), dt_datetime.now(tz),
day_data.get('datetime'), day_data.get('datetimeEpoch'), day_data.get('datetime'), day_data.get('datetimeEpoch'),
day_data.get('temp'), day_data.get('feelslikemax'), day_data.get('temp'), day_data.get('feelslikemax'),
day_data.get('feelslikemin'), day_data.get('feelslike'), day_data.get('feelslikemin'), day_data.get('feelslike'),
day_data.get('dew'), day_data.get('humidity'), day_data.get('dew'), day_data.get('humidity'),
day_data.get('precip'), day_data.get('precipprob'), day_data.get('precip'), day_data.get('precipprob'),
day_data.get('precipcover'), preciptype_array, day_data.get('precipcover'), preciptype_array,
day_data.get('snow'), day_data.get('snowdepth'), day_data.get('snow'), day_data.get('snowdepth'),
day_data.get('windgust'), day_data.get('pressure'), day_data.get('windgust'), day_data.get('pressure'),
day_data.get('cloudcover'), day_data.get('visibility'), day_data.get('cloudcover'), day_data.get('visibility'),
day_data.get('solarradiation'), day_data.get('solarenergy'), day_data.get('solarradiation'), day_data.get('solarenergy'),
day_data.get('severerisk', 0), day_data.get('moonphase'), day_data.get('severerisk', 0), day_data.get('moonphase'),
day_data.get('conditions'), stations_array, day_data.get('source'), day_data.get('conditions'), stations_array, day_data.get('source'),
location_point location_point
) )
except Exception as e: except Exception as e:
err(f"Failed to prepare database query in store_weather_to_db! {e}") err(f"Failed to prepare database query in store_weather_to_db! {e}")
return "FAILURE"
try: try:
daily_weather_query = ''' daily_weather_query = '''
INSERT INTO DailyWeather ( INSERT INTO DailyWeather (
sunrise, sunriseepoch, sunset, sunsetepoch, description, sunrise, sunriseepoch, sunset, sunsetepoch, description,
tempmax, tempmin, uvindex, winddir, windspeed, icon, last_updated, tempmax, tempmin, uvindex, winddir, windspeed, icon, last_updated,
datetime, datetimeepoch, temp, feelslikemax, feelslikemin, feelslike, datetime, datetimeepoch, temp, feelslikemax, feelslikemin, feelslike,
dew, humidity, precip, precipprob, precipcover, preciptype, dew, humidity, precip, precipprob, precipcover, preciptype,
snow, snowdepth, windgust, pressure, cloudcover, visibility, snow, snowdepth, windgust, pressure, cloudcover, visibility,
solarradiation, solarenergy, severerisk, moonphase, conditions, solarradiation, solarenergy, severerisk, moonphase, conditions,
stations, source, location stations, source, location
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38) ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38)
RETURNING id RETURNING id
''' '''
async with conn.transaction(): daily_weather_id = await API.execute_write_query(daily_weather_query, *daily_weather_params, table_name="DailyWeather")
daily_weather_id = await conn.fetchval(daily_weather_query, *daily_weather_params)
if 'hours' in day_data:
debug(f"Processing hours now...")
for hour_data in day_data['hours']:
try:
await asyncio.sleep(0.01)
hour_data['datetime'] = await gis.dt(hour_data.get('datetimeEpoch'))
hour_preciptype_array = hour_data.get('preciptype', []) or []
hour_stations_array = hour_data.get('stations', []) or []
hourly_weather_params = (
daily_weather_id,
hour_data['datetime'],
hour_data.get('datetimeEpoch'),
hour_data['temp'],
hour_data['feelslike'],
hour_data['humidity'],
hour_data['dew'],
hour_data['precip'],
hour_data['precipprob'],
hour_preciptype_array,
hour_data['snow'],
hour_data['snowdepth'],
hour_data['windgust'],
hour_data['windspeed'],
hour_data['winddir'],
hour_data['pressure'],
hour_data['cloudcover'],
hour_data['visibility'],
hour_data['solarradiation'],
hour_data['solarenergy'],
hour_data['uvindex'],
hour_data.get('severerisk', 0),
hour_data['conditions'],
hour_data['icon'],
hour_stations_array,
hour_data.get('source', ''),
)
if 'hours' in day_data:
debug(f"Processing hours now...")
for hour_data in day_data['hours']:
try: try:
await asyncio.sleep(0.01) hourly_weather_query = '''
hour_data['datetime'] = await gis.dt(hour_data.get('datetimeEpoch')) INSERT INTO HourlyWeather (daily_weather_id, datetime, datetimeepoch, temp, feelslike, humidity, dew, precip, precipprob,
hour_preciptype_array = hour_data.get('preciptype', []) or [] preciptype, snow, snowdepth, windgust, windspeed, winddir, pressure, cloudcover, visibility, solarradiation, solarenergy,
hour_stations_array = hour_data.get('stations', []) or [] uvindex, severerisk, conditions, icon, stations, source)
hourly_weather_params = ( VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
daily_weather_id, RETURNING id
hour_data['datetime'], '''
hour_data.get('datetimeEpoch'), hourly_weather_id = await API.execute_write_query(hourly_weather_query, *hourly_weather_params, table_name="HourlyWeather")
hour_data['temp'], debug(f"Done processing hourly_weather_id {hourly_weather_id}")
hour_data['feelslike'],
hour_data['humidity'],
hour_data['dew'],
hour_data['precip'],
hour_data['precipprob'],
hour_preciptype_array,
hour_data['snow'],
hour_data['snowdepth'],
hour_data['windgust'],
hour_data['windspeed'],
hour_data['winddir'],
hour_data['pressure'],
hour_data['cloudcover'],
hour_data['visibility'],
hour_data['solarradiation'],
hour_data['solarenergy'],
hour_data['uvindex'],
hour_data.get('severerisk', 0),
hour_data['conditions'],
hour_data['icon'],
hour_stations_array,
hour_data.get('source', ''),
)
try:
hourly_weather_query = '''
INSERT INTO HourlyWeather (daily_weather_id, datetime, datetimeepoch, temp, feelslike, humidity, dew, precip, precipprob,
preciptype, snow, snowdepth, windgust, windspeed, winddir, pressure, cloudcover, visibility, solarradiation, solarenergy,
uvindex, severerisk, conditions, icon, stations, source)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
RETURNING id
'''
async with conn.transaction():
hourly_weather_id = await conn.fetchval(hourly_weather_query, *hourly_weather_params)
debug(f"Done processing hourly_weather_id {hourly_weather_id}")
except Exception as e:
err(f"EXCEPTION: {e}")
except Exception as e: except Exception as e:
err(f"EXCEPTION: {e}") err(f"EXCEPTION: {e}")
return "SUCCESS" except Exception as e:
err(f"EXCEPTION: {e}")
return "SUCCESS"
except Exception as e:
err(f"Error in dailyweather storage: {e}")
return "FAILURE"
except Exception as e:
err(f"Error in dailyweather storage: {e}")