auto-update

This commit is contained in:
sanj 2024-11-09 18:03:34 -08:00
parent 79b99ef68f
commit 56a203ea74
4 changed files with 341 additions and 234 deletions

View file

@ -58,7 +58,6 @@ async def lifespan(app: FastAPI):
try: try:
await Db.initialize_engines() await Db.initialize_engines()
await Db.ensure_query_tracking_table()
except Exception as e: except Exception as e:
l.critical(f"Error during startup: {str(e)}") l.critical(f"Error during startup: {str(e)}")
l.critical(f"Traceback: {traceback.format_exc()}") l.critical(f"Traceback: {traceback.format_exc()}")

View file

@ -1,38 +1,21 @@
import json # database.py
import yaml import yaml
import time
import aiohttp
import asyncio
import traceback
from datetime import datetime as dt_datetime, date
from tqdm.asyncio import tqdm
import reverse_geocoder as rg
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar, ClassVar from typing import Any, Dict, List
from dotenv import load_dotenv from dotenv import load_dotenv
from pydantic import BaseModel, Field, create_model, PrivateAttr from pydantic import BaseModel
from concurrent.futures import ThreadPoolExecutor from datetime import datetime
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from zoneinfo import ZoneInfo
from srtm import get_data
import os
import sys
from loguru import logger from loguru import logger
from sqlalchemy import text, select, func, and_ from sqlalchemy import text, create_engine
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base, make_transient from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.exc import OperationalError
from sqlalchemy import Column, Integer, String, DateTime, JSON, Text from sqlalchemy import Column, Integer, String, DateTime, JSON, Text
import uuid import uuid
from sqlalchemy import Column, String, DateTime, Text, ARRAY
from sqlalchemy.dialects.postgresql import UUID, JSONB from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.sql import func from sqlalchemy.sql import func
from urllib.parse import urljoin import os
import hashlib
import random
from .logs import get_logger from .logs import get_logger
from .serialization import json_dumps, json_serial, serialize from .serialization import serialize
l = get_logger(__name__) l = get_logger(__name__)
@ -44,24 +27,7 @@ ENV_PATH = CONFIG_DIR / ".env"
load_dotenv(ENV_PATH) load_dotenv(ENV_PATH)
TS_ID = os.environ.get('TS_ID') TS_ID = os.environ.get('TS_ID')
class QueryTracking(Base):
__tablename__ = 'query_tracking'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
origin_ts_id = Column(String, nullable=False)
query = Column(Text, nullable=False)
args = Column(JSON)
executed_at = Column(DateTime(timezone=True), server_default=func.now())
completed_by = Column(ARRAY(String), default=[])
result_checksum = Column(String(32))
class Database: class Database:
SYNC_COOLDOWN = 30 # seconds
@classmethod @classmethod
def init(cls, config_name: str): def init(cls, config_name: str):
return cls(config_name) return cls(config_name)
@ -70,9 +36,7 @@ class Database:
self.config = self.load_config(config_path) self.config = self.load_config(config_path)
self.engines: Dict[str, Any] = {} self.engines: Dict[str, Any] = {}
self.sessions: Dict[str, Any] = {} self.sessions: Dict[str, Any] = {}
self.online_servers: set = set()
self.local_ts_id = self.get_local_ts_id() self.local_ts_id = self.get_local_ts_id()
self.last_sync_time = 0
def load_config(self, config_path: str) -> Dict[str, Any]: def load_config(self, config_path: str) -> Dict[str, Any]:
base_path = Path(__file__).parent.parent base_path = Path(__file__).parent.parent
@ -95,7 +59,6 @@ class Database:
self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
l.info(f"Initialized engine and session for {db_info['ts_id']}") l.info(f"Initialized engine and session for {db_info['ts_id']}")
# Create tables if they don't exist
async with engine.begin() as conn: async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all) await conn.run_sync(Base.metadata.create_all)
l.info(f"Ensured tables exist for {db_info['ts_id']}") l.info(f"Ensured tables exist for {db_info['ts_id']}")
@ -105,20 +68,6 @@ class Database:
if self.local_ts_id not in self.sessions: if self.local_ts_id not in self.sessions:
l.error(f"Failed to initialize session for local server {self.local_ts_id}") l.error(f"Failed to initialize session for local server {self.local_ts_id}")
async def get_online_servers(self) -> List[str]:
online_servers = []
for ts_id, engine in self.engines.items():
try:
async with engine.connect() as conn:
await conn.execute(text("SELECT 1"))
online_servers.append(ts_id)
l.debug(f"Server {ts_id} is online")
except OperationalError:
l.warning(f"Server {ts_id} is offline")
self.online_servers = set(online_servers)
l.info(f"Online servers: {', '.join(online_servers)}")
return online_servers
async def read(self, query: str, **kwargs): async def read(self, query: str, **kwargs):
if self.local_ts_id not in self.sessions: if self.local_ts_id not in self.sessions:
l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.") l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.")
@ -144,190 +93,17 @@ class Database:
async with self.sessions[self.local_ts_id]() as session: async with self.sessions[self.local_ts_id]() as session:
try: try:
# Execute the write query locally
serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()} serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
result = await session.execute(text(query), serialized_kwargs) result = await session.execute(text(query), serialized_kwargs)
await session.commit() await session.commit()
# Initiate async operations
asyncio.create_task(self._async_sync_operations(query, kwargs))
# Return the result
return result return result
except Exception as e: except Exception as e:
l.error(f"Failed to execute write query: {str(e)}") l.error(f"Failed to execute write query: {str(e)}")
l.error(f"Query: {query}") l.error(f"Query: {query}")
l.error(f"Kwargs: {kwargs}") l.error(f"Kwargs: {kwargs}")
l.error(f"Serialized kwargs: {serialized_kwargs}") l.error(f"Serialized kwargs: {serialized_kwargs}")
l.error(f"Traceback: {traceback.format_exc()}")
return None return None
async def _async_sync_operations(self, query: str, kwargs: dict):
try:
# Add the write query to the query_tracking table
await self.add_query_to_tracking(query, kwargs)
# Call /db/sync on all online servers
await self.call_db_sync_on_servers()
except Exception as e:
l.error(f"Error in async sync operations: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
async def add_query_to_tracking(self, query: str, kwargs: dict, result_checksum: str = None):
async with self.sessions[self.local_ts_id]() as session:
new_query = QueryTracking(
origin_ts_id=self.local_ts_id,
query=query,
args=json_dumps(kwargs),
completed_by=[self.local_ts_id],
result_checksum=result_checksum
)
session.add(new_query)
await session.commit()
l.info(f"Added query to tracking: {query[:50]}...")
async def sync_db(self):
current_time = time.time()
if current_time - self.last_sync_time < self.SYNC_COOLDOWN:
l.info(f"Skipping sync, last sync was less than {self.SYNC_COOLDOWN} seconds ago")
return
try:
l.info("Starting database synchronization")
self.last_sync_time = current_time # Update the last sync time before starting
await self.pull_query_tracking_from_all_servers()
await self.execute_unexecuted_queries()
l.info("Database synchronization completed successfully")
except Exception as e:
l.error(f"Error during database sync: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
finally:
# Ensure the cooldown is respected even if an error occurs
self.last_sync_time = max(self.last_sync_time, current_time)
async def pull_query_tracking_from_all_servers(self):
online_servers = await self.get_online_servers()
l.info(f"Pulling query tracking from {len(online_servers)} online servers")
for server_id in online_servers:
if server_id == self.local_ts_id:
continue # Skip local server
l.info(f"Pulling queries from server: {server_id}")
async with self.sessions[server_id]() as remote_session:
try:
result = await remote_session.execute(select(QueryTracking))
queries = result.scalars().all()
l.info(f"Retrieved {len(queries)} queries from server {server_id}")
async with self.sessions[self.local_ts_id]() as local_session:
for query in queries:
# Detach the object from its original session
make_transient(query)
existing = await local_session.execute(
select(QueryTracking).where(QueryTracking.id == query.id)
)
existing = existing.scalar_one_or_none()
if existing:
# Update existing query
existing.completed_by = list(set(existing.completed_by + query.completed_by))
l.debug(f"Updated existing query: {query.id}")
else:
# Create a new instance for the local session
new_query = QueryTracking(
id=query.id,
origin_ts_id=query.origin_ts_id,
query=query.query,
args=query.args,
executed_at=query.executed_at,
completed_by=query.completed_by,
result_checksum=query.result_checksum
)
local_session.add(new_query)
l.debug(f"Added new query: {query.id}")
await local_session.commit()
except Exception as e:
l.error(f"Error pulling queries from server {server_id}: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
l.info("Finished pulling queries from all servers")
async def execute_unexecuted_queries(self):
async with self.sessions[self.local_ts_id]() as session:
unexecuted_queries = await session.execute(
select(QueryTracking).where(~QueryTracking.completed_by.any(self.local_ts_id)).order_by(QueryTracking.executed_at)
)
unexecuted_queries = unexecuted_queries.scalars().all()
l.info(f"Executing {len(unexecuted_queries)} unexecuted queries")
for query in unexecuted_queries:
try:
params = json.loads(query.args)
# Convert string datetime to datetime objects
for key, value in params.items():
if isinstance(value, str) and value.endswith(('Z', '+00:00')):
try:
params[key] = datetime.fromisoformat(value.rstrip('Z'))
except ValueError:
# If conversion fails, leave the original value
pass
async with session.begin():
await session.execute(text(query.query), params)
query.completed_by = list(set(query.completed_by + [self.local_ts_id]))
await session.commit()
l.info(f"Successfully executed query ID {query.id}")
except Exception as e:
l.error(f"Failed to execute query ID {query.id}: {str(e)}")
await session.rollback()
l.info("Finished executing unexecuted queries")
async def call_db_sync_on_servers(self):
"""Call /db/sync on all online servers."""
online_servers = await self.get_online_servers()
l.info(f"Calling /db/sync on {len(online_servers)} online servers")
for server in self.config['POOL']:
if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id:
try:
await self.call_db_sync(server)
except Exception as e:
l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}")
l.info("Finished calling /db/sync on all servers")
async def call_db_sync(self, server):
url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
headers = {
"Authorization": f"Bearer {server['api_key']}"
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, headers=headers, timeout=30) as response:
if response.status == 200:
l.info(f"Successfully called /db/sync on {url}")
else:
l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}")
except asyncio.TimeoutError:
l.debug(f"Timeout while calling /db/sync on {url}")
except Exception as e:
l.error(f"Error calling /db/sync on {url}: {str(e)}")
async def ensure_query_tracking_table(self):
for ts_id, engine in self.engines.items():
try:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Ensured query_tracking table exists for {ts_id}")
except Exception as e:
l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}")
async def close(self): async def close(self):
for engine in self.engines.values(): for engine in self.engines.values():
await engine.dispose() await engine.dispose()

