Auto-update: Mon Aug 12 15:46:19 PDT 2024

This commit is contained in:
sanj 2024-08-12 15:46:19 -07:00
parent ff678a5df1
commit fd81cbed98
5 changed files with 196 additions and 78 deletions

View file

@ -47,12 +47,6 @@ async def lifespan(app: FastAPI):
# Startup
l.critical("sijapi launched")
l.info(f"Arguments: {args}")
# Log the router directory path
l.debug(f"Router directory path: {Dir.ROUTER.absolute()}")
l.debug(f"Router directory exists: {Dir.ROUTER.exists()}")
l.debug(f"Router directory is a directory: {Dir.ROUTER.is_dir()}")
l.debug(f"Contents of router directory: {list(Dir.ROUTER.iterdir())}")
# Load routers
if args.test:
@ -64,6 +58,7 @@ async def lifespan(app: FastAPI):
try:
await Db.initialize_engines()
await Db.ensure_query_tracking_table()
except Exception as e:
l.critical(f"Error during startup: {str(e)}")
l.critical(f"Traceback: {traceback.format_exc()}")
@ -82,6 +77,7 @@ async def lifespan(app: FastAPI):
l.critical(f"Error during shutdown: {str(e)}")
l.critical(f"Traceback: {traceback.format_exc()}")
app = FastAPI(lifespan=lifespan)
app.add_middleware(

View file

@ -1,9 +1,11 @@
# database.py
import json
import yaml
import time
import aiohttp
import asyncio
import traceback
from datetime import datetime as dt_datetime, date
from tqdm.asyncio import tqdm
import reverse_geocoder as rg
@ -19,11 +21,11 @@ from srtm import get_data
import os
import sys
from loguru import logger
from sqlalchemy import text
from sqlalchemy import text, select, func, and_
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.exc import OperationalError
from sqlalchemy import Column, Integer, String, DateTime, JSON, Text, select, func
from sqlalchemy import Column, Integer, String, DateTime, JSON, Text
from sqlalchemy.dialects.postgresql import JSONB
from urllib.parse import urljoin
import hashlib
@ -41,7 +43,6 @@ ENV_PATH = CONFIG_DIR / ".env"
load_dotenv(ENV_PATH)
TS_ID = os.environ.get('TS_ID')
class QueryTracking(Base):
__tablename__ = 'query_tracking'
@ -85,19 +86,16 @@ class Database:
self.engines[db_info['ts_id']] = engine
self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
l.info(f"Initialized engine and session for {db_info['ts_id']}")
# Create tables if they don't exist
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Ensured tables exist for {db_info['ts_id']}")
except Exception as e:
l.error(f"Failed to initialize engine for {db_info['ts_id']}: {str(e)}")
if self.local_ts_id not in self.sessions:
l.error(f"Failed to initialize session for local server {self.local_ts_id}")
else:
try:
# Create tables if they don't exist
async with self.engines[self.local_ts_id].begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Initialized tables for local server {self.local_ts_id}")
except Exception as e:
l.error(f"Failed to create tables for local server {self.local_ts_id}: {str(e)}")
async def get_online_servers(self) -> List[str]:
online_servers = []
@ -119,7 +117,6 @@ class Database:
async with self.sessions[self.local_ts_id]() as session:
try:
result = await session.execute(text(query), kwargs)
# Convert the result to a list of dictionaries
rows = result.fetchall()
if rows:
columns = result.keys()
@ -138,17 +135,17 @@ class Database:
async with self.sessions[self.local_ts_id]() as session:
try:
# Serialize the kwargs using
# Serialize the kwargs
serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
# Execute the write query
result = await session.execute(text(query), serialized_kwargs)
# Log the query (use json_dumps for logging purposes)
# Log the query
new_query = QueryTracking(
ts_id=self.local_ts_id,
query=query,
args=json_dumps(kwargs) # Use original kwargs for logging
args=json_dumps(kwargs) # Use json_dumps for logging
)
session.add(new_query)
await session.flush()
@ -162,13 +159,10 @@ class Database:
# Update query_tracking with checksum
await self.update_query_checksum(query_id, checksum)
# Replicate to online servers
online_servers = await self.get_online_servers()
for ts_id in online_servers:
if ts_id != self.local_ts_id:
asyncio.create_task(self._replicate_write(ts_id, query_id, query, serialized_kwargs, checksum))
# Perform sync operations asynchronously
asyncio.create_task(self._async_sync_operations(query_id, query, serialized_kwargs, checksum))
return result # Return the CursorResult
return result
except Exception as e:
l.error(f"Failed to execute write query: {str(e)}")
@ -178,6 +172,27 @@ class Database:
l.error(f"Traceback: {traceback.format_exc()}")
return None
async def _async_sync_operations(self, query_id: int, query: str, params: dict, checksum: str):
try:
await self.sync_query_tracking()
except Exception as e:
l.error(f"Failed to sync query_tracking: {str(e)}")
try:
await self.call_db_sync_on_servers()
except Exception as e:
l.error(f"Failed to call db_sync on other servers: {str(e)}")
# Replicate write to other servers
online_servers = await self.get_online_servers()
for ts_id in online_servers:
if ts_id != self.local_ts_id:
try:
await self._replicate_write(ts_id, query_id, query, params, checksum)
except Exception as e:
l.error(f"Failed to replicate write to {ts_id}: {str(e)}")
async def get_primary_server(self) -> str:
url = urljoin(self.config['URL'], '/id')
@ -194,7 +209,6 @@ class Database:
l.error(f"Error connecting to load balancer: {str(e)}")
return None
async def get_checksum_server(self) -> dict:
primary_ts_id = await self.get_primary_server()
online_servers = await self.get_online_servers()
@ -206,7 +220,6 @@ class Database:
return random.choice(checksum_servers)
async def _local_compute_checksum(self, query: str, params: dict):
async with self.sessions[self.local_ts_id]() as session:
result = await session.execute(text(query), params)
@ -217,7 +230,6 @@ class Database:
checksum = hashlib.md5(str(data).encode()).hexdigest()
return checksum
async def _delegate_compute_checksum(self, server: Dict[str, Any], query: str, params: dict):
url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum"
@ -234,7 +246,6 @@ class Database:
l.error(f"Error connecting to {server['ts_id']} for checksum: {str(e)}")
return await self._local_compute_checksum(query, params)
async def update_query_checksum(self, query_id: int, checksum: str):
async with self.sessions[self.local_ts_id]() as session:
await session.execute(
@ -243,7 +254,6 @@ class Database:
)
await session.commit()
async def _replicate_write(self, ts_id: str, query_id: int, query: str, params: dict, expected_checksum: str):
try:
async with self.sessions[ts_id]() as session:
@ -255,8 +265,8 @@ class Database:
await session.commit()
l.info(f"Successfully replicated write to {ts_id}")
except Exception as e:
l.error(f"Failed to replicate write on {ts_id}: {str(e)}")
l.error(f"Failed to replicate write on {ts_id}")
l.debug(f"Failed to replicate write on {ts_id}: {str(e)}")
async def mark_query_completed(self, query_id: int, ts_id: str):
async with self.sessions[self.local_ts_id]() as session:
@ -267,7 +277,6 @@ class Database:
query.completed_by = completed_by
await session.commit()
async def sync_local_server(self):
async with self.sessions[self.local_ts_id]() as session:
last_synced = await session.execute(
@ -295,7 +304,6 @@ class Database:
await session.commit()
l.info(f"Local server sync completed. Executed {unexecuted_queries.rowcount} queries.")
async def purge_completed_queries(self):
async with self.sessions[self.local_ts_id]() as session:
all_ts_ids = [db['ts_id'] for db in self.config['POOL']]
@ -316,9 +324,106 @@ class Database:
deleted_count = result.rowcount
l.info(f"Purged {deleted_count} completed queries.")
async def sync_query_tracking(self):
"""Combinatorial sync method for the query_tracking table."""
try:
online_servers = await self.get_online_servers()
for ts_id in online_servers:
if ts_id == self.local_ts_id:
continue
try:
async with self.sessions[ts_id]() as remote_session:
local_max_id = await self.get_max_query_id(self.local_ts_id)
remote_max_id = await self.get_max_query_id(ts_id)
# Sync from remote to local
remote_new_queries = await remote_session.execute(
select(QueryTracking).where(QueryTracking.id > local_max_id)
)
for query in remote_new_queries:
await self.add_or_update_query(query)
# Sync from local to remote
async with self.sessions[self.local_ts_id]() as local_session:
local_new_queries = await local_session.execute(
select(QueryTracking).where(QueryTracking.id > remote_max_id)
)
for query in local_new_queries:
await self.add_or_update_query_remote(ts_id, query)
except Exception as e:
l.error(f"Error syncing with {ts_id}: {str(e)}")
except Exception as e:
l.error(f"Error in sync_query_tracking: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
async def get_max_query_id(self, ts_id):
async with self.sessions[ts_id]() as session:
result = await session.execute(select(func.max(QueryTracking.id)))
return result.scalar() or 0
async def add_or_update_query(self, query):
async with self.sessions[self.local_ts_id]() as session:
existing_query = await session.get(QueryTracking, query.id)
if existing_query:
existing_query.completed_by = {**existing_query.completed_by, **query.completed_by}
else:
session.add(query)
await session.commit()
async def add_or_update_query_remote(self, ts_id, query):
async with self.sessions[ts_id]() as session:
existing_query = await session.get(QueryTracking, query.id)
if existing_query:
existing_query.completed_by = {**existing_query.completed_by, **query.completed_by}
else:
new_query = QueryTracking(
id=query.id,
ts_id=query.ts_id,
query=query.query,
args=query.args,
executed_at=query.executed_at,
completed_by=query.completed_by,
result_checksum=query.result_checksum
)
session.add(new_query)
await session.commit()
async def ensure_query_tracking_table(self):
for ts_id, engine in self.engines.items():
try:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Ensured query_tracking table exists for {ts_id}")
except Exception as e:
l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}")
async def call_db_sync_on_servers(self):
"""Call /db/sync on all online servers."""
online_servers = await self.get_online_servers()
tasks = []
for server in self.config['POOL']:
if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id:
url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
tasks.append(self.call_db_sync(url))
await asyncio.gather(*tasks)
async def call_db_sync(self, url):
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, timeout=30) as response:
if response.status == 200:
l.info(f"Successfully called /db/sync on {url}")
else:
l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}")
except asyncio.TimeoutError:
l.debug(f"Timeout while calling /db/sync on {url}")
except Exception as e:
l.error(f"Error calling /db/sync on {url}: {str(e)}")
async def close(self):
for engine in self.engines.values():
await engine.dispose()

