From fd81cbed985aa35d927d5f620e99e5b5ac9ade48 Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Mon, 12 Aug 2024 15:46:19 -0700 Subject: [PATCH] Auto-update: Mon Aug 12 15:46:19 PDT 2024 --- sijapi/__main__.py | 8 +- sijapi/database.py | 169 ++++++++++++++++++++++++++++++-------- sijapi/helpers/start.py | 8 +- sijapi/routers/sys.py | 53 ++++++++++-- sijapi/routers/weather.py | 36 ++------ 5 files changed, 196 insertions(+), 78 deletions(-) diff --git a/sijapi/__main__.py b/sijapi/__main__.py index 7daa4a5..8d57ec3 100755 --- a/sijapi/__main__.py +++ b/sijapi/__main__.py @@ -47,12 +47,6 @@ async def lifespan(app: FastAPI): # Startup l.critical("sijapi launched") l.info(f"Arguments: {args}") - - # Log the router directory path - l.debug(f"Router directory path: {Dir.ROUTER.absolute()}") - l.debug(f"Router directory exists: {Dir.ROUTER.exists()}") - l.debug(f"Router directory is a directory: {Dir.ROUTER.is_dir()}") - l.debug(f"Contents of router directory: {list(Dir.ROUTER.iterdir())}") # Load routers if args.test: @@ -64,6 +58,7 @@ async def lifespan(app: FastAPI): try: await Db.initialize_engines() + await Db.ensure_query_tracking_table() except Exception as e: l.critical(f"Error during startup: {str(e)}") l.critical(f"Traceback: {traceback.format_exc()}") @@ -82,6 +77,7 @@ async def lifespan(app: FastAPI): l.critical(f"Error during shutdown: {str(e)}") l.critical(f"Traceback: {traceback.format_exc()}") + app = FastAPI(lifespan=lifespan) app.add_middleware( diff --git a/sijapi/database.py b/sijapi/database.py index 9eb0c3a..5783260 100644 --- a/sijapi/database.py +++ b/sijapi/database.py @@ -1,9 +1,11 @@ # database.py + import json import yaml import time import aiohttp import asyncio +import traceback from datetime import datetime as dt_datetime, date from tqdm.asyncio import tqdm import reverse_geocoder as rg @@ -19,11 +21,11 @@ from srtm import get_data import os import sys from loguru import logger -from sqlalchemy import text +from sqlalchemy import text, select, func, and_ from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy.exc import OperationalError -from sqlalchemy import Column, Integer, String, DateTime, JSON, Text, select, func +from sqlalchemy import Column, Integer, String, DateTime, JSON, Text from sqlalchemy.dialects.postgresql import JSONB from urllib.parse import urljoin import hashlib @@ -41,7 +43,6 @@ ENV_PATH = CONFIG_DIR / ".env" load_dotenv(ENV_PATH) TS_ID = os.environ.get('TS_ID') - class QueryTracking(Base): __tablename__ = 'query_tracking' @@ -85,19 +86,16 @@ class Database: self.engines[db_info['ts_id']] = engine self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) l.info(f"Initialized engine and session for {db_info['ts_id']}") + + # Create tables if they don't exist + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + l.info(f"Ensured tables exist for {db_info['ts_id']}") except Exception as e: l.error(f"Failed to initialize engine for {db_info['ts_id']}: {str(e)}") if self.local_ts_id not in self.sessions: l.error(f"Failed to initialize session for local server {self.local_ts_id}") - else: - try: - # Create tables if they don't exist - async with self.engines[self.local_ts_id].begin() as conn: - await conn.run_sync(Base.metadata.create_all) - l.info(f"Initialized tables for local server {self.local_ts_id}") - except Exception as e: - l.error(f"Failed to create tables for local server {self.local_ts_id}: {str(e)}") async def get_online_servers(self) -> List[str]: online_servers = [] @@ -119,7 +117,6 @@ class Database: async with self.sessions[self.local_ts_id]() as session: try: result = await session.execute(text(query), kwargs) - # Convert the result to a list of dictionaries rows = result.fetchall() if rows: columns = result.keys() @@ -138,17 +135,17 @@ class Database: async with self.sessions[self.local_ts_id]() as session: try: - # Serialize the kwargs using + # Serialize the kwargs serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()} # Execute the write query result = await session.execute(text(query), serialized_kwargs) - # Log the query (use json_dumps for logging purposes) + # Log the query new_query = QueryTracking( ts_id=self.local_ts_id, query=query, - args=json_dumps(kwargs) # Use original kwargs for logging + args=json_dumps(kwargs) # Use json_dumps for logging ) session.add(new_query) await session.flush() @@ -162,13 +159,10 @@ class Database: # Update query_tracking with checksum await self.update_query_checksum(query_id, checksum) - # Replicate to online servers - online_servers = await self.get_online_servers() - for ts_id in online_servers: - if ts_id != self.local_ts_id: - asyncio.create_task(self._replicate_write(ts_id, query_id, query, serialized_kwargs, checksum)) + # Perform sync operations asynchronously + asyncio.create_task(self._async_sync_operations(query_id, query, serialized_kwargs, checksum)) - return result # Return the CursorResult + return result except Exception as e: l.error(f"Failed to execute write query: {str(e)}") @@ -178,6 +172,27 @@ class Database: l.error(f"Traceback: {traceback.format_exc()}") return None + async def _async_sync_operations(self, query_id: int, query: str, params: dict, checksum: str): + try: + await self.sync_query_tracking() + except Exception as e: + l.error(f"Failed to sync query_tracking: {str(e)}") + + try: + await self.call_db_sync_on_servers() + except Exception as e: + l.error(f"Failed to call db_sync on other servers: {str(e)}") + + # Replicate write to other servers + online_servers = await self.get_online_servers() + for ts_id in online_servers: + if ts_id != self.local_ts_id: + try: + await self._replicate_write(ts_id, query_id, query, params, checksum) + except Exception as e: + l.error(f"Failed to replicate write to {ts_id}: {str(e)}") + + async def get_primary_server(self) -> str: url = urljoin(self.config['URL'], '/id') @@ -194,7 +209,6 @@ class Database: l.error(f"Error connecting to load balancer: {str(e)}") return None - async def get_checksum_server(self) -> dict: primary_ts_id = await self.get_primary_server() online_servers = await self.get_online_servers() @@ -206,7 +220,6 @@ class Database: return random.choice(checksum_servers) - async def _local_compute_checksum(self, query: str, params: dict): async with self.sessions[self.local_ts_id]() as session: result = await session.execute(text(query), params) @@ -217,7 +230,6 @@ class Database: checksum = hashlib.md5(str(data).encode()).hexdigest() return checksum - async def _delegate_compute_checksum(self, server: Dict[str, Any], query: str, params: dict): url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum" @@ -234,7 +246,6 @@ class Database: l.error(f"Error connecting to {server['ts_id']} for checksum: {str(e)}") return await self._local_compute_checksum(query, params) - async def update_query_checksum(self, query_id: int, checksum: str): async with self.sessions[self.local_ts_id]() as session: await session.execute( @@ -243,7 +254,6 @@ class Database: ) await session.commit() - async def _replicate_write(self, ts_id: str, query_id: int, query: str, params: dict, expected_checksum: str): try: async with self.sessions[ts_id]() as session: @@ -255,8 +265,8 @@ class Database: await session.commit() l.info(f"Successfully replicated write to {ts_id}") except Exception as e: - l.error(f"Failed to replicate write on {ts_id}: {str(e)}") - + l.error(f"Failed to replicate write on {ts_id}") + l.debug(f"Failed to replicate write on {ts_id}: {str(e)}") async def mark_query_completed(self, query_id: int, ts_id: str): async with self.sessions[self.local_ts_id]() as session: @@ -267,7 +277,6 @@ class Database: query.completed_by = completed_by await session.commit() - async def sync_local_server(self): async with self.sessions[self.local_ts_id]() as session: last_synced = await session.execute( @@ -295,7 +304,6 @@ class Database: await session.commit() l.info(f"Local server sync completed. Executed {unexecuted_queries.rowcount} queries.") - async def purge_completed_queries(self): async with self.sessions[self.local_ts_id]() as session: all_ts_ids = [db['ts_id'] for db in self.config['POOL']] @@ -316,9 +324,106 @@ class Database: deleted_count = result.rowcount l.info(f"Purged {deleted_count} completed queries.") + async def sync_query_tracking(self): + """Combinatorial sync method for the query_tracking table.""" + try: + online_servers = await self.get_online_servers() + + for ts_id in online_servers: + if ts_id == self.local_ts_id: + continue + + try: + async with self.sessions[ts_id]() as remote_session: + local_max_id = await self.get_max_query_id(self.local_ts_id) + remote_max_id = await self.get_max_query_id(ts_id) + + # Sync from remote to local + remote_new_queries = await remote_session.execute( + select(QueryTracking).where(QueryTracking.id > local_max_id) + ) + for query in remote_new_queries: + await self.add_or_update_query(query) + + # Sync from local to remote + async with self.sessions[self.local_ts_id]() as local_session: + local_new_queries = await local_session.execute( + select(QueryTracking).where(QueryTracking.id > remote_max_id) + ) + for query in local_new_queries: + await self.add_or_update_query_remote(ts_id, query) + except Exception as e: + l.error(f"Error syncing with {ts_id}: {str(e)}") + except Exception as e: + l.error(f"Error in sync_query_tracking: {str(e)}") + l.error(f"Traceback: {traceback.format_exc()}") + + + async def get_max_query_id(self, ts_id): + async with self.sessions[ts_id]() as session: + result = await session.execute(select(func.max(QueryTracking.id))) + return result.scalar() or 0 + + async def add_or_update_query(self, query): + async with self.sessions[self.local_ts_id]() as session: + existing_query = await session.get(QueryTracking, query.id) + if existing_query: + existing_query.completed_by = {**existing_query.completed_by, **query.completed_by} + else: + session.add(query) + await session.commit() + + async def add_or_update_query_remote(self, ts_id, query): + async with self.sessions[ts_id]() as session: + existing_query = await session.get(QueryTracking, query.id) + if existing_query: + existing_query.completed_by = {**existing_query.completed_by, **query.completed_by} + else: + new_query = QueryTracking( + id=query.id, + ts_id=query.ts_id, + query=query.query, + args=query.args, + executed_at=query.executed_at, + completed_by=query.completed_by, + result_checksum=query.result_checksum + ) + session.add(new_query) + await session.commit() + + async def ensure_query_tracking_table(self): + for ts_id, engine in self.engines.items(): + try: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + l.info(f"Ensured query_tracking table exists for {ts_id}") + except Exception as e: + l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}") + + + async def call_db_sync_on_servers(self): + """Call /db/sync on all online servers.""" + online_servers = await self.get_online_servers() + tasks = [] + for server in self.config['POOL']: + if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id: + url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync" + tasks.append(self.call_db_sync(url)) + await asyncio.gather(*tasks) + + async def call_db_sync(self, url): + async with aiohttp.ClientSession() as session: + try: + async with session.post(url, timeout=30) as response: + if response.status == 200: + l.info(f"Successfully called /db/sync on {url}") + else: + l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}") + except asyncio.TimeoutError: + l.debug(f"Timeout while calling /db/sync on {url}") + except Exception as e: + l.error(f"Error calling /db/sync on {url}: {str(e)}") async def close(self): for engine in self.engines.values(): await engine.dispose() - - \ No newline at end of file diff --git a/sijapi/helpers/start.py b/sijapi/helpers/start.py index 031e4fb..e18fe7b 100644 --- a/sijapi/helpers/start.py +++ b/sijapi/helpers/start.py @@ -11,8 +11,8 @@ import sys logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') -def load_config(): - config_path = Path(__file__).parent.parent / 'config' / 'sys.yaml' +def load_config(cfg: str): + config_path = Path(__file__).parent.parent / 'config' / f'{cfg}.yaml' with open(config_path, 'r') as file: return yaml.safe_load(file) @@ -149,8 +149,8 @@ def kill_remote_server(server): def main(): load_env() - config = load_config() - pool = config['POOL'] + db_config = load_config('db') + pool = db_config['POOL'] local_ts_id = os.environ.get('TS_ID') parser = argparse.ArgumentParser(description='Manage sijapi servers') diff --git a/sijapi/routers/sys.py b/sijapi/routers/sys.py index 2346df4..6a33d99 100644 --- a/sijapi/routers/sys.py +++ b/sijapi/routers/sys.py @@ -1,26 +1,26 @@ -''' -System module. /health returns `'status': 'ok'`, /id returns TS_ID, /routers responds with a list of the active routers, /ip responds with the device's local IP, /ts_ip responds with its tailnet IP, and /wan_ip responds with WAN IP. -''' -#routers/sys.py +# routers/sys.py import os import httpx import socket -from fastapi import APIRouter +from fastapi import APIRouter, BackgroundTasks, HTTPException +from sqlalchemy import text, select from tailscale import Tailscale -from sijapi import Sys, TS_ID +from sijapi import Sys, Db, TS_ID from sijapi.logs import get_logger +from sijapi.serialization import json_loads +from sijapi.database import QueryTracking + l = get_logger(__name__) sys = APIRouter() - @sys.get("/health") def get_health(): return {"status": "ok"} @sys.get("/id") -def get_health() -> str: +def get_id() -> str: return TS_ID @sys.get("/routers") @@ -65,4 +65,39 @@ async def get_tailscale_ip(): # Assuming you want the IP of the first device in the list return devices[0]['addresses'][0] else: - return "No devices found" \ No newline at end of file + return "No devices found" + +async def sync_process(): + async with Db.sessions[TS_ID]() as session: + # Find unexecuted queries + unexecuted_queries = await session.execute( + select(QueryTracking).where(~QueryTracking.completed_by.has_key(TS_ID)).order_by(QueryTracking.id) + ) + + for query in unexecuted_queries: + try: + params = json_loads(query.args) + await session.execute(text(query.query), params) + actual_checksum = await Db._local_compute_checksum(query.query, params) + if actual_checksum != query.result_checksum: + l.error(f"Checksum mismatch for query ID {query.id}") + continue + + # Update the completed_by field + query.completed_by[TS_ID] = True + await session.commit() + + l.info(f"Successfully executed and verified query ID {query.id}") + except Exception as e: + l.error(f"Failed to execute query ID {query.id} during sync: {str(e)}") + await session.rollback() + + l.info(f"Sync process completed. Executed {unexecuted_queries.rowcount} queries.") + + # After executing all queries, perform combinatorial sync + await Db.sync_query_tracking() + +@sys.post("/db/sync") +async def db_sync(background_tasks: BackgroundTasks): + background_tasks.add_task(sync_process) + return {"message": "Sync process initiated"} diff --git a/sijapi/routers/weather.py b/sijapi/routers/weather.py index b2c9b26..49c31b0 100644 --- a/sijapi/routers/weather.py +++ b/sijapi/routers/weather.py @@ -1,14 +1,9 @@ -''' -Uses the VisualCrossing API and Postgres/PostGIS to source local weather forecasts and history. -''' -#routers/weather.py - import asyncio import traceback import os from fastapi import APIRouter, HTTPException, Query -from fastapi import HTTPException from fastapi.responses import JSONResponse +from fastapi.encoders import jsonable_encoder from asyncpg.cursor import Cursor from httpx import AsyncClient from typing import Dict @@ -19,11 +14,12 @@ from sijapi import VISUALCROSSING_API_KEY, TZ, Sys, GEO, Db from sijapi.utilities import haversine from sijapi.routers import gis from sijapi.logs import get_logger +from sijapi.serialization import json_dumps, serialize + l = get_logger(__name__) weather = APIRouter() - @weather.get("/weather/refresh", response_class=JSONResponse) async def get_refreshed_weather( date: str = Query(default=dt_datetime.now().strftime("%Y-%m-%d"), description="Enter a date in YYYY-MM-DD format, otherwise it will default to today."), @@ -49,18 +45,9 @@ async def get_refreshed_weather( if day is None: raise HTTPException(status_code=404, detail="No weather data found for the given date and location") - - # Convert the day object to a JSON-serializable format - day_dict = {} - for k, v in day.items(): - if k == 'DailyWeather': - day_dict[k] = {kk: vv.isoformat() if isinstance(vv, (dt_datetime, dt_date)) else vv for kk, vv in v.items()} - elif k == 'HourlyWeather': - day_dict[k] = [{kk: vv.isoformat() if isinstance(vv, (dt_datetime, dt_date)) else vv for kk, vv in hour.items()} for hour in v] - else: - day_dict[k] = v.isoformat() if isinstance(v, (dt_datetime, dt_date)) else v - return JSONResponse(content={"weather": day_dict}, status_code=200) + json_compatible_data = jsonable_encoder({"weather": day}) + return JSONResponse(content=json_compatible_data) except HTTPException as e: l.error(f"HTTP Exception in get_refreshed_weather: {e.detail}") @@ -136,9 +123,6 @@ async def get_weather(date_time: dt_datetime, latitude: float, longitude: float, return daily_weather_data - -# weather.py - async def store_weather_to_db(date_time: dt_datetime, weather_data: dict): try: day_data = weather_data.get('days', [{}])[0] @@ -231,7 +215,7 @@ async def store_weather_to_db(date_time: dt_datetime, weather_data: dict): hour_preciptype_array = hour_data.get('preciptype', []) or [] hour_stations_array = hour_data.get('stations', []) or [] hourly_weather_params = { - 'daily_weather_id': str(daily_weather_id), # Convert UUID to string + 'daily_weather_id': daily_weather_id, 'datetime': await gis.dt(hour_data.get('datetimeEpoch')), 'datetimeepoch': hour_data.get('datetimeEpoch'), 'temp': hour_data.get('temp'), @@ -287,8 +271,6 @@ async def store_weather_to_db(date_time: dt_datetime, weather_data: dict): l.error(f"Traceback: {traceback.format_exc()}") return "FAILURE" - - async def get_weather_from_db(date_time: dt_datetime, latitude: float, longitude: float): l.debug(f"Using {date_time.strftime('%Y-%m-%d %H:%M:%S')} as our datetime in get_weather_from_db.") query_date = date_time.date() @@ -311,12 +293,12 @@ async def get_weather_from_db(date_time: dt_datetime, latitude: float, longitude hourly_query = ''' SELECT * FROM hourlyweather - WHERE daily_weather_id::text = :daily_weather_id + WHERE daily_weather_id = :daily_weather_id ORDER BY datetime ASC ''' hourly_weather_records = await Db.read( hourly_query, - daily_weather_id=str(daily_weather_data['id']), + daily_weather_id=daily_weather_data['id'], table_name='hourlyweather' ) @@ -331,4 +313,4 @@ async def get_weather_from_db(date_time: dt_datetime, latitude: float, longitude except Exception as e: l.error(f"Unexpected error occurred in get_weather_from_db: {e}") l.error(f"Traceback: {traceback.format_exc()}") - return None \ No newline at end of file + return None