Much more efficient database sync method, first working version.

This commit is contained in:
sanj 2024-08-01 23:51:01 -07:00
parent d01f24ad45
commit 299d27b1e7
7 changed files with 950 additions and 882 deletions

View file

@ -20,6 +20,7 @@ L = Logger("Central", LOGS_DIR)
# API essentials
API = APIConfig.load('api', 'secrets')
Dir = Configuration.load('dirs')
HOST = f"{API.BIND}:{API.PORT}"
LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']

View file

@ -39,12 +39,11 @@ def warn(text: str): logger.warning(text)
def err(text: str): logger.error(text)
def crit(text: str): logger.critical(text)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
crit("sijapi launched")
crit(f"Arguments: {args}")
info(f"Arguments: {args}")
# Load routers
if args.test:
@ -54,20 +53,10 @@ async def lifespan(app: FastAPI):
if getattr(API.MODULES, module_name):
load_router(module_name)
crit("Starting database synchronization...")
try:
# Initialize sync structures on all databases
await API.initialize_sync()
# Check if other instances have more recent data
source = await API.get_most_recent_source()
if source:
crit(f"Pulling changes from {source['ts_id']} ({source['ts_ip']})...")
total_changes = await API.pull_changes(source)
crit(f"Data pull complete. Total changes: {total_changes}")
else:
crit("No instances with more recent data found or all instances are offline.")
except Exception as e:
crit(f"Error during startup: {str(e)}")
crit(f"Traceback: {traceback.format_exc()}")
@ -79,8 +68,6 @@ async def lifespan(app: FastAPI):
await API.close_db_pools()
crit("Database pools closed.")
app = FastAPI(lifespan=lifespan)
app.add_middleware(
@ -135,30 +122,41 @@ async def handle_exception_middleware(request: Request, call_next):
return response
# This was removed on 7/31/2024 when we decided to instead use a targeted push sync approach.
deprecated = '''
async def push_changes_background():
@app.post("/sync/pull")
async def pull_changes():
try:
await API.push_changes_to_all()
await API.add_primary_keys_to_local_tables()
await API.add_primary_keys_to_remote_tables()
try:
source = await API.get_most_recent_source()
if source:
# Pull changes from the source
total_changes = await API.pull_changes(source)
return JSONResponse(content={
"status": "success",
"message": f"Pull complete. Total changes: {total_changes}",
"source": f"{source['ts_id']} ({source['ts_ip']})",
"changes": total_changes
})
else:
return JSONResponse(content={
"status": "info",
"message": "No instances with more recent data found or all instances are offline."
})
except Exception as e:
err(f"Error pushing changes to other databases: {str(e)}")
err(f"Error during pull: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Error during pull: {str(e)}")
@app.middleware("http")
async def sync_middleware(request: Request, call_next):
response = await call_next(request)
# Check if the request was a database write operation
if request.method in ["POST", "PUT", "PATCH", "DELETE"]:
try:
# Push changes to other databases
await API.push_changes_to_all()
except Exception as e:
err(f"Error pushing changes to other databases: {str(e)}")
err(f"Error while ensuring primary keys to tables: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail=f"Error during primary key insurance: {str(e)}")
return response
'''
def load_router(router_name):
router_file = ROUTER_DIR / f'{router_name}.py'

View file

@ -5,6 +5,7 @@ import math
import os
import re
import uuid
import time
import aiofiles
import aiohttp
import asyncio
@ -180,14 +181,18 @@ class APIConfig(BaseModel):
TZ: str
KEYS: List[str]
GARBAGE: Dict[str, Any]
SPECIAL_TABLES: ClassVar[List[str]] = ['spatial_ref_sys']
db_pools: Dict[str, Any] = Field(default_factory=dict)
offline_servers: Dict[str, float] = Field(default_factory=dict)
offline_timeout: float = Field(default=30.0) # 30 second timeout for offline servers
online_hosts_cache: Dict[str, Tuple[List[Dict[str, Any]], float]] = Field(default_factory=dict)
online_hosts_cache_ttl: float = Field(default=30.0) # Cache TTL in seconds
def __init__(self, **data):
super().__init__(**data)
self._db_pools = {}
self.db_pools = {}
self.online_hosts_cache = {} # Initialize the cache
self._sync_tasks = {}
class Config:
arbitrary_types_allowed = True
@ -306,13 +311,48 @@ class APIConfig(BaseModel):
raise ValueError(f"No database configuration found for TS_ID: {ts_id}")
return local_db
@asynccontextmanager
async def get_online_hosts(self) -> List[Dict[str, Any]]:
current_time = time.time()
cache_key = "online_hosts"
if cache_key in self.online_hosts_cache:
cached_hosts, cache_time = self.online_hosts_cache[cache_key]
if current_time - cache_time < self.online_hosts_cache_ttl:
return cached_hosts
online_hosts = []
local_ts_id = os.environ.get('TS_ID')
for pool_entry in self.POOL:
if pool_entry['ts_id'] != local_ts_id:
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
if pool_key in self.offline_servers:
if current_time - self.offline_servers[pool_key] < self.offline_timeout:
continue
else:
del self.offline_servers[pool_key]
conn = await self.get_connection(pool_entry)
if conn is not None:
online_hosts.append(pool_entry)
await conn.close()
self.online_hosts_cache[cache_key] = (online_hosts, current_time)
return online_hosts
async def get_connection(self, pool_entry: Dict[str, Any] = None):
if pool_entry is None:
pool_entry = self.local_db
pool_key = f"{pool_entry['ts_ip']}:{pool_entry['db_port']}"
# Check if the server is marked as offline
if pool_key in self.offline_servers:
if time.time() - self.offline_servers[pool_key] < self.offline_timeout:
return None
else:
del self.offline_servers[pool_key]
if pool_key not in self.db_pools:
try:
self.db_pools[pool_key] = await asyncpg.create_pool(
@ -326,27 +366,21 @@ class APIConfig(BaseModel):
timeout=5
)
except Exception as e:
err(f"Failed to create connection pool for {pool_key}: {str(e)}")
yield None
return
warn(f"Failed to create connection pool for {pool_key}: {str(e)}")
self.offline_servers[pool_key] = time.time()
return None
try:
async with self.db_pools[pool_key].acquire() as conn:
yield conn
return await asyncio.wait_for(self.db_pools[pool_key].acquire(), timeout=5)
except asyncio.TimeoutError:
warn(f"Timeout acquiring connection from pool for {pool_key}")
self.offline_servers[pool_key] = time.time()
return None
except Exception as e:
err(f"Failed to acquire connection from pool for {pool_key}: {str(e)}")
yield None
warn(f"Failed to acquire connection for {pool_key}: {str(e)}")
self.offline_servers[pool_key] = time.time()
return None
async def close_db_pools(self):
info("Closing database connection pools...")
for pool_key, pool in self.db_pools.items():
try:
await pool.close()
debug(f"Closed pool for {pool_key}")
except Exception as e:
err(f"Error closing pool for {pool_key}: {str(e)}")
self.db_pools.clear()
info("All database connection pools closed.")
async def initialize_sync(self):
local_ts_id = os.environ.get('TS_ID')
@ -356,7 +390,7 @@ class APIConfig(BaseModel):
if pool_entry['ts_id'] == local_ts_id:
continue # Skip local database
try:
async with self.get_connection(pool_entry) as conn:
conn = await self.get_connection(pool_entry)
if conn is None:
continue # Skip this database if connection failed
@ -382,26 +416,34 @@ class APIConfig(BaseModel):
err(f"Error initializing sync for {pool_entry['ts_ip']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
async def ensure_sync_columns(self, conn, table_name):
if conn is None:
debug(f"Skipping offline server...")
return None
def _schedule_sync_task(self, table_name: str, pk_value: Any, version: int, server_id: str):
# Use a background task manager to handle syncing
task_key = f"{table_name}:{pk_value}" if pk_value else table_name
if task_key not in self._sync_tasks:
self._sync_tasks[task_key] = asyncio.create_task(self._sync_changes(table_name, pk_value, version, server_id))
if table_name in self.SPECIAL_TABLES:
debug(f"Skipping sync columns for special table: {table_name}")
async def ensure_sync_columns(self, conn, table_name):
if conn is None or table_name in self.SPECIAL_TABLES:
return None
try:
# Get primary key information
# Check if primary key exists
primary_key = await conn.fetchval(f"""
SELECT a.attname
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid
AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = '{table_name}'::regclass
AND i.indisprimary;
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = '{table_name}'::regclass AND i.indisprimary;
""")
if not primary_key:
# Add an id column as primary key if it doesn't exist
await conn.execute(f"""
ALTER TABLE "{table_name}"
ADD COLUMN IF NOT EXISTS id SERIAL PRIMARY KEY;
""")
primary_key = 'id'
# Ensure version column exists
await conn.execute(f"""
ALTER TABLE "{table_name}"
@ -449,17 +491,7 @@ class APIConfig(BaseModel):
except Exception as e:
err(f"Error ensuring sync columns for table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
async def get_online_hosts(self) -> List[Dict[str, Any]]:
online_hosts = []
for pool_entry in self.POOL:
try:
async with self.get_connection(pool_entry) as conn:
if conn is not None:
online_hosts.append(pool_entry)
except Exception as e:
err(f"Error checking host {pool_entry['ts_ip']}:{pool_entry['db_port']}: {str(e)}")
return online_hosts
return None
async def check_postgis(self, conn):
if conn is None:
@ -478,53 +510,350 @@ class APIConfig(BaseModel):
err(f"Error checking PostGIS: {str(e)}")
return False
async def execute_write_query(self, query: str, *args, table_name: str):
async def pull_changes(self, source_pool_entry, batch_size=10000):
if source_pool_entry['ts_id'] == os.environ.get('TS_ID'):
debug("Skipping self-sync")
return 0
total_changes = 0
source_id = source_pool_entry['ts_id']
source_ip = source_pool_entry['ts_ip']
dest_id = os.environ.get('TS_ID')
dest_ip = self.local_db['ts_ip']
info(f"Starting sync from source {source_id} ({source_ip}) to destination {dest_id} ({dest_ip})")
source_conn = None
dest_conn = None
try:
source_conn = await self.get_connection(source_pool_entry)
if source_conn is None:
warn(f"Unable to connect to source {source_id} ({source_ip}). Skipping sync.")
return 0
dest_conn = await self.get_connection(self.local_db)
if dest_conn is None:
warn(f"Unable to connect to local database. Skipping sync.")
return 0
tables = await source_conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
try:
if table_name in self.SPECIAL_TABLES:
return await self._execute_special_table_write(query, *args, table_name=table_name)
await self.sync_special_table(source_conn, dest_conn, table_name)
else:
primary_key = await self.ensure_sync_columns(dest_conn, table_name)
last_synced_version = await self.get_last_synced_version(dest_conn, table_name, source_id)
async with self.get_connection() as conn:
changes = await source_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > $1 AND server_id = $2
ORDER BY version ASC
LIMIT $3
""", last_synced_version, source_id, batch_size)
if changes:
changes_count = await self.apply_batch_changes(dest_conn, table_name, changes, primary_key)
total_changes += changes_count
if changes_count > 0:
info(f"Synced batch for {table_name}: {changes_count} changes. Total so far: {total_changes}")
else:
debug(f"No changes to sync for {table_name}")
except Exception as e:
err(f"Error syncing table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
info(f"Sync complete from {source_id} ({source_ip}) to {dest_id} ({dest_ip}). Total changes: {total_changes}")
except Exception as e:
err(f"Error during sync process: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
finally:
if source_conn:
await source_conn.close()
if dest_conn:
await dest_conn.close()
info(f"Sync summary:")
info(f" Total changes: {total_changes}")
info(f" Tables synced: {len(tables) if 'tables' in locals() else 0}")
info(f" Source: {source_id} ({source_ip})")
info(f" Destination: {dest_id} ({dest_ip})")
return total_changes
async def get_last_synced_version(self, conn, table_name, server_id):
if conn is None:
raise ConnectionError("Failed to connect to local database")
debug(f"Skipping offline server...")
return 0
# Ensure sync columns exist
primary_key = await self.ensure_sync_columns(conn, table_name)
if table_name in self.SPECIAL_TABLES:
debug(f"Skipping get_last_synced_version because {table_name} is special.")
return 0 # Special handling for tables without version column
# Execute the query
result = await conn.execute(query, *args)
try:
last_version = await conn.fetchval(f"""
SELECT COALESCE(MAX(version), 0)
FROM "{table_name}"
WHERE server_id = $1
""", server_id)
return last_version
except Exception as e:
err(f"Error getting last synced version for table {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
return 0
# Get the primary key and new version of the affected row
if primary_key:
affected_row = await conn.fetchrow(f"""
SELECT "{primary_key}", version, server_id
async def get_most_recent_source(self):
most_recent_source = None
max_version = -1
local_ts_id = os.environ.get('TS_ID')
online_hosts = await self.get_online_hosts()
num_online_hosts = len(online_hosts)
if num_online_hosts > 0:
online_ts_ids = [host['ts_id'] for host in online_hosts if host['ts_id'] != local_ts_id]
crit(f"Online hosts: {', '.join(online_ts_ids)}")
for pool_entry in online_hosts:
if pool_entry['ts_id'] == local_ts_id:
continue # Skip local database
try:
conn = await self.get_connection(pool_entry)
if conn is None:
warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping.")
continue
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
if table_name in self.SPECIAL_TABLES:
continue
try:
result = await conn.fetchrow(f"""
SELECT MAX(version) as max_version, server_id
FROM "{table_name}"
WHERE version = (SELECT MAX(version) FROM "{table_name}")
GROUP BY server_id
ORDER BY MAX(version) DESC
LIMIT 1
""")
if affected_row:
await self.push_change(table_name, affected_row[primary_key], affected_row['version'], affected_row['server_id'])
if result:
version, server_id = result['max_version'], result['server_id']
info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})")
if version > max_version:
max_version = version
most_recent_source = pool_entry
else:
# For tables without a primary key, we'll push all rows
await self.push_all_changes(table_name)
debug(f"No data in table {table_name} for {pool_entry['ts_id']}")
except asyncpg.exceptions.UndefinedColumnError:
warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping.")
except Exception as e:
err(f"Error checking version for {pool_entry['ts_id']}, table {table_name}: {str(e)}")
return result
except asyncpg.exceptions.ConnectionFailureError:
warn(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}. Skipping.")
except Exception as e:
err(f"Unexpected error occurred while checking version for {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
finally:
if conn:
await conn.close()
async def push_change(self, table_name: str, pk_value: Any, version: int, server_id: str):
if most_recent_source is None:
if num_online_hosts > 0:
warn("Could not determine most recent source. Using first available online host.")
most_recent_source = next(host for host in online_hosts if host['ts_id'] != local_ts_id)
else:
crit("No other online hosts available for sync.")
return most_recent_source
async def _sync_changes(self, table_name: str, primary_key: str):
try:
local_conn = await self.get_connection()
if local_conn is None:
return
# Get the latest changes
changes = await local_conn.fetch(f"""
SELECT * FROM "{table_name}"
WHERE version > (SELECT COALESCE(MAX(version), 0) FROM "{table_name}" WHERE server_id != $1)
OR (version = (SELECT COALESCE(MAX(version), 0) FROM "{table_name}" WHERE server_id != $1) AND server_id = $1)
ORDER BY version ASC
""", os.environ.get('TS_ID'))
if changes:
online_hosts = await self.get_online_hosts()
for pool_entry in online_hosts:
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
remote_conn = await self.get_connection(pool_entry)
if remote_conn is None:
continue
try:
async with self.get_connection(pool_entry) as remote_conn:
await self.apply_batch_changes(remote_conn, table_name, changes, primary_key)
finally:
await remote_conn.close()
except Exception as e:
err(f"Error syncing changes for {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
finally:
if 'local_conn' in locals():
await local_conn.close()
async def execute_read_query(self, query: str, *args, table_name: str):
online_hosts = await self.get_online_hosts()
results = []
max_version = -1
latest_result = None
for pool_entry in online_hosts:
conn = await self.get_connection(pool_entry)
if conn is None:
warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping read.")
continue
try:
# Execute the query
result = await conn.fetch(query, *args)
if not result:
continue
# Check version if it's not a special table
if table_name not in self.SPECIAL_TABLES:
try:
version_query = f"""
SELECT MAX(version) as max_version, server_id
FROM "{table_name}"
WHERE version = (SELECT MAX(version) FROM "{table_name}")
GROUP BY server_id
ORDER BY MAX(version) DESC
LIMIT 1
"""
version_result = await conn.fetchrow(version_query)
if version_result:
version = version_result['max_version']
server_id = version_result['server_id']
info(f"Max version for {pool_entry['ts_id']}, table {table_name}: {version} (from server {server_id})")
if version > max_version:
max_version = version
latest_result = result
else:
debug(f"No data in table {table_name} for {pool_entry['ts_id']}")
except asyncpg.exceptions.UndefinedColumnError:
warn(f"Version or server_id column does not exist in table {table_name} for {pool_entry['ts_id']}. Skipping version check.")
if latest_result is None:
latest_result = result
else:
# For special tables, just use the first result
if latest_result is None:
latest_result = result
results.append((pool_entry['ts_id'], result))
except Exception as e:
err(f"Error executing read query on {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
finally:
await conn.close()
if not latest_result:
warn(f"No results found for query on table {table_name}")
return []
# Log results from all databases
for ts_id, result in results:
info(f"Read result from {ts_id}: {result}")
return [dict(r) for r in latest_result] # Convert Record objects to dictionaries
async def execute_write_query(self, query: str, *args, table_name: str):
conn = await self.get_connection()
if conn is None:
raise ConnectionError("Failed to connect to local database")
try:
if table_name in self.SPECIAL_TABLES:
return await self._execute_special_table_write(conn, query, *args, table_name=table_name)
primary_key = await self.ensure_sync_columns(conn, table_name)
result = await conn.execute(query, *args)
asyncio.create_task(self._sync_changes(table_name, primary_key))
return []
finally:
await conn.close()
async def _run_sync_tasks(self, tasks):
for task in tasks:
try:
await task
except Exception as e:
err(f"Error during background sync: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
async def push_change(self, table_name: str, pk_value: Any, version: int, server_id: str):
asyncio.create_task(self._push_change_background(table_name, pk_value, version, server_id))
async def _push_change_background(self, table_name: str, pk_value: Any, version: int, server_id: str):
online_hosts = await self.get_online_hosts()
successful_pushes = 0
failed_pushes = 0
for pool_entry in online_hosts:
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
remote_conn = await self.get_connection(pool_entry)
if remote_conn is None:
continue
# Fetch the updated row from the local database
async with self.get_connection() as local_conn:
try:
local_conn = await self.get_connection()
if local_conn is None:
continue
try:
updated_row = await local_conn.fetchrow(f'SELECT * FROM "{table_name}" WHERE "{self.get_primary_key(table_name)}" = $1', pk_value)
finally:
await local_conn.close()
if updated_row:
columns = updated_row.keys()
placeholders = [f'${i+1}' for i in range(len(columns))]
primary_key = self.get_primary_key(table_name)
remote_version = await remote_conn.fetchval(f"""
SELECT version FROM "{table_name}"
WHERE "{primary_key}" = $1
""", updated_row[primary_key])
if remote_version is not None and remote_version >= updated_row['version']:
debug(f"Remote version for {table_name} in {pool_entry['ts_id']} is already up to date. Skipping push.")
successful_pushes += 1
continue
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
@ -536,25 +865,57 @@ class APIConfig(BaseModel):
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
"""
await remote_conn.execute(insert_query, *updated_row.values())
successful_pushes += 1
except Exception as e:
err(f"Error pushing change to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
failed_pushes += 1
finally:
if remote_conn:
await remote_conn.close()
if successful_pushes > 0:
info(f"Successfully pushed changes to {successful_pushes} server(s) for {table_name}")
if failed_pushes > 0:
warn(f"Failed to push changes to {failed_pushes} server(s) for {table_name}")
async def push_all_changes(self, table_name: str):
online_hosts = await self.get_online_hosts()
tasks = []
for pool_entry in online_hosts:
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
try:
async with self.get_connection(pool_entry) as remote_conn:
if remote_conn is None:
continue
task = asyncio.create_task(self._push_changes_to_host(pool_entry, table_name))
tasks.append(task)
# Fetch all rows from the local database
async with self.get_connection() as local_conn:
results = await asyncio.gather(*tasks, return_exceptions=True)
successful_pushes = sum(1 for r in results if r is True)
failed_pushes = sum(1 for r in results if r is False or isinstance(r, Exception))
info(f"Push all changes summary for {table_name}: Successful: {successful_pushes}, Failed: {failed_pushes}")
if failed_pushes > 0:
warn(f"Failed to push all changes to {failed_pushes} server(s). Data may be out of sync.")
async def _push_changes_to_host(self, pool_entry: Dict[str, Any], table_name: str) -> bool:
remote_conn = None
try:
remote_conn = await self.get_connection(pool_entry)
if remote_conn is None:
warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping push.")
return False
local_conn = await self.get_connection()
if local_conn is None:
warn(f"Unable to connect to local database. Skipping push.")
return False
try:
all_rows = await local_conn.fetch(f'SELECT * FROM "{table_name}"')
finally:
await local_conn.close()
if all_rows:
columns = all_rows[0].keys()
columns = list(all_rows[0].keys())
placeholders = [f'${i+1}' for i in range(len(columns))]
insert_query = f"""
@ -563,25 +924,27 @@ class APIConfig(BaseModel):
ON CONFLICT DO NOTHING
"""
# Use a transaction to insert all rows
async with remote_conn.transaction():
for row in all_rows:
await remote_conn.execute(insert_query, *row.values())
return True
else:
debug(f"No rows to push for table {table_name}")
return True
except Exception as e:
err(f"Error pushing all changes to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
return False
finally:
if remote_conn:
await remote_conn.close()
async def _execute_special_table_write(self, query: str, *args, table_name: str):
async def _execute_special_table_write(self, conn, query: str, *args, table_name: str):
if table_name == 'spatial_ref_sys':
return await self._execute_spatial_ref_sys_write(query, *args)
# Add more special cases as needed
async def _execute_spatial_ref_sys_write(self, query: str, *args):
result = None
async with self.get_connection() as local_conn:
if local_conn is None:
raise ConnectionError("Failed to connect to local database")
return await self._execute_spatial_ref_sys_write(conn, query, *args)
async def _execute_spatial_ref_sys_write(self, local_conn, query: str, *args):
# Execute the query locally
result = await local_conn.execute(query, *args)
@ -589,22 +952,65 @@ class APIConfig(BaseModel):
online_hosts = await self.get_online_hosts()
for pool_entry in online_hosts:
if pool_entry['ts_id'] != os.environ.get('TS_ID'):
try:
async with self.get_connection(pool_entry) as remote_conn:
remote_conn = await self.get_connection(pool_entry)
if remote_conn is None:
continue
try:
await self.sync_spatial_ref_sys(local_conn, remote_conn)
except Exception as e:
err(f"Error syncing spatial_ref_sys to {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
finally:
await remote_conn.close()
return result
def get_primary_key(self, table_name: str) -> str:
# This method should return the primary key for the given table
# You might want to cache this information for performance
# For now, we'll assume it's always 'id', but you should implement proper logic here
return 'id'
async def apply_batch_changes(self, conn, table_name, changes, primary_key):
if conn is None or not changes:
debug(f"Skipping apply_batch_changes because conn is none or there are no changes.")
return 0
try:
columns = list(changes[0].keys())
placeholders = [f'${i+1}' for i in range(len(columns))]
if primary_key:
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT ("{primary_key}") DO UPDATE SET
{', '.join(f'"{col}" = EXCLUDED."{col}"' for col in columns if col not in [primary_key, 'version', 'server_id'])},
version = EXCLUDED.version,
server_id = EXCLUDED.server_id
WHERE "{table_name}".version < EXCLUDED.version
OR ("{table_name}".version = EXCLUDED.version AND "{table_name}".server_id < EXCLUDED.server_id)
"""
else:
warn(f"Possible source of issue #4")
# For tables without a primary key, we'll use all columns for conflict resolution
insert_query = f"""
INSERT INTO "{table_name}" ({', '.join(f'"{col}"' for col in columns)})
VALUES ({', '.join(placeholders)})
ON CONFLICT DO NOTHING
"""
affected_rows = 0
async for change in tqdm(changes, desc=f"Syncing {table_name}", unit="row"):
values = [change[col] for col in columns]
result = await conn.execute(insert_query, *values)
affected_rows += int(result.split()[-1])
return affected_rows
except Exception as e:
err(f"Error applying batch changes to {table_name}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
return 0
async def sync_special_table(self, source_conn, dest_conn, table_name):
if table_name == 'spatial_ref_sys':
return await self.sync_spatial_ref_sys(source_conn, dest_conn)
# Add more special cases as needed
async def sync_spatial_ref_sys(self, source_conn, dest_conn):
try:
@ -666,6 +1072,77 @@ class APIConfig(BaseModel):
return 0
def get_primary_key(self, table_name: str) -> str:
# This method should return the primary key for the given table
# You might want to cache this information for performance
# For now, we'll assume it's always 'id', but you should implement proper logic here
return 'id'
async def add_primary_keys_to_local_tables(self):
conn = await self.get_connection()
debug(f"Adding primary keys to existing tables...")
if conn is None:
raise ConnectionError("Failed to connect to local database")
try:
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
if table_name not in self.SPECIAL_TABLES:
await self.ensure_sync_columns(conn, table_name)
finally:
await conn.close()
async def add_primary_keys_to_remote_tables(self):
online_hosts = await self.get_online_hosts()
for pool_entry in online_hosts:
conn = await self.get_connection(pool_entry)
if conn is None:
warn(f"Unable to connect to {pool_entry['ts_id']}. Skipping primary key addition.")
continue
try:
info(f"Adding primary keys to existing tables on {pool_entry['ts_id']}...")
tables = await conn.fetch("""
SELECT tablename FROM pg_tables
WHERE schemaname = 'public'
""")
for table in tables:
table_name = table['tablename']
if table_name not in self.SPECIAL_TABLES:
primary_key = await self.ensure_sync_columns(conn, table_name)
if primary_key:
info(f"Added/ensured primary key '{primary_key}' for table '{table_name}' on {pool_entry['ts_id']}")
else:
warn(f"Failed to add/ensure primary key for table '{table_name}' on {pool_entry['ts_id']}")
info(f"Completed adding primary keys to existing tables on {pool_entry['ts_id']}")
except Exception as e:
err(f"Error adding primary keys to existing tables on {pool_entry['ts_id']}: {str(e)}")
err(f"Traceback: {traceback.format_exc()}")
finally:
await conn.close()
async def close_db_pools(self):
info("Closing database connection pools...")
for pool_key, pool in self.db_pools.items():
try:
await pool.close()
debug(f"Closed pool for {pool_key}")
except Exception as e:
err(f"Error closing pool for {pool_key}: {str(e)}")
self.db_pools.clear()
info("All database connection pools closed.")
class Location(BaseModel):
latitude: float
@ -679,7 +1156,7 @@ class Location(BaseModel):
state: Optional[str] = None
country: Optional[str] = None
context: Optional[Dict[str, Any]] = None
class_: Optional[str] = None
class_: Optional[str] = Field(None, alias="class")
type: Optional[str] = None
name: Optional[str] = None
display_name: Optional[str] = None
@ -697,11 +1174,7 @@ class Location(BaseModel):
json_encoders = {
datetime: lambda dt: dt.isoformat(),
}
def model_dump(self):
data = self.dict()
data["datetime"] = self.datetime.isoformat() if self.datetime else None
return data
populate_by_name = True
class Geocoder:

View file

@ -120,7 +120,6 @@ async def get_last_location() -> Optional[Location]:
return None
async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, datetime, None] = None) -> List[Location]:
start_datetime = await dt(start)
if end is None:
@ -133,10 +132,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
debug(f"Fetching locations between {start_datetime} and {end_datetime}")
async with API.get_connection() as conn:
locations = []
# Check for records within the specified datetime range
range_locations = await conn.fetch('''
query = '''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
@ -146,13 +142,14 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
FROM locations
WHERE datetime >= $1 AND datetime <= $2
ORDER BY datetime DESC
''', start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None))
'''
debug(f"Range locations query returned: {range_locations}")
locations.extend(range_locations)
locations = await API.execute_read_query(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None), table_name="locations")
debug(f"Range locations query returned: {locations}")
if not locations and (end is None or start_datetime.date() == end_datetime.date()):
location_data = await conn.fetchrow('''
fallback_query = '''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
@ -163,11 +160,11 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
''', start_datetime.replace(tzinfo=None))
'''
location_data = await API.execute_read_query(fallback_query, start_datetime.replace(tzinfo=None), table_name="locations")
debug(f"Fallback query returned: {location_data}")
if location_data:
locations.append(location_data)
locations = location_data
debug(f"Locations found: {locations}")
@ -197,15 +194,12 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
return location_objects if location_objects else []
# Function to fetch the last location before the specified datetime
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
datetime = await dt(datetime)
debug(f"Fetching last location before {datetime}")
async with API.get_connection() as conn:
location_data = await conn.fetchrow('''
query = '''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
@ -216,13 +210,13 @@ async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
''', datetime.replace(tzinfo=None))
'''
await conn.close()
location_data = await API.execute_read_query(query, datetime.replace(tzinfo=None), table_name="locations")
if location_data:
debug(f"Last location found: {location_data}")
return Location(**location_data)
debug(f"Last location found: {location_data[0]}")
return Location(**location_data[0])
else:
debug("No location found before the specified datetime")
return None
@ -247,17 +241,13 @@ async def generate_map_endpoint(
return HTMLResponse(content=html_content)
async def get_date_range():
async with API.get_connection() as conn:
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
row = await conn.fetchrow(query)
if row and row['min_date'] and row['max_date']:
return row['min_date'], row['max_date']
row = await API.execute_read_query(query, table_name="locations")
if row and row[0]['min_date'] and row[0]['max_date']:
return row[0]['min_date'], row[0]['max_date']
else:
return datetime(2022, 1, 1), datetime.now()
async def generate_and_save_heatmap(
start_date: Union[str, int, datetime],
end_date: Optional[Union[str, int, datetime]] = None,
@ -313,8 +303,6 @@ Generate a heatmap for the given date range and save it as a PNG file using Foli
err(f"Error generating and saving heatmap: {str(e)}")
raise
async def generate_map(start_date: datetime, end_date: datetime, max_points: int):
locations = await fetch_locations(start_date, end_date)
if not locations:
@ -343,8 +331,6 @@ async def generate_map(start_date: datetime, end_date: datetime, max_points: int
folium.TileLayer('cartodbdark_matter', name='Dark Mode').add_to(m)
# In the generate_map function:
draw = Draw(
draw_options={
'polygon': True,
@ -433,11 +419,6 @@ map.on(L.Draw.Event.CREATED, function (event) {
return m.get_root().render()
async def post_location(location: Location):
# if not location.datetime:
# info(f"location appears to be missing datetime: {location}")
# else:
# debug(f"post_location called with {location.datetime}")
async with API.get_connection() as conn:
try:
context = location.context or {}
action = context.get('action', 'manual')
@ -449,7 +430,7 @@ async def post_location(location: Location):
# Parse and localize the datetime
localized_datetime = await dt(location.datetime)
await conn.execute('''
query = '''
INSERT INTO locations (
datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os,
class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood,
@ -457,13 +438,18 @@ async def post_location(location: Location):
)
VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15,
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
''', localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
'''
await API.execute_write_query(
query,
localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
location.zip, location.street, action, device_type, device_model, device_name, device_os,
location.class_, location.type, location.name, location.display_name,
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
location.suburb, location.county, location.country_code, location.country)
location.suburb, location.county, location.country_code, location.country,
table_name="locations"
)
await conn.close()
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")
return {
'datetime': localized_datetime,
@ -553,6 +539,7 @@ async def get_last_location_endpoint() -> JSONResponse:
raise HTTPException(status_code=404, detail="No location found before the specified datetime")
@gis.get("/locate/{datetime_str}", response_model=List[Location])
async def get_locate(datetime_str: str, all: bool = False):
try:

View file

@ -1,393 +0,0 @@
'''
Uses Postgres/PostGIS for location tracking (data obtained via the companion mobile Pythonista scripts), and for geocoding purposes.
'''
from fastapi import APIRouter, HTTPException, Query
from fastapi.responses import HTMLResponse, JSONResponse
import yaml
from typing import List, Tuple, Union
import traceback
from datetime import datetime, timezone
from typing import Union, List
import folium
from zoneinfo import ZoneInfo
from dateutil.parser import parse as dateutil_parse
from typing import Optional, List, Union
from datetime import datetime
from sijapi import L, API, TZ, GEO
from sijapi.classes import Location
from sijapi.utilities import haversine
loc = APIRouter()
logger = L.get_module_logger("loc")
async def dt(
date_time: Union[str, int, datetime],
tz: Union[str, ZoneInfo, None] = None
) -> datetime:
try:
# Convert integer (epoch time) to UTC datetime
if isinstance(date_time, int):
date_time = datetime.utcfromtimestamp(date_time).replace(tzinfo=timezone.utc)
logger.debug(f"Converted epoch time {date_time} to UTC datetime object.")
# Convert string to datetime if necessary
elif isinstance(date_time, str):
date_time = dateutil_parse(date_time)
logger.debug(f"Converted string '{date_time}' to datetime object.")
if not isinstance(date_time, datetime):
raise ValueError(f"Input must be a string, integer (epoch time), or datetime object. What we received: {date_time}, type {type(date_time)}")
# Ensure the datetime is timezone-aware (UTC if not specified)
if date_time.tzinfo is None:
date_time = date_time.replace(tzinfo=timezone.utc)
logger.debug("Added UTC timezone to naive datetime.")
# Handle provided timezone
if tz is not None:
if isinstance(tz, str):
if tz == "local":
last_loc = await get_timezone_without_timezone(date_time)
tz = await GEO.tz_at(last_loc.latitude, last_loc.longitude)
logger.debug(f"Using local timezone: {tz}")
else:
try:
tz = ZoneInfo(tz)
except Exception as e:
logger.error(f"Invalid timezone string '{tz}'. Error: {e}")
raise ValueError(f"Invalid timezone string: {tz}")
elif isinstance(tz, ZoneInfo):
pass # tz is already a ZoneInfo object
else:
raise ValueError(f"What we needed: tz == 'local', a string, or a ZoneInfo object. What we got: tz, a {type(tz)}, == {tz})")
# Convert to the provided or determined timezone
date_time = date_time.astimezone(tz)
logger.debug(f"Converted datetime to timezone: {tz}")
return date_time
except ValueError as e:
logger.error(f"Error in dt: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error in dt: {e}")
raise ValueError(f"Failed to process datetime: {e}")
async def get_timezone_without_timezone(date_time):
# This is a bit convoluted because we're trying to solve the paradox of needing to know the location in order to determine the timezone, but needing the timezone to be certain we've got the right location if this datetime coincided with inter-timezone travel. Our imperfect solution is to use UTC for an initial location query to determine roughly where we were at the time, get that timezone, then check the location again using that timezone, and if this location is different from the one using UTC, get the timezone again usng it, otherwise use the one we already sourced using UTC.
# Step 1: Use UTC as an interim timezone to query location
interim_dt = date_time.replace(tzinfo=ZoneInfo("UTC"))
interim_loc = await fetch_last_location_before(interim_dt)
# Step 2: Get a preliminary timezone based on the interim location
interim_tz = await GEO.tz_current((interim_loc.latitude, interim_loc.longitude))
# Step 3: Apply this preliminary timezone and query location again
query_dt = date_time.replace(tzinfo=ZoneInfo(interim_tz))
query_loc = await fetch_last_location_before(query_dt)
# Step 4: Get the final timezone, reusing interim_tz if location hasn't changed
return interim_tz if query_loc == interim_loc else await GEO.tz_current(query_loc.latitude, query_loc.longitude)
async def get_last_location() -> Optional[Location]:
query_datetime = datetime.now(TZ)
logger.debug(f"Query_datetime: {query_datetime}")
this_location = await fetch_last_location_before(query_datetime)
if this_location:
logger.debug(f"location: {this_location}")
return this_location
return None
async def fetch_locations(start: datetime, end: datetime = None) -> List[Location]:
start_datetime = await dt(start)
if end is None:
end_datetime = await dt(start_datetime.replace(hour=23, minute=59, second=59))
else:
end_datetime = await dt(end)
if start_datetime.time() == datetime.min.time() and end_datetime.time() == datetime.min.time():
end_datetime = end_datetime.replace(hour=23, minute=59, second=59)
logger.debug(f"Fetching locations between {start_datetime} and {end_datetime}")
async with API.get_connection() as conn:
locations = []
# Check for records within the specified datetime range
range_locations = await conn.fetch('''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street,
action, device_type, device_model, device_name, device_os
FROM locations
WHERE datetime >= $1 AND datetime <= $2
ORDER BY datetime DESC
''', start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None))
logger.debug(f"Range locations query returned: {range_locations}")
locations.extend(range_locations)
if not locations and (end is None or start_datetime.date() == end_datetime.date()):
location_data = await conn.fetchrow('''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street,
action, device_type, device_model, device_name, device_os
FROM locations
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
''', start_datetime.replace(tzinfo=None))
logger.debug(f"Fallback query returned: {location_data}")
if location_data:
locations.append(location_data)
logger.debug(f"Locations found: {locations}")
# Sort location_data based on the datetime field in descending order
sorted_locations = sorted(locations, key=lambda x: x['datetime'], reverse=True)
# Create Location objects directly from the location data
location_objects = [
Location(
latitude=location['latitude'],
longitude=location['longitude'],
datetime=location['datetime'],
elevation=location.get('elevation'),
city=location.get('city'),
state=location.get('state'),
zip=location.get('zip'),
street=location.get('street'),
context={
'action': location.get('action'),
'device_type': location.get('device_type'),
'device_model': location.get('device_model'),
'device_name': location.get('device_name'),
'device_os': location.get('device_os')
}
) for location in sorted_locations if location['latitude'] is not None and location['longitude'] is not None
]
return location_objects if location_objects else []
# Function to fetch the last location before the specified datetime
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
datetime = await dt(datetime)
logger.debug(f"Fetching last location before {datetime}")
async with API.get_connection() as conn:
location_data = await conn.fetchrow('''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street, country,
action
FROM locations
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
''', datetime.replace(tzinfo=None))
await conn.close()
if location_data:
logger.debug(f"Last location found: {location_data}")
return Location(**location_data)
else:
logger.debug("No location found before the specified datetime")
return None
@loc.get("/map/start_date={start_date_str}&end_date={end_date_str}", response_class=HTMLResponse)
async def generate_map_endpoint(start_date_str: str, end_date_str: str):
try:
start_date = await dt(start_date_str)
end_date = await dt(end_date_str)
except ValueError:
raise HTTPException(status_code=400, detail="Invalid date format")
html_content = await generate_map(start_date, end_date)
return HTMLResponse(content=html_content)
@loc.get("/map", response_class=HTMLResponse)
async def generate_alltime_map_endpoint():
try:
start_date = await dt(datetime.fromisoformat("2022-01-01"))
end_date = dt(datetime.now())
except ValueError:
raise HTTPException(status_code=400, detail="Invalid date format")
html_content = await generate_map(start_date, end_date)
return HTMLResponse(content=html_content)
async def generate_map(start_date: datetime, end_date: datetime):
locations = await fetch_locations(start_date, end_date)
if not locations:
raise HTTPException(status_code=404, detail="No locations found for the given date range")
# Create a folium map centered around the first location
map_center = [locations[0].latitude, locations[0].longitude]
m = folium.Map(location=map_center, zoom_start=5)
# Add markers for each location
for location in locations:
folium.Marker(
location=[location.latitude, location.longitude],
popup=f"{location.city}, {location.state}<br>Elevation: {location.elevation}m<br>Date: {location.datetime}",
tooltip=f"{location.city}, {location.state}"
).add_to(m)
# Save the map to an HTML file and return the HTML content
map_html = "map.html"
m.save(map_html)
with open(map_html, 'r') as file:
html_content = file.read()
return html_content
async def post_location(location: Location):
async with API.get_connection() as conn:
try:
context = location.context or {}
action = context.get('action', 'manual')
device_type = context.get('device_type', 'Unknown')
device_model = context.get('device_model', 'Unknown')
device_name = context.get('device_name', 'Unknown')
device_os = context.get('device_os', 'Unknown')
# Parse and localize the datetime
localized_datetime = await dt(location.datetime)
await conn.execute('''
INSERT INTO locations (
datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os,
class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood,
suburb, county, country_code, country
)
VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15,
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
''', localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
location.zip, location.street, action, device_type, device_model, device_name, device_os,
location.class_, location.type, location.name, location.display_name,
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
location.suburb, location.county, location.country_code, location.country)
await conn.close()
logger.info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")
return {
'datetime': localized_datetime,
'latitude': location.latitude,
'longitude': location.longitude,
'elevation': location.elevation,
'city': location.city,
'state': location.state,
'zip': location.zip,
'street': location.street,
'action': action,
'device_type': device_type,
'device_model': device_model,
'device_name': device_name,
'device_os': device_os,
'class_': location.class_,
'type': location.type,
'name': location.name,
'display_name': location.display_name,
'amenity': location.amenity,
'house_number': location.house_number,
'road': location.road,
'quarter': location.quarter,
'neighbourhood': location.neighbourhood,
'suburb': location.suburb,
'county': location.county,
'country_code': location.country_code,
'country': location.country
}
except Exception as e:
logger.error(f"Error posting location {e}")
logger.error(traceback.format_exc())
return None
@loc.post("/locate")
async def post_locate_endpoint(locations: Union[Location, List[Location]]):
if isinstance(locations, Location):
locations = [locations]
# Prepare locations
for lcn in locations:
if not lcn.datetime:
tz = await GEO.tz_at(lcn.latitude, lcn.longitude)
lcn.datetime = datetime.now(ZoneInfo(tz)).isoformat()
if not lcn.context:
lcn.context = {
"action": "missing",
"device_type": "API",
"device_model": "Unknown",
"device_name": "Unknown",
"device_os": "Unknown"
}
logger.debug(f"Location received for processing: {lcn}")
geocoded_locations = await GEO.code(locations)
responses = []
if isinstance(geocoded_locations, List):
for location in geocoded_locations:
logger.debug(f"Final location to be submitted to database: {location}")
location_entry = await post_location(location)
if location_entry:
responses.append({"location_data": location_entry})
else:
logger.warning(f"Posting location to database appears to have failed.")
else:
logger.debug(f"Final location to be submitted to database: {geocoded_locations}")
location_entry = await post_location(geocoded_locations)
if location_entry:
responses.append({"location_data": location_entry})
else:
logger.warning(f"Posting location to database appears to have failed.")
return {"message": "Locations and weather updated", "results": responses}
@loc.get("/locate", response_model=Location)
async def get_last_location_endpoint() -> JSONResponse:
this_location = await get_last_location()
if this_location:
location_dict = this_location.model_dump()
location_dict["datetime"] = this_location.datetime.isoformat()
return JSONResponse(content=location_dict)
else:
raise HTTPException(status_code=404, detail="No location found before the specified datetime")
@loc.get("/locate/{datetime_str}", response_model=List[Location])
async def get_locate(datetime_str: str, all: bool = False):
try:
date_time = await dt(datetime_str)
except ValueError as e:
logger.error(f"Invalid datetime string provided: {datetime_str}")
return ["ERROR: INVALID DATETIME PROVIDED. USE YYYYMMDDHHmmss or YYYYMMDD format."]
locations = await fetch_locations(date_time)
if not locations:
raise HTTPException(status_code=404, detail="No nearby data found for this date and time")
return locations if all else [locations[0]]

View file

@ -212,7 +212,6 @@ if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True:
except requests.exceptions.RequestException:
results.append(f"{address} is down")
# Generate a simple text-based graph
graph = '|' * up_count + '.' * (len(addresses) - up_count)
text_update = "\n".join(results)
@ -220,7 +219,6 @@ if API.EXTENSIONS.shellfish == "on" or API.EXTENSIONS.shellfish == True:
output = shellfish_run_widget_command(widget_command)
return {"output": output, "graph": graph}
def shellfish_update_widget(update: WidgetUpdate):
widget_command = ["widget"]
@ -291,6 +289,7 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True:
bg_tasks.add_task(cl_docket_process, result)
return JSONResponse(content={"message": "Received"}, status_code=status.HTTP_200_OK)
async def cl_docket_process(result):
async with httpx.AsyncClient() as session:
await cl_docket_process_result(result, session)
@ -346,12 +345,14 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True:
await cl_download_file(file_url, target_path, session)
debug(f"Downloaded {file_name} to {target_path}")
def cl_case_details(docket):
case_info = CASETABLE.get(str(docket), {"code": "000", "shortname": "UNKNOWN"})
case_code = case_info.get("code")
short_name = case_info.get("shortname")
return case_code, short_name
async def cl_download_file(url: str, path: Path, session: aiohttp.ClientSession = None):
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.82 Safari/537.36'
@ -417,19 +418,20 @@ if API.EXTENSIONS.courtlistener == "on" or API.EXTENSIONS.courtlistener == True:
await cl_download_file(download_url, target_path, session)
debug(f"Downloaded {file_name} to {target_path}")
@serve.get("/s", response_class=HTMLResponse)
async def shortener_form(request: Request):
return templates.TemplateResponse("shortener.html", {"request": request})
@serve.post("/s")
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
await create_tables()
if custom_code:
if len(custom_code) != 3 or not custom_code.isalnum():
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code must be 3 alphanumeric characters"})
existing = await API.execute_write_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls")
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', custom_code, table_name="short_urls")
if existing:
return templates.TemplateResponse("shortener.html", {"request": request, "error": "Custom code already in use"})
@ -437,8 +439,9 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c
else:
chars = string.ascii_letters + string.digits
while True:
debug(f"FOUND THE ISSUE")
short_code = ''.join(random.choice(chars) for _ in range(3))
existing = await API.execute_write_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls")
existing = await API.execute_read_query('SELECT 1 FROM short_urls WHERE short_code = $1', short_code, table_name="short_urls")
if not existing:
break
@ -451,48 +454,36 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c
short_url = f"https://sij.ai/{short_code}"
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
async def create_tables():
await API.execute_write_query('''
CREATE TABLE IF NOT EXISTS short_urls (
short_code VARCHAR(3) PRIMARY KEY,
long_url TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
''', table_name="short_urls")
await API.execute_write_query('''
CREATE TABLE IF NOT EXISTS click_logs (
id SERIAL PRIMARY KEY,
short_code VARCHAR(3) REFERENCES short_urls(short_code),
clicked_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
ip_address TEXT,
user_agent TEXT
)
''', table_name="click_logs")
@serve.get("/{short_code}", response_class=RedirectResponse, status_code=301)
async def redirect_short_url(request: Request, short_code: str = PathParam(..., min_length=3, max_length=3)):
if request.headers.get('host') != 'sij.ai':
raise HTTPException(status_code=404, detail="Not Found")
result = await API.execute_write_query(
@serve.get("/{short_code}")
async def redirect_short_url(short_code: str):
results = await API.execute_read_query(
'SELECT long_url FROM short_urls WHERE short_code = $1',
short_code,
table_name="short_urls"
)
if result:
if not results:
raise HTTPException(status_code=404, detail="Short URL not found")
long_url = results[0].get('long_url')
if not long_url:
raise HTTPException(status_code=404, detail="Long URL not found")
# Increment click count (you may want to do this asynchronously)
await API.execute_write_query(
'INSERT INTO click_logs (short_code, ip_address, user_agent) VALUES ($1, $2, $3)',
short_code, request.client.host, request.headers.get("user-agent"),
'INSERT INTO click_logs (short_code, clicked_at) VALUES ($1, $2)',
short_code, datetime.now(),
table_name="click_logs"
)
return result['long_url']
else:
raise HTTPException(status_code=404, detail="Short URL not found")
return RedirectResponse(url=long_url)
@serve.get("/analytics/{short_code}")
async def get_analytics(short_code: str):
url_info = await API.execute_write_query(
url_info = await API.execute_read_query(
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
short_code,
table_name="short_urls"
@ -500,13 +491,13 @@ async def get_analytics(short_code: str):
if not url_info:
raise HTTPException(status_code=404, detail="Short URL not found")
click_count = await API.execute_write_query(
click_count = await API.execute_read_query(
'SELECT COUNT(*) FROM click_logs WHERE short_code = $1',
short_code,
table_name="click_logs"
)
clicks = await API.execute_write_query(
clicks = await API.execute_read_query(
'SELECT clicked_at, ip_address, user_agent FROM click_logs WHERE short_code = $1 ORDER BY clicked_at DESC LIMIT 100',
short_code,
table_name="click_logs"
@ -521,15 +512,20 @@ async def get_analytics(short_code: str):
}
async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, destination: str):
try:
dest_host, dest_port = destination.split(':')
dest_port = int(dest_port)
except ValueError:
warn(f"Invalid destination format: {destination}. Expected 'host:port'.")
writer.close()
await writer.wait_closed()
return
try:
dest_reader, dest_writer = await asyncio.open_connection(dest_host, dest_port)
except Exception as e:
warn(f"Failed to connect to destination {destination}: {str(e)}")
writer.close()
await writer.wait_closed()
return
@ -543,7 +539,7 @@ async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWr
dst.write(data)
await dst.drain()
except Exception as e:
pass
warn(f"Error in forwarding: {str(e)}")
finally:
dst.close()
await dst.wait_closed()
@ -554,8 +550,12 @@ async def forward_traffic(reader: asyncio.StreamReader, writer: asyncio.StreamWr
)
async def start_server(source: str, destination: str):
if ':' in source:
host, port = source.split(':')
port = int(port)
else:
host = source
port = 80
server = await asyncio.start_server(
lambda r, w: forward_traffic(r, w, destination),
@ -566,6 +566,7 @@ async def start_server(source: str, destination: str):
async with server:
await server.serve_forever()
async def start_port_forwarding():
if hasattr(Serve, 'forwarding_rules'):
for rule in Serve.forwarding_rules:
@ -573,6 +574,7 @@ async def start_port_forwarding():
else:
warn("No forwarding rules found in the configuration.")
@serve.get("/forward_status")
async def get_forward_status():
if hasattr(Serve, 'forwarding_rules'):
@ -580,5 +582,5 @@ async def get_forward_status():
else:
return {"status": "inactive", "message": "No forwarding rules configured"}
# Add this to the end of your serve.py file
asyncio.create_task(start_port_forwarding())

View file

@ -116,7 +116,6 @@ async def get_weather(date_time: dt_datetime, latitude: float, longitude: float,
async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
warn(f"Using {date_time.strftime('%Y-%m-%d %H:%M:%S')} as our datetime in store_weather_to_db")
async with API.get_connection() as conn:
try:
day_data = weather_data.get('days')[0]
debug(f"RAW DAY_DATA: {day_data}")
@ -163,6 +162,7 @@ async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
)
except Exception as e:
err(f"Failed to prepare database query in store_weather_to_db! {e}")
return "FAILURE"
try:
daily_weather_query = '''
@ -178,8 +178,7 @@ async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
RETURNING id
'''
async with conn.transaction():
daily_weather_id = await conn.fetchval(daily_weather_query, *daily_weather_params)
daily_weather_id = await API.execute_write_query(daily_weather_query, *daily_weather_params, table_name="DailyWeather")
if 'hours' in day_data:
debug(f"Processing hours now...")
@ -226,8 +225,7 @@ async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
RETURNING id
'''
async with conn.transaction():
hourly_weather_id = await conn.fetchval(hourly_weather_query, *hourly_weather_params)
hourly_weather_id = await API.execute_write_query(hourly_weather_query, *hourly_weather_params, table_name="HourlyWeather")
debug(f"Done processing hourly_weather_id {hourly_weather_id}")
except Exception as e:
err(f"EXCEPTION: {e}")
@ -239,6 +237,8 @@ async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
except Exception as e:
err(f"Error in dailyweather storage: {e}")
return "FAILURE"