From 56a203ea7437389f663aaa442fe34230b6bfbdbf Mon Sep 17 00:00:00 2001 From: sanj <67624670+iodrift@users.noreply.github.com> Date: Sat, 9 Nov 2024 18:03:34 -0800 Subject: [PATCH] auto-update --- sijapi/__main__.py | 1 - sijapi/database.py | 240 +----------------------------- sijapi/db_sync_old.py | 333 ++++++++++++++++++++++++++++++++++++++++++ sijapi/routers/sys.py | 1 - 4 files changed, 341 insertions(+), 234 deletions(-) create mode 100644 sijapi/db_sync_old.py diff --git a/sijapi/__main__.py b/sijapi/__main__.py index 8d57ec3..99cd631 100755 --- a/sijapi/__main__.py +++ b/sijapi/__main__.py @@ -58,7 +58,6 @@ async def lifespan(app: FastAPI): try: await Db.initialize_engines() - await Db.ensure_query_tracking_table() except Exception as e: l.critical(f"Error during startup: {str(e)}") l.critical(f"Traceback: {traceback.format_exc()}") diff --git a/sijapi/database.py b/sijapi/database.py index 8e4b491..307e73d 100644 --- a/sijapi/database.py +++ b/sijapi/database.py @@ -1,38 +1,21 @@ -import json +# database.py import yaml -import time -import aiohttp -import asyncio -import traceback -from datetime import datetime as dt_datetime, date -from tqdm.asyncio import tqdm -import reverse_geocoder as rg from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar, ClassVar +from typing import Any, Dict, List from dotenv import load_dotenv -from pydantic import BaseModel, Field, create_model, PrivateAttr -from concurrent.futures import ThreadPoolExecutor -from contextlib import asynccontextmanager -from datetime import datetime, timedelta, timezone -from zoneinfo import ZoneInfo -from srtm import get_data -import os -import sys +from pydantic import BaseModel +from datetime import datetime from loguru import logger -from sqlalchemy import text, select, func, and_ +from sqlalchemy import text, create_engine from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession -from sqlalchemy.orm import sessionmaker, declarative_base, make_transient -from sqlalchemy.exc import OperationalError +from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy import Column, Integer, String, DateTime, JSON, Text import uuid -from sqlalchemy import Column, String, DateTime, Text, ARRAY from sqlalchemy.dialects.postgresql import UUID, JSONB from sqlalchemy.sql import func -from urllib.parse import urljoin -import hashlib -import random +import os from .logs import get_logger -from .serialization import json_dumps, json_serial, serialize +from .serialization import serialize l = get_logger(__name__) @@ -44,24 +27,7 @@ ENV_PATH = CONFIG_DIR / ".env" load_dotenv(ENV_PATH) TS_ID = os.environ.get('TS_ID') - - -class QueryTracking(Base): - __tablename__ = 'query_tracking' - - id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) - origin_ts_id = Column(String, nullable=False) - query = Column(Text, nullable=False) - args = Column(JSON) - executed_at = Column(DateTime(timezone=True), server_default=func.now()) - completed_by = Column(ARRAY(String), default=[]) - result_checksum = Column(String(32)) - - - class Database: - SYNC_COOLDOWN = 30 # seconds - @classmethod def init(cls, config_name: str): return cls(config_name) @@ -70,9 +36,7 @@ class Database: self.config = self.load_config(config_path) self.engines: Dict[str, Any] = {} self.sessions: Dict[str, Any] = {} - self.online_servers: set = set() self.local_ts_id = self.get_local_ts_id() - self.last_sync_time = 0 def load_config(self, config_path: str) -> Dict[str, Any]: base_path = Path(__file__).parent.parent @@ -95,7 +59,6 @@ class Database: self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) l.info(f"Initialized engine and session for {db_info['ts_id']}") - # Create tables if they don't exist async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) l.info(f"Ensured tables exist for {db_info['ts_id']}") @@ -105,20 +68,6 @@ class Database: if self.local_ts_id not in self.sessions: l.error(f"Failed to initialize session for local server {self.local_ts_id}") - async def get_online_servers(self) -> List[str]: - online_servers = [] - for ts_id, engine in self.engines.items(): - try: - async with engine.connect() as conn: - await conn.execute(text("SELECT 1")) - online_servers.append(ts_id) - l.debug(f"Server {ts_id} is online") - except OperationalError: - l.warning(f"Server {ts_id} is offline") - self.online_servers = set(online_servers) - l.info(f"Online servers: {', '.join(online_servers)}") - return online_servers - async def read(self, query: str, **kwargs): if self.local_ts_id not in self.sessions: l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.") @@ -144,190 +93,17 @@ class Database: async with self.sessions[self.local_ts_id]() as session: try: - # Execute the write query locally serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()} result = await session.execute(text(query), serialized_kwargs) await session.commit() - - # Initiate async operations - asyncio.create_task(self._async_sync_operations(query, kwargs)) - - # Return the result return result - except Exception as e: l.error(f"Failed to execute write query: {str(e)}") l.error(f"Query: {query}") l.error(f"Kwargs: {kwargs}") l.error(f"Serialized kwargs: {serialized_kwargs}") - l.error(f"Traceback: {traceback.format_exc()}") return None - async def _async_sync_operations(self, query: str, kwargs: dict): - try: - # Add the write query to the query_tracking table - await self.add_query_to_tracking(query, kwargs) - - # Call /db/sync on all online servers - await self.call_db_sync_on_servers() - except Exception as e: - l.error(f"Error in async sync operations: {str(e)}") - l.error(f"Traceback: {traceback.format_exc()}") - - - - async def add_query_to_tracking(self, query: str, kwargs: dict, result_checksum: str = None): - async with self.sessions[self.local_ts_id]() as session: - new_query = QueryTracking( - origin_ts_id=self.local_ts_id, - query=query, - args=json_dumps(kwargs), - completed_by=[self.local_ts_id], - result_checksum=result_checksum - ) - session.add(new_query) - await session.commit() - l.info(f"Added query to tracking: {query[:50]}...") - - - async def sync_db(self): - current_time = time.time() - if current_time - self.last_sync_time < self.SYNC_COOLDOWN: - l.info(f"Skipping sync, last sync was less than {self.SYNC_COOLDOWN} seconds ago") - return - - try: - l.info("Starting database synchronization") - self.last_sync_time = current_time # Update the last sync time before starting - await self.pull_query_tracking_from_all_servers() - await self.execute_unexecuted_queries() - l.info("Database synchronization completed successfully") - except Exception as e: - l.error(f"Error during database sync: {str(e)}") - l.error(f"Traceback: {traceback.format_exc()}") - finally: - # Ensure the cooldown is respected even if an error occurs - self.last_sync_time = max(self.last_sync_time, current_time) - - - async def pull_query_tracking_from_all_servers(self): - online_servers = await self.get_online_servers() - l.info(f"Pulling query tracking from {len(online_servers)} online servers") - - for server_id in online_servers: - if server_id == self.local_ts_id: - continue # Skip local server - - l.info(f"Pulling queries from server: {server_id}") - async with self.sessions[server_id]() as remote_session: - try: - result = await remote_session.execute(select(QueryTracking)) - queries = result.scalars().all() - l.info(f"Retrieved {len(queries)} queries from server {server_id}") - - async with self.sessions[self.local_ts_id]() as local_session: - for query in queries: - # Detach the object from its original session - make_transient(query) - - existing = await local_session.execute( - select(QueryTracking).where(QueryTracking.id == query.id) - ) - existing = existing.scalar_one_or_none() - - if existing: - # Update existing query - existing.completed_by = list(set(existing.completed_by + query.completed_by)) - l.debug(f"Updated existing query: {query.id}") - else: - # Create a new instance for the local session - new_query = QueryTracking( - id=query.id, - origin_ts_id=query.origin_ts_id, - query=query.query, - args=query.args, - executed_at=query.executed_at, - completed_by=query.completed_by, - result_checksum=query.result_checksum - ) - local_session.add(new_query) - l.debug(f"Added new query: {query.id}") - await local_session.commit() - except Exception as e: - l.error(f"Error pulling queries from server {server_id}: {str(e)}") - l.error(f"Traceback: {traceback.format_exc()}") - l.info("Finished pulling queries from all servers") - - - async def execute_unexecuted_queries(self): - async with self.sessions[self.local_ts_id]() as session: - unexecuted_queries = await session.execute( - select(QueryTracking).where(~QueryTracking.completed_by.any(self.local_ts_id)).order_by(QueryTracking.executed_at) - ) - unexecuted_queries = unexecuted_queries.scalars().all() - - l.info(f"Executing {len(unexecuted_queries)} unexecuted queries") - for query in unexecuted_queries: - try: - params = json.loads(query.args) - - # Convert string datetime to datetime objects - for key, value in params.items(): - if isinstance(value, str) and value.endswith(('Z', '+00:00')): - try: - params[key] = datetime.fromisoformat(value.rstrip('Z')) - except ValueError: - # If conversion fails, leave the original value - pass - - async with session.begin(): - await session.execute(text(query.query), params) - query.completed_by = list(set(query.completed_by + [self.local_ts_id])) - await session.commit() - l.info(f"Successfully executed query ID {query.id}") - except Exception as e: - l.error(f"Failed to execute query ID {query.id}: {str(e)}") - await session.rollback() - l.info("Finished executing unexecuted queries") - - async def call_db_sync_on_servers(self): - """Call /db/sync on all online servers.""" - online_servers = await self.get_online_servers() - l.info(f"Calling /db/sync on {len(online_servers)} online servers") - for server in self.config['POOL']: - if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id: - try: - await self.call_db_sync(server) - except Exception as e: - l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}") - l.info("Finished calling /db/sync on all servers") - - async def call_db_sync(self, server): - url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync" - headers = { - "Authorization": f"Bearer {server['api_key']}" - } - async with aiohttp.ClientSession() as session: - try: - async with session.post(url, headers=headers, timeout=30) as response: - if response.status == 200: - l.info(f"Successfully called /db/sync on {url}") - else: - l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}") - except asyncio.TimeoutError: - l.debug(f"Timeout while calling /db/sync on {url}") - except Exception as e: - l.error(f"Error calling /db/sync on {url}: {str(e)}") - - async def ensure_query_tracking_table(self): - for ts_id, engine in self.engines.items(): - try: - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - l.info(f"Ensured query_tracking table exists for {ts_id}") - except Exception as e: - l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}") - async def close(self): for engine in self.engines.values(): await engine.dispose() diff --git a/sijapi/db_sync_old.py b/sijapi/db_sync_old.py new file mode 100644 index 0000000..519c086 --- /dev/null +++ b/sijapi/db_sync_old.py @@ -0,0 +1,333 @@ +#database.py +import json +import yaml +import time +import aiohttp +import asyncio +import traceback +from datetime import datetime as dt_datetime, date +from tqdm.asyncio import tqdm +import reverse_geocoder as rg +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar, ClassVar +from dotenv import load_dotenv +from pydantic import BaseModel, Field, create_model, PrivateAttr +from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager +from datetime import datetime, timedelta, timezone +from zoneinfo import ZoneInfo +from srtm import get_data +import os +import sys +from loguru import logger +from sqlalchemy import text, select, func, and_ +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker, declarative_base, make_transient +from sqlalchemy.exc import OperationalError +from sqlalchemy import Column, Integer, String, DateTime, JSON, Text +import uuid +from sqlalchemy import Column, String, DateTime, Text, ARRAY +from sqlalchemy.dialects.postgresql import UUID, JSONB +from sqlalchemy.sql import func +from urllib.parse import urljoin +import hashlib +import random +from .logs import get_logger +from .serialization import json_dumps, json_serial, serialize + +l = get_logger(__name__) + +Base = declarative_base() + +BASE_DIR = Path(__file__).resolve().parent +CONFIG_DIR = BASE_DIR / "config" +ENV_PATH = CONFIG_DIR / ".env" +load_dotenv(ENV_PATH) +TS_ID = os.environ.get('TS_ID') + + +class QueryTracking(Base): + __tablename__ = 'query_tracking' + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + origin_ts_id = Column(String, nullable=False) + query = Column(Text, nullable=False) + args = Column(JSON) + executed_at = Column(DateTime(timezone=True), server_default=func.now()) + completed_by = Column(ARRAY(String), default=[]) + result_checksum = Column(String(32)) + + +class Database: + SYNC_COOLDOWN = 30 # seconds + + @classmethod + def init(cls, config_name: str): + return cls(config_name) + + def __init__(self, config_path: str): + self.config = self.load_config(config_path) + self.engines: Dict[str, Any] = {} + self.sessions: Dict[str, Any] = {} + self.online_servers: set = set() + self.local_ts_id = self.get_local_ts_id() + self.last_sync_time = 0 + + def load_config(self, config_path: str) -> Dict[str, Any]: + base_path = Path(__file__).parent.parent + full_path = base_path / "sijapi" / "config" / f"{config_path}.yaml" + + with open(full_path, 'r') as file: + config = yaml.safe_load(file) + + return config + + def get_local_ts_id(self) -> str: + return os.environ.get('TS_ID') + + async def initialize_engines(self): + for db_info in self.config['POOL']: + url = f"postgresql+asyncpg://{db_info['db_user']}:{db_info['db_pass']}@{db_info['ts_ip']}:{db_info['db_port']}/{db_info['db_name']}" + try: + engine = create_async_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10) + self.engines[db_info['ts_id']] = engine + self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + l.info(f"Initialized engine and session for {db_info['ts_id']}") + + # Create tables if they don't exist + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + l.info(f"Ensured tables exist for {db_info['ts_id']}") + except Exception as e: + l.error(f"Failed to initialize engine for {db_info['ts_id']}: {str(e)}") + + if self.local_ts_id not in self.sessions: + l.error(f"Failed to initialize session for local server {self.local_ts_id}") + + async def get_online_servers(self) -> List[str]: + online_servers = [] + for ts_id, engine in self.engines.items(): + try: + async with engine.connect() as conn: + await conn.execute(text("SELECT 1")) + online_servers.append(ts_id) + l.debug(f"Server {ts_id} is online") + except OperationalError: + l.warning(f"Server {ts_id} is offline") + self.online_servers = set(online_servers) + l.info(f"Online servers: {', '.join(online_servers)}") + return online_servers + + async def read(self, query: str, **kwargs): + if self.local_ts_id not in self.sessions: + l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.") + return None + + async with self.sessions[self.local_ts_id]() as session: + try: + result = await session.execute(text(query), kwargs) + rows = result.fetchall() + if rows: + columns = result.keys() + return [dict(zip(columns, row)) for row in rows] + else: + return [] + except Exception as e: + l.error(f"Failed to execute read query: {str(e)}") + return None + + async def write(self, query: str, **kwargs): + if self.local_ts_id not in self.sessions: + l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.") + return None + + async with self.sessions[self.local_ts_id]() as session: + try: + # Execute the write query locally + serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()} + result = await session.execute(text(query), serialized_kwargs) + await session.commit() + + # Initiate async operations + asyncio.create_task(self._async_sync_operations(query, kwargs)) + + # Return the result + return result + + except Exception as e: + l.error(f"Failed to execute write query: {str(e)}") + l.error(f"Query: {query}") + l.error(f"Kwargs: {kwargs}") + l.error(f"Serialized kwargs: {serialized_kwargs}") + l.error(f"Traceback: {traceback.format_exc()}") + return None + + async def _async_sync_operations(self, query: str, kwargs: dict): + try: + # Add the write query to the query_tracking table + await self.add_query_to_tracking(query, kwargs) + + # Call /db/sync on all online servers + await self.call_db_sync_on_servers() + except Exception as e: + l.error(f"Error in async sync operations: {str(e)}") + l.error(f"Traceback: {traceback.format_exc()}") + + + + async def add_query_to_tracking(self, query: str, kwargs: dict, result_checksum: str = None): + async with self.sessions[self.local_ts_id]() as session: + new_query = QueryTracking( + origin_ts_id=self.local_ts_id, + query=query, + args=json_dumps(kwargs), + completed_by=[self.local_ts_id], + result_checksum=result_checksum + ) + session.add(new_query) + await session.commit() + l.info(f"Added query to tracking: {query[:50]}...") + + + async def sync_db(self): + current_time = time.time() + if current_time - self.last_sync_time < self.SYNC_COOLDOWN: + l.info(f"Skipping sync, last sync was less than {self.SYNC_COOLDOWN} seconds ago") + return + + try: + l.info("Starting database synchronization") + self.last_sync_time = current_time # Update the last sync time before starting + await self.pull_query_tracking_from_all_servers() + await self.execute_unexecuted_queries() + l.info("Database synchronization completed successfully") + except Exception as e: + l.error(f"Error during database sync: {str(e)}") + l.error(f"Traceback: {traceback.format_exc()}") + finally: + # Ensure the cooldown is respected even if an error occurs + self.last_sync_time = max(self.last_sync_time, current_time) + + + async def pull_query_tracking_from_all_servers(self): + online_servers = await self.get_online_servers() + l.info(f"Pulling query tracking from {len(online_servers)} online servers") + + for server_id in online_servers: + if server_id == self.local_ts_id: + continue # Skip local server + + l.info(f"Pulling queries from server: {server_id}") + async with self.sessions[server_id]() as remote_session: + try: + result = await remote_session.execute(select(QueryTracking)) + queries = result.scalars().all() + l.info(f"Retrieved {len(queries)} queries from server {server_id}") + + async with self.sessions[self.local_ts_id]() as local_session: + for query in queries: + # Detach the object from its original session + make_transient(query) + + existing = await local_session.execute( + select(QueryTracking).where(QueryTracking.id == query.id) + ) + existing = existing.scalar_one_or_none() + + if existing: + # Update existing query + existing.completed_by = list(set(existing.completed_by + query.completed_by)) + l.debug(f"Updated existing query: {query.id}") + else: + # Create a new instance for the local session + new_query = QueryTracking( + id=query.id, + origin_ts_id=query.origin_ts_id, + query=query.query, + args=query.args, + executed_at=query.executed_at, + completed_by=query.completed_by, + result_checksum=query.result_checksum + ) + local_session.add(new_query) + l.debug(f"Added new query: {query.id}") + await local_session.commit() + except Exception as e: + l.error(f"Error pulling queries from server {server_id}: {str(e)}") + l.error(f"Traceback: {traceback.format_exc()}") + l.info("Finished pulling queries from all servers") + + + async def execute_unexecuted_queries(self): + async with self.sessions[self.local_ts_id]() as session: + unexecuted_queries = await session.execute( + select(QueryTracking).where(~QueryTracking.completed_by.any(self.local_ts_id)).order_by(QueryTracking.executed_at) + ) + unexecuted_queries = unexecuted_queries.scalars().all() + + l.info(f"Executing {len(unexecuted_queries)} unexecuted queries") + for query in unexecuted_queries: + try: + params = json.loads(query.args) + + # Convert string datetime to datetime objects + for key, value in params.items(): + if isinstance(value, str) and value.endswith(('Z', '+00:00')): + try: + params[key] = datetime.fromisoformat(value.rstrip('Z')) + except ValueError: + # If conversion fails, leave the original value + pass + + async with session.begin(): + await session.execute(text(query.query), params) + query.completed_by = list(set(query.completed_by + [self.local_ts_id])) + await session.commit() + l.info(f"Successfully executed query ID {query.id}") + except Exception as e: + l.error(f"Failed to execute query ID {query.id}: {str(e)}") + await session.rollback() + l.info("Finished executing unexecuted queries") + + async def call_db_sync_on_servers(self): + """Call /db/sync on all online servers.""" + online_servers = await self.get_online_servers() + l.info(f"Calling /db/sync on {len(online_servers)} online servers") + for server in self.config['POOL']: + if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id: + try: + await self.call_db_sync(server) + except Exception as e: + l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}") + l.info("Finished calling /db/sync on all servers") + + async def call_db_sync(self, server): + url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync" + headers = { + "Authorization": f"Bearer {server['api_key']}" + } + async with aiohttp.ClientSession() as session: + try: + async with session.post(url, headers=headers, timeout=30) as response: + if response.status == 200: + l.info(f"Successfully called /db/sync on {url}") + else: + l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}") + except asyncio.TimeoutError: + l.debug(f"Timeout while calling /db/sync on {url}") + except Exception as e: + l.error(f"Error calling /db/sync on {url}: {str(e)}") + + async def ensure_query_tracking_table(self): + for ts_id, engine in self.engines.items(): + try: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + l.info(f"Ensured query_tracking table exists for {ts_id}") + except Exception as e: + l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}") + + async def close(self): + for engine in self.engines.values(): + await engine.dispose() + l.info("Closed all database connections") diff --git a/sijapi/routers/sys.py b/sijapi/routers/sys.py index 4954af8..0c7e050 100644 --- a/sijapi/routers/sys.py +++ b/sijapi/routers/sys.py @@ -67,7 +67,6 @@ async def get_tailscale_ip(): else: return "No devices found" - @sys.post("/db/sync") async def db_sync(background_tasks: BackgroundTasks): l.info(f"Received request to /db/sync")