Auto-update: Mon Aug 12 15:46:19 PDT 2024

This commit is contained in:
sanj 2024-08-12 15:46:19 -07:00
parent ff678a5df1
commit fd81cbed98
5 changed files with 196 additions and 78 deletions

View file

@ -48,12 +48,6 @@ async def lifespan(app: FastAPI):
l.critical("sijapi launched") l.critical("sijapi launched")
l.info(f"Arguments: {args}") l.info(f"Arguments: {args}")
# Log the router directory path
l.debug(f"Router directory path: {Dir.ROUTER.absolute()}")
l.debug(f"Router directory exists: {Dir.ROUTER.exists()}")
l.debug(f"Router directory is a directory: {Dir.ROUTER.is_dir()}")
l.debug(f"Contents of router directory: {list(Dir.ROUTER.iterdir())}")
# Load routers # Load routers
if args.test: if args.test:
load_router(args.test) load_router(args.test)
@ -64,6 +58,7 @@ async def lifespan(app: FastAPI):
try: try:
await Db.initialize_engines() await Db.initialize_engines()
await Db.ensure_query_tracking_table()
except Exception as e: except Exception as e:
l.critical(f"Error during startup: {str(e)}") l.critical(f"Error during startup: {str(e)}")
l.critical(f"Traceback: {traceback.format_exc()}") l.critical(f"Traceback: {traceback.format_exc()}")
@ -82,6 +77,7 @@ async def lifespan(app: FastAPI):
l.critical(f"Error during shutdown: {str(e)}") l.critical(f"Error during shutdown: {str(e)}")
l.critical(f"Traceback: {traceback.format_exc()}") l.critical(f"Traceback: {traceback.format_exc()}")
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware( app.add_middleware(

View file

@ -1,9 +1,11 @@
# database.py # database.py
import json import json
import yaml import yaml
import time import time
import aiohttp import aiohttp
import asyncio import asyncio
import traceback
from datetime import datetime as dt_datetime, date from datetime import datetime as dt_datetime, date
from tqdm.asyncio import tqdm from tqdm.asyncio import tqdm
import reverse_geocoder as rg import reverse_geocoder as rg
@ -19,11 +21,11 @@ from srtm import get_data
import os import os
import sys import sys
from loguru import logger from loguru import logger
from sqlalchemy import text from sqlalchemy import text, select, func, and_
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from sqlalchemy import Column, Integer, String, DateTime, JSON, Text, select, func from sqlalchemy import Column, Integer, String, DateTime, JSON, Text
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from urllib.parse import urljoin from urllib.parse import urljoin
import hashlib import hashlib
@ -41,7 +43,6 @@ ENV_PATH = CONFIG_DIR / ".env"
load_dotenv(ENV_PATH) load_dotenv(ENV_PATH)
TS_ID = os.environ.get('TS_ID') TS_ID = os.environ.get('TS_ID')
class QueryTracking(Base): class QueryTracking(Base):
__tablename__ = 'query_tracking' __tablename__ = 'query_tracking'
@ -85,19 +86,16 @@ class Database:
self.engines[db_info['ts_id']] = engine self.engines[db_info['ts_id']] = engine
self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
l.info(f"Initialized engine and session for {db_info['ts_id']}") l.info(f"Initialized engine and session for {db_info['ts_id']}")
# Create tables if they don't exist
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Ensured tables exist for {db_info['ts_id']}")
except Exception as e: except Exception as e:
l.error(f"Failed to initialize engine for {db_info['ts_id']}: {str(e)}") l.error(f"Failed to initialize engine for {db_info['ts_id']}: {str(e)}")
if self.local_ts_id not in self.sessions: if self.local_ts_id not in self.sessions:
l.error(f"Failed to initialize session for local server {self.local_ts_id}") l.error(f"Failed to initialize session for local server {self.local_ts_id}")
else:
try:
# Create tables if they don't exist
async with self.engines[self.local_ts_id].begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Initialized tables for local server {self.local_ts_id}")
except Exception as e:
l.error(f"Failed to create tables for local server {self.local_ts_id}: {str(e)}")
async def get_online_servers(self) -> List[str]: async def get_online_servers(self) -> List[str]:
online_servers = [] online_servers = []
@ -119,7 +117,6 @@ class Database:
async with self.sessions[self.local_ts_id]() as session: async with self.sessions[self.local_ts_id]() as session:
try: try:
result = await session.execute(text(query), kwargs) result = await session.execute(text(query), kwargs)
# Convert the result to a list of dictionaries
rows = result.fetchall() rows = result.fetchall()
if rows: if rows:
columns = result.keys() columns = result.keys()
@ -138,17 +135,17 @@ class Database:
async with self.sessions[self.local_ts_id]() as session: async with self.sessions[self.local_ts_id]() as session:
try: try:
# Serialize the kwargs using # Serialize the kwargs
serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()} serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
# Execute the write query # Execute the write query
result = await session.execute(text(query), serialized_kwargs) result = await session.execute(text(query), serialized_kwargs)
# Log the query (use json_dumps for logging purposes) # Log the query
new_query = QueryTracking( new_query = QueryTracking(
ts_id=self.local_ts_id, ts_id=self.local_ts_id,
query=query, query=query,
args=json_dumps(kwargs) # Use original kwargs for logging args=json_dumps(kwargs) # Use json_dumps for logging
) )
session.add(new_query) session.add(new_query)
await session.flush() await session.flush()
@ -162,13 +159,10 @@ class Database:
# Update query_tracking with checksum # Update query_tracking with checksum
await self.update_query_checksum(query_id, checksum) await self.update_query_checksum(query_id, checksum)
# Replicate to online servers # Perform sync operations asynchronously
online_servers = await self.get_online_servers() asyncio.create_task(self._async_sync_operations(query_id, query, serialized_kwargs, checksum))
for ts_id in online_servers:
if ts_id != self.local_ts_id:
asyncio.create_task(self._replicate_write(ts_id, query_id, query, serialized_kwargs, checksum))
return result # Return the CursorResult return result
except Exception as e: except Exception as e:
l.error(f"Failed to execute write query: {str(e)}") l.error(f"Failed to execute write query: {str(e)}")
@ -178,6 +172,27 @@ class Database:
l.error(f"Traceback: {traceback.format_exc()}") l.error(f"Traceback: {traceback.format_exc()}")
return None return None
async def _async_sync_operations(self, query_id: int, query: str, params: dict, checksum: str):
try:
await self.sync_query_tracking()
except Exception as e:
l.error(f"Failed to sync query_tracking: {str(e)}")
try:
await self.call_db_sync_on_servers()
except Exception as e:
l.error(f"Failed to call db_sync on other servers: {str(e)}")
# Replicate write to other servers
online_servers = await self.get_online_servers()
for ts_id in online_servers:
if ts_id != self.local_ts_id:
try:
await self._replicate_write(ts_id, query_id, query, params, checksum)
except Exception as e:
l.error(f"Failed to replicate write to {ts_id}: {str(e)}")
async def get_primary_server(self) -> str: async def get_primary_server(self) -> str:
url = urljoin(self.config['URL'], '/id') url = urljoin(self.config['URL'], '/id')
@ -194,7 +209,6 @@ class Database:
l.error(f"Error connecting to load balancer: {str(e)}") l.error(f"Error connecting to load balancer: {str(e)}")
return None return None
async def get_checksum_server(self) -> dict: async def get_checksum_server(self) -> dict:
primary_ts_id = await self.get_primary_server() primary_ts_id = await self.get_primary_server()
online_servers = await self.get_online_servers() online_servers = await self.get_online_servers()
@ -206,7 +220,6 @@ class Database:
return random.choice(checksum_servers) return random.choice(checksum_servers)
async def _local_compute_checksum(self, query: str, params: dict): async def _local_compute_checksum(self, query: str, params: dict):
async with self.sessions[self.local_ts_id]() as session: async with self.sessions[self.local_ts_id]() as session:
result = await session.execute(text(query), params) result = await session.execute(text(query), params)
@ -217,7 +230,6 @@ class Database:
checksum = hashlib.md5(str(data).encode()).hexdigest() checksum = hashlib.md5(str(data).encode()).hexdigest()
return checksum return checksum
async def _delegate_compute_checksum(self, server: Dict[str, Any], query: str, params: dict): async def _delegate_compute_checksum(self, server: Dict[str, Any], query: str, params: dict):
url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum" url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum"
@ -234,7 +246,6 @@ class Database:
l.error(f"Error connecting to {server['ts_id']} for checksum: {str(e)}") l.error(f"Error connecting to {server['ts_id']} for checksum: {str(e)}")
return await self._local_compute_checksum(query, params) return await self._local_compute_checksum(query, params)
async def update_query_checksum(self, query_id: int, checksum: str): async def update_query_checksum(self, query_id: int, checksum: str):
async with self.sessions[self.local_ts_id]() as session: async with self.sessions[self.local_ts_id]() as session:
await session.execute( await session.execute(
@ -243,7 +254,6 @@ class Database:
) )
await session.commit() await session.commit()
async def _replicate_write(self, ts_id: str, query_id: int, query: str, params: dict, expected_checksum: str): async def _replicate_write(self, ts_id: str, query_id: int, query: str, params: dict, expected_checksum: str):
try: try:
async with self.sessions[ts_id]() as session: async with self.sessions[ts_id]() as session:
@ -255,8 +265,8 @@ class Database:
await session.commit() await session.commit()
l.info(f"Successfully replicated write to {ts_id}") l.info(f"Successfully replicated write to {ts_id}")
except Exception as e: except Exception as e:
l.error(f"Failed to replicate write on {ts_id}: {str(e)}") l.error(f"Failed to replicate write on {ts_id}")
l.debug(f"Failed to replicate write on {ts_id}: {str(e)}")
async def mark_query_completed(self, query_id: int, ts_id: str): async def mark_query_completed(self, query_id: int, ts_id: str):
async with self.sessions[self.local_ts_id]() as session: async with self.sessions[self.local_ts_id]() as session:
@ -267,7 +277,6 @@ class Database:
query.completed_by = completed_by query.completed_by = completed_by
await session.commit() await session.commit()
async def sync_local_server(self): async def sync_local_server(self):
async with self.sessions[self.local_ts_id]() as session: async with self.sessions[self.local_ts_id]() as session:
last_synced = await session.execute( last_synced = await session.execute(
@ -295,7 +304,6 @@ class Database:
await session.commit() await session.commit()
l.info(f"Local server sync completed. Executed {unexecuted_queries.rowcount} queries.") l.info(f"Local server sync completed. Executed {unexecuted_queries.rowcount} queries.")
async def purge_completed_queries(self): async def purge_completed_queries(self):
async with self.sessions[self.local_ts_id]() as session: async with self.sessions[self.local_ts_id]() as session:
all_ts_ids = [db['ts_id'] for db in self.config['POOL']] all_ts_ids = [db['ts_id'] for db in self.config['POOL']]
@ -316,9 +324,106 @@ class Database:
deleted_count = result.rowcount deleted_count = result.rowcount
l.info(f"Purged {deleted_count} completed queries.") l.info(f"Purged {deleted_count} completed queries.")
async def sync_query_tracking(self):
"""Combinatorial sync method for the query_tracking table."""
try:
online_servers = await self.get_online_servers()
for ts_id in online_servers:
if ts_id == self.local_ts_id:
continue
try:
async with self.sessions[ts_id]() as remote_session:
local_max_id = await self.get_max_query_id(self.local_ts_id)
remote_max_id = await self.get_max_query_id(ts_id)
# Sync from remote to local
remote_new_queries = await remote_session.execute(
select(QueryTracking).where(QueryTracking.id > local_max_id)
)
for query in remote_new_queries:
await self.add_or_update_query(query)
# Sync from local to remote
async with self.sessions[self.local_ts_id]() as local_session:
local_new_queries = await local_session.execute(
select(QueryTracking).where(QueryTracking.id > remote_max_id)
)
for query in local_new_queries:
await self.add_or_update_query_remote(ts_id, query)
except Exception as e:
l.error(f"Error syncing with {ts_id}: {str(e)}")
except Exception as e:
l.error(f"Error in sync_query_tracking: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
async def get_max_query_id(self, ts_id):
async with self.sessions[ts_id]() as session:
result = await session.execute(select(func.max(QueryTracking.id)))
return result.scalar() or 0
async def add_or_update_query(self, query):
async with self.sessions[self.local_ts_id]() as session:
existing_query = await session.get(QueryTracking, query.id)
if existing_query:
existing_query.completed_by = {**existing_query.completed_by, **query.completed_by}
else:
session.add(query)
await session.commit()
async def add_or_update_query_remote(self, ts_id, query):
async with self.sessions[ts_id]() as session:
existing_query = await session.get(QueryTracking, query.id)
if existing_query:
existing_query.completed_by = {**existing_query.completed_by, **query.completed_by}
else:
new_query = QueryTracking(
id=query.id,
ts_id=query.ts_id,
query=query.query,
args=query.args,
executed_at=query.executed_at,
completed_by=query.completed_by,
result_checksum=query.result_checksum
)
session.add(new_query)
await session.commit()
async def ensure_query_tracking_table(self):
for ts_id, engine in self.engines.items():
try:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Ensured query_tracking table exists for {ts_id}")
except Exception as e:
l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}")
async def call_db_sync_on_servers(self):
"""Call /db/sync on all online servers."""
online_servers = await self.get_online_servers()
tasks = []
for server in self.config['POOL']:
if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id:
url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
tasks.append(self.call_db_sync(url))
await asyncio.gather(*tasks)
async def call_db_sync(self, url):
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, timeout=30) as response:
if response.status == 200:
l.info(f"Successfully called /db/sync on {url}")
else:
l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}")
except asyncio.TimeoutError:
l.debug(f"Timeout while calling /db/sync on {url}")
except Exception as e:
l.error(f"Error calling /db/sync on {url}: {str(e)}")
async def close(self): async def close(self):
for engine in self.engines.values(): for engine in self.engines.values():
await engine.dispose() await engine.dispose()

View file

@ -11,8 +11,8 @@ import sys
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def load_config(): def load_config(cfg: str):
config_path = Path(__file__).parent.parent / 'config' / 'sys.yaml' config_path = Path(__file__).parent.parent / 'config' / f'{cfg}.yaml'
with open(config_path, 'r') as file: with open(config_path, 'r') as file:
return yaml.safe_load(file) return yaml.safe_load(file)
@ -149,8 +149,8 @@ def kill_remote_server(server):
def main(): def main():
load_env() load_env()
config = load_config() db_config = load_config('db')
pool = config['POOL'] pool = db_config['POOL']
local_ts_id = os.environ.get('TS_ID') local_ts_id = os.environ.get('TS_ID')
parser = argparse.ArgumentParser(description='Manage sijapi servers') parser = argparse.ArgumentParser(description='Manage sijapi servers')

View file

@ -1,26 +1,26 @@
'''
System module. /health returns `'status': 'ok'`, /id returns TS_ID, /routers responds with a list of the active routers, /ip responds with the device's local IP, /ts_ip responds with its tailnet IP, and /wan_ip responds with WAN IP.
'''
# routers/sys.py # routers/sys.py
import os import os
import httpx import httpx
import socket import socket
from fastapi import APIRouter from fastapi import APIRouter, BackgroundTasks, HTTPException
from sqlalchemy import text, select
from tailscale import Tailscale from tailscale import Tailscale
from sijapi import Sys, TS_ID from sijapi import Sys, Db, TS_ID
from sijapi.logs import get_logger from sijapi.logs import get_logger
from sijapi.serialization import json_loads
from sijapi.database import QueryTracking
l = get_logger(__name__) l = get_logger(__name__)
sys = APIRouter() sys = APIRouter()
@sys.get("/health") @sys.get("/health")
def get_health(): def get_health():
return {"status": "ok"} return {"status": "ok"}
@sys.get("/id") @sys.get("/id")
def get_health() -> str: def get_id() -> str:
return TS_ID return TS_ID
@sys.get("/routers") @sys.get("/routers")
@ -66,3 +66,38 @@ async def get_tailscale_ip():
return devices[0]['addresses'][0] return devices[0]['addresses'][0]
else: else:
return "No devices found" return "No devices found"
async def sync_process():
async with Db.sessions[TS_ID]() as session:
# Find unexecuted queries
unexecuted_queries = await session.execute(
select(QueryTracking).where(~QueryTracking.completed_by.has_key(TS_ID)).order_by(QueryTracking.id)
)
for query in unexecuted_queries:
try:
params = json_loads(query.args)
await session.execute(text(query.query), params)
actual_checksum = await Db._local_compute_checksum(query.query, params)
if actual_checksum != query.result_checksum:
l.error(f"Checksum mismatch for query ID {query.id}")
continue
# Update the completed_by field
query.completed_by[TS_ID] = True
await session.commit()
l.info(f"Successfully executed and verified query ID {query.id}")
except Exception as e:
l.error(f"Failed to execute query ID {query.id} during sync: {str(e)}")
await session.rollback()
l.info(f"Sync process completed. Executed {unexecuted_queries.rowcount} queries.")
# After executing all queries, perform combinatorial sync
await Db.sync_query_tracking()
@sys.post("/db/sync")
async def db_sync(background_tasks: BackgroundTasks):
background_tasks.add_task(sync_process)
return {"message": "Sync process initiated"}

View file

@ -1,14 +1,9 @@
'''
Uses the VisualCrossing API and Postgres/PostGIS to source local weather forecasts and history.
'''
#routers/weather.py
import asyncio import asyncio
import traceback import traceback
import os import os
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from fastapi import HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from asyncpg.cursor import Cursor from asyncpg.cursor import Cursor
from httpx import AsyncClient from httpx import AsyncClient
from typing import Dict from typing import Dict
@ -19,11 +14,12 @@ from sijapi import VISUALCROSSING_API_KEY, TZ, Sys, GEO, Db
from sijapi.utilities import haversine from sijapi.utilities import haversine
from sijapi.routers import gis from sijapi.routers import gis
from sijapi.logs import get_logger from sijapi.logs import get_logger
from sijapi.serialization import json_dumps, serialize
l = get_logger(__name__) l = get_logger(__name__)
weather = APIRouter() weather = APIRouter()
@weather.get("/weather/refresh", response_class=JSONResponse) @weather.get("/weather/refresh", response_class=JSONResponse)
async def get_refreshed_weather( async def get_refreshed_weather(
date: str = Query(default=dt_datetime.now().strftime("%Y-%m-%d"), description="Enter a date in YYYY-MM-DD format, otherwise it will default to today."), date: str = Query(default=dt_datetime.now().strftime("%Y-%m-%d"), description="Enter a date in YYYY-MM-DD format, otherwise it will default to today."),
@ -50,17 +46,8 @@ async def get_refreshed_weather(
if day is None: if day is None:
raise HTTPException(status_code=404, detail="No weather data found for the given date and location") raise HTTPException(status_code=404, detail="No weather data found for the given date and location")
# Convert the day object to a JSON-serializable format json_compatible_data = jsonable_encoder({"weather": day})
day_dict = {} return JSONResponse(content=json_compatible_data)
for k, v in day.items():
if k == 'DailyWeather':
day_dict[k] = {kk: vv.isoformat() if isinstance(vv, (dt_datetime, dt_date)) else vv for kk, vv in v.items()}
elif k == 'HourlyWeather':
day_dict[k] = [{kk: vv.isoformat() if isinstance(vv, (dt_datetime, dt_date)) else vv for kk, vv in hour.items()} for hour in v]
else:
day_dict[k] = v.isoformat() if isinstance(v, (dt_datetime, dt_date)) else v
return JSONResponse(content={"weather": day_dict}, status_code=200)
except HTTPException as e: except HTTPException as e:
l.error(f"HTTP Exception in get_refreshed_weather: {e.detail}") l.error(f"HTTP Exception in get_refreshed_weather: {e.detail}")
@ -136,9 +123,6 @@ async def get_weather(date_time: dt_datetime, latitude: float, longitude: float,
return daily_weather_data return daily_weather_data
# weather.py
async def store_weather_to_db(date_time: dt_datetime, weather_data: dict): async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
try: try:
day_data = weather_data.get('days', [{}])[0] day_data = weather_data.get('days', [{}])[0]
@ -231,7 +215,7 @@ async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
hour_preciptype_array = hour_data.get('preciptype', []) or [] hour_preciptype_array = hour_data.get('preciptype', []) or []
hour_stations_array = hour_data.get('stations', []) or [] hour_stations_array = hour_data.get('stations', []) or []
hourly_weather_params = { hourly_weather_params = {
'daily_weather_id': str(daily_weather_id), # Convert UUID to string 'daily_weather_id': daily_weather_id,
'datetime': await gis.dt(hour_data.get('datetimeEpoch')), 'datetime': await gis.dt(hour_data.get('datetimeEpoch')),
'datetimeepoch': hour_data.get('datetimeEpoch'), 'datetimeepoch': hour_data.get('datetimeEpoch'),
'temp': hour_data.get('temp'), 'temp': hour_data.get('temp'),
@ -287,8 +271,6 @@ async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
l.error(f"Traceback: {traceback.format_exc()}") l.error(f"Traceback: {traceback.format_exc()}")
return "FAILURE" return "FAILURE"
async def get_weather_from_db(date_time: dt_datetime, latitude: float, longitude: float): async def get_weather_from_db(date_time: dt_datetime, latitude: float, longitude: float):
l.debug(f"Using {date_time.strftime('%Y-%m-%d %H:%M:%S')} as our datetime in get_weather_from_db.") l.debug(f"Using {date_time.strftime('%Y-%m-%d %H:%M:%S')} as our datetime in get_weather_from_db.")
query_date = date_time.date() query_date = date_time.date()
@ -311,12 +293,12 @@ async def get_weather_from_db(date_time: dt_datetime, latitude: float, longitude
hourly_query = ''' hourly_query = '''
SELECT * FROM hourlyweather SELECT * FROM hourlyweather
WHERE daily_weather_id::text = :daily_weather_id WHERE daily_weather_id = :daily_weather_id
ORDER BY datetime ASC ORDER BY datetime ASC
''' '''
hourly_weather_records = await Db.read( hourly_weather_records = await Db.read(
hourly_query, hourly_query,
daily_weather_id=str(daily_weather_data['id']), daily_weather_id=daily_weather_data['id'],
table_name='hourlyweather' table_name='hourlyweather'
) )