View file

@ -11,8 +11,8 @@ import sys
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def load_config():
config_path = Path(__file__).parent.parent / 'config' / 'sys.yaml'
def load_config(cfg: str):
config_path = Path(__file__).parent.parent / 'config' / f'{cfg}.yaml'
with open(config_path, 'r') as file:
return yaml.safe_load(file)
@ -149,8 +149,8 @@ def kill_remote_server(server):
def main():
load_env()
config = load_config()
pool = config['POOL']
db_config = load_config('db')
pool = db_config['POOL']
local_ts_id = os.environ.get('TS_ID')
parser = argparse.ArgumentParser(description='Manage sijapi servers')

View file

@ -1,26 +1,26 @@
'''
System module. /health returns `'status': 'ok'`, /id returns TS_ID, /routers responds with a list of the active routers, /ip responds with the device's local IP, /ts_ip responds with its tailnet IP, and /wan_ip responds with WAN IP.
'''
#routers/sys.py
# routers/sys.py
import os
import httpx
import socket
from fastapi import APIRouter
from fastapi import APIRouter, BackgroundTasks, HTTPException
from sqlalchemy import text, select
from tailscale import Tailscale
from sijapi import Sys, TS_ID
from sijapi import Sys, Db, TS_ID
from sijapi.logs import get_logger
from sijapi.serialization import json_loads
from sijapi.database import QueryTracking
l = get_logger(__name__)
sys = APIRouter()
@sys.get("/health")
def get_health():
return {"status": "ok"}
@sys.get("/id")
def get_health() -> str:
def get_id() -> str:
return TS_ID
@sys.get("/routers")
@ -65,4 +65,39 @@ async def get_tailscale_ip():
# Assuming you want the IP of the first device in the list
return devices[0]['addresses'][0]
else:
return "No devices found"
return "No devices found"
async def sync_process():
async with Db.sessions[TS_ID]() as session:
# Find unexecuted queries
unexecuted_queries = await session.execute(
select(QueryTracking).where(~QueryTracking.completed_by.has_key(TS_ID)).order_by(QueryTracking.id)
)
for query in unexecuted_queries:
try:
params = json_loads(query.args)
await session.execute(text(query.query), params)
actual_checksum = await Db._local_compute_checksum(query.query, params)
if actual_checksum != query.result_checksum:
l.error(f"Checksum mismatch for query ID {query.id}")
continue
# Update the completed_by field
query.completed_by[TS_ID] = True
await session.commit()
l.info(f"Successfully executed and verified query ID {query.id}")
except Exception as e:
l.error(f"Failed to execute query ID {query.id} during sync: {str(e)}")
await session.rollback()
l.info(f"Sync process completed. Executed {unexecuted_queries.rowcount} queries.")
# After executing all queries, perform combinatorial sync
await Db.sync_query_tracking()
@sys.post("/db/sync")
async def db_sync(background_tasks: BackgroundTasks):
background_tasks.add_task(sync_process)
return {"message": "Sync process initiated"}

View file

@ -1,14 +1,9 @@
'''
Uses the VisualCrossing API and Postgres/PostGIS to source local weather forecasts and history.
'''
#routers/weather.py
import asyncio
import traceback
import os
from fastapi import APIRouter, HTTPException, Query
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from fastapi.encoders import jsonable_encoder
from asyncpg.cursor import Cursor
from httpx import AsyncClient
from typing import Dict
@ -19,11 +14,12 @@ from sijapi import VISUALCROSSING_API_KEY, TZ, Sys, GEO, Db
from sijapi.utilities import haversine
from sijapi.routers import gis
from sijapi.logs import get_logger
from sijapi.serialization import json_dumps, serialize
l = get_logger(__name__)
weather = APIRouter()
@weather.get("/weather/refresh", response_class=JSONResponse)
async def get_refreshed_weather(
date: str = Query(default=dt_datetime.now().strftime("%Y-%m-%d"), description="Enter a date in YYYY-MM-DD format, otherwise it will default to today."),
@ -49,18 +45,9 @@ async def get_refreshed_weather(
if day is None:
raise HTTPException(status_code=404, detail="No weather data found for the given date and location")
# Convert the day object to a JSON-serializable format
day_dict = {}
for k, v in day.items():
if k == 'DailyWeather':
day_dict[k] = {kk: vv.isoformat() if isinstance(vv, (dt_datetime, dt_date)) else vv for kk, vv in v.items()}
elif k == 'HourlyWeather':
day_dict[k] = [{kk: vv.isoformat() if isinstance(vv, (dt_datetime, dt_date)) else vv for kk, vv in hour.items()} for hour in v]
else:
day_dict[k] = v.isoformat() if isinstance(v, (dt_datetime, dt_date)) else v
return JSONResponse(content={"weather": day_dict}, status_code=200)
json_compatible_data = jsonable_encoder({"weather": day})
return JSONResponse(content=json_compatible_data)
except HTTPException as e:
l.error(f"HTTP Exception in get_refreshed_weather: {e.detail}")
@ -136,9 +123,6 @@ async def get_weather(date_time: dt_datetime, latitude: float, longitude: float,
return daily_weather_data
# weather.py
async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
try:
day_data = weather_data.get('days', [{}])[0]
@ -231,7 +215,7 @@ async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
hour_preciptype_array = hour_data.get('preciptype', []) or []
hour_stations_array = hour_data.get('stations', []) or []
hourly_weather_params = {
'daily_weather_id': str(daily_weather_id), # Convert UUID to string
'daily_weather_id': daily_weather_id,
'datetime': await gis.dt(hour_data.get('datetimeEpoch')),
'datetimeepoch': hour_data.get('datetimeEpoch'),
'temp': hour_data.get('temp'),
@ -287,8 +271,6 @@ async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
l.error(f"Traceback: {traceback.format_exc()}")
return "FAILURE"
async def get_weather_from_db(date_time: dt_datetime, latitude: float, longitude: float):
l.debug(f"Using {date_time.strftime('%Y-%m-%d %H:%M:%S')} as our datetime in get_weather_from_db.")
query_date = date_time.date()
@ -311,12 +293,12 @@ async def get_weather_from_db(date_time: dt_datetime, latitude: float, longitude
hourly_query = '''
SELECT * FROM hourlyweather
WHERE daily_weather_id::text = :daily_weather_id
WHERE daily_weather_id = :daily_weather_id
ORDER BY datetime ASC
'''
hourly_weather_records = await Db.read(
hourly_query,
daily_weather_id=str(daily_weather_data['id']),
daily_weather_id=daily_weather_data['id'],
table_name='hourlyweather'
)
@ -331,4 +313,4 @@ async def get_weather_from_db(date_time: dt_datetime, latitude: float, longitude
except Exception as e:
l.error(f"Unexpected error occurred in get_weather_from_db: {e}")
l.error(f"Traceback: {traceback.format_exc()}")
return None
return None