333
sijapi/db_sync_old.py Normal file
View file

@ -0,0 +1,333 @@
#database.py
import json
import yaml
import time
import aiohttp
import asyncio
import traceback
from datetime import datetime as dt_datetime, date
from tqdm.asyncio import tqdm
import reverse_geocoder as rg
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar, ClassVar
from dotenv import load_dotenv
from pydantic import BaseModel, Field, create_model, PrivateAttr
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from zoneinfo import ZoneInfo
from srtm import get_data
import os
import sys
from loguru import logger
from sqlalchemy import text, select, func, and_
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base, make_transient
from sqlalchemy.exc import OperationalError
from sqlalchemy import Column, Integer, String, DateTime, JSON, Text
import uuid
from sqlalchemy import Column, String, DateTime, Text, ARRAY
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.sql import func
from urllib.parse import urljoin
import hashlib
import random
from .logs import get_logger
from .serialization import json_dumps, json_serial, serialize
l = get_logger(__name__)
Base = declarative_base()
BASE_DIR = Path(__file__).resolve().parent
CONFIG_DIR = BASE_DIR / "config"
ENV_PATH = CONFIG_DIR / ".env"
load_dotenv(ENV_PATH)
TS_ID = os.environ.get('TS_ID')
class QueryTracking(Base):
__tablename__ = 'query_tracking'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
origin_ts_id = Column(String, nullable=False)
query = Column(Text, nullable=False)
args = Column(JSON)
executed_at = Column(DateTime(timezone=True), server_default=func.now())
completed_by = Column(ARRAY(String), default=[])
result_checksum = Column(String(32))
class Database:
SYNC_COOLDOWN = 30 # seconds
@classmethod
def init(cls, config_name: str):
return cls(config_name)
def __init__(self, config_path: str):
self.config = self.load_config(config_path)
self.engines: Dict[str, Any] = {}
self.sessions: Dict[str, Any] = {}
self.online_servers: set = set()
self.local_ts_id = self.get_local_ts_id()
self.last_sync_time = 0
def load_config(self, config_path: str) -> Dict[str, Any]:
base_path = Path(__file__).parent.parent
full_path = base_path / "sijapi" / "config" / f"{config_path}.yaml"
with open(full_path, 'r') as file:
config = yaml.safe_load(file)
return config
def get_local_ts_id(self) -> str:
return os.environ.get('TS_ID')
async def initialize_engines(self):
for db_info in self.config['POOL']:
url = f"postgresql+asyncpg://{db_info['db_user']}:{db_info['db_pass']}@{db_info['ts_ip']}:{db_info['db_port']}/{db_info['db_name']}"
try:
engine = create_async_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10)
self.engines[db_info['ts_id']] = engine
self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
l.info(f"Initialized engine and session for {db_info['ts_id']}")
# Create tables if they don't exist
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Ensured tables exist for {db_info['ts_id']}")
except Exception as e:
l.error(f"Failed to initialize engine for {db_info['ts_id']}: {str(e)}")
if self.local_ts_id not in self.sessions:
l.error(f"Failed to initialize session for local server {self.local_ts_id}")
async def get_online_servers(self) -> List[str]:
online_servers = []
for ts_id, engine in self.engines.items():
try:
async with engine.connect() as conn:
await conn.execute(text("SELECT 1"))
online_servers.append(ts_id)
l.debug(f"Server {ts_id} is online")
except OperationalError:
l.warning(f"Server {ts_id} is offline")
self.online_servers = set(online_servers)
l.info(f"Online servers: {', '.join(online_servers)}")
return online_servers
async def read(self, query: str, **kwargs):
if self.local_ts_id not in self.sessions:
l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.")
return None
async with self.sessions[self.local_ts_id]() as session:
try:
result = await session.execute(text(query), kwargs)
rows = result.fetchall()
if rows:
columns = result.keys()
return [dict(zip(columns, row)) for row in rows]
else:
return []
except Exception as e:
l.error(f"Failed to execute read query: {str(e)}")
return None
async def write(self, query: str, **kwargs):
if self.local_ts_id not in self.sessions:
l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.")
return None
async with self.sessions[self.local_ts_id]() as session:
try:
# Execute the write query locally
serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
result = await session.execute(text(query), serialized_kwargs)
await session.commit()
# Initiate async operations
asyncio.create_task(self._async_sync_operations(query, kwargs))
# Return the result
return result
except Exception as e:
l.error(f"Failed to execute write query: {str(e)}")
l.error(f"Query: {query}")
l.error(f"Kwargs: {kwargs}")
l.error(f"Serialized kwargs: {serialized_kwargs}")
l.error(f"Traceback: {traceback.format_exc()}")
return None
async def _async_sync_operations(self, query: str, kwargs: dict):
try:
# Add the write query to the query_tracking table
await self.add_query_to_tracking(query, kwargs)
# Call /db/sync on all online servers
await self.call_db_sync_on_servers()
except Exception as e:
l.error(f"Error in async sync operations: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
async def add_query_to_tracking(self, query: str, kwargs: dict, result_checksum: str = None):
async with self.sessions[self.local_ts_id]() as session:
new_query = QueryTracking(
origin_ts_id=self.local_ts_id,
query=query,
args=json_dumps(kwargs),
completed_by=[self.local_ts_id],
result_checksum=result_checksum
)
session.add(new_query)
await session.commit()
l.info(f"Added query to tracking: {query[:50]}...")
async def sync_db(self):
current_time = time.time()
if current_time - self.last_sync_time < self.SYNC_COOLDOWN:
l.info(f"Skipping sync, last sync was less than {self.SYNC_COOLDOWN} seconds ago")
return
try:
l.info("Starting database synchronization")
self.last_sync_time = current_time # Update the last sync time before starting
await self.pull_query_tracking_from_all_servers()
await self.execute_unexecuted_queries()
l.info("Database synchronization completed successfully")
except Exception as e:
l.error(f"Error during database sync: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
finally:
# Ensure the cooldown is respected even if an error occurs
self.last_sync_time = max(self.last_sync_time, current_time)
async def pull_query_tracking_from_all_servers(self):
online_servers = await self.get_online_servers()
l.info(f"Pulling query tracking from {len(online_servers)} online servers")
for server_id in online_servers:
if server_id == self.local_ts_id:
continue # Skip local server
l.info(f"Pulling queries from server: {server_id}")
async with self.sessions[server_id]() as remote_session:
try:
result = await remote_session.execute(select(QueryTracking))
queries = result.scalars().all()
l.info(f"Retrieved {len(queries)} queries from server {server_id}")
async with self.sessions[self.local_ts_id]() as local_session:
for query in queries:
# Detach the object from its original session
make_transient(query)
existing = await local_session.execute(
select(QueryTracking).where(QueryTracking.id == query.id)
)
existing = existing.scalar_one_or_none()
if existing:
# Update existing query
existing.completed_by = list(set(existing.completed_by + query.completed_by))
l.debug(f"Updated existing query: {query.id}")
else:
# Create a new instance for the local session
new_query = QueryTracking(
id=query.id,
origin_ts_id=query.origin_ts_id,
query=query.query,
args=query.args,
executed_at=query.executed_at,
completed_by=query.completed_by,
result_checksum=query.result_checksum
)
local_session.add(new_query)
l.debug(f"Added new query: {query.id}")
await local_session.commit()
except Exception as e:
l.error(f"Error pulling queries from server {server_id}: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
l.info("Finished pulling queries from all servers")
async def execute_unexecuted_queries(self):
async with self.sessions[self.local_ts_id]() as session:
unexecuted_queries = await session.execute(
select(QueryTracking).where(~QueryTracking.completed_by.any(self.local_ts_id)).order_by(QueryTracking.executed_at)
)
unexecuted_queries = unexecuted_queries.scalars().all()
l.info(f"Executing {len(unexecuted_queries)} unexecuted queries")
for query in unexecuted_queries:
try:
params = json.loads(query.args)
# Convert string datetime to datetime objects
for key, value in params.items():
if isinstance(value, str) and value.endswith(('Z', '+00:00')):
try:
params[key] = datetime.fromisoformat(value.rstrip('Z'))
except ValueError:
# If conversion fails, leave the original value
pass
async with session.begin():
await session.execute(text(query.query), params)
query.completed_by = list(set(query.completed_by + [self.local_ts_id]))
await session.commit()
l.info(f"Successfully executed query ID {query.id}")
except Exception as e:
l.error(f"Failed to execute query ID {query.id}: {str(e)}")
await session.rollback()
l.info("Finished executing unexecuted queries")
async def call_db_sync_on_servers(self):
"""Call /db/sync on all online servers."""
online_servers = await self.get_online_servers()
l.info(f"Calling /db/sync on {len(online_servers)} online servers")
for server in self.config['POOL']:
if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id:
try:
await self.call_db_sync(server)
except Exception as e:
l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}")
l.info("Finished calling /db/sync on all servers")
async def call_db_sync(self, server):
url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
headers = {
"Authorization": f"Bearer {server['api_key']}"
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, headers=headers, timeout=30) as response:
if response.status == 200:
l.info(f"Successfully called /db/sync on {url}")
else:
l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}")
except asyncio.TimeoutError:
l.debug(f"Timeout while calling /db/sync on {url}")
except Exception as e:
l.error(f"Error calling /db/sync on {url}: {str(e)}")
async def ensure_query_tracking_table(self):
for ts_id, engine in self.engines.items():
try:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Ensured query_tracking table exists for {ts_id}")
except Exception as e:
l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}")
async def close(self):
for engine in self.engines.values():
await engine.dispose()
l.info("Closed all database connections")

View file

@ -67,7 +67,6 @@ async def get_tailscale_ip():
else: else:
return "No devices found" return "No devices found"
@sys.post("/db/sync") @sys.post("/db/sync")
async def db_sync(background_tasks: BackgroundTasks): async def db_sync(background_tasks: BackgroundTasks):
l.info(f"Received request to /db/sync") l.info(f"Received request to /db/sync")