auto-update

This commit is contained in:
sanj 2024-11-09 18:03:34 -08:00
parent 79b99ef68f
commit 56a203ea74
4 changed files with 341 additions and 234 deletions

View file

@ -58,7 +58,6 @@ async def lifespan(app: FastAPI):
try:
await Db.initialize_engines()
await Db.ensure_query_tracking_table()
except Exception as e:
l.critical(f"Error during startup: {str(e)}")
l.critical(f"Traceback: {traceback.format_exc()}")

View file

@ -1,38 +1,21 @@
import json
# database.py
import yaml
import time
import aiohttp
import asyncio
import traceback
from datetime import datetime as dt_datetime, date
from tqdm.asyncio import tqdm
import reverse_geocoder as rg
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar, ClassVar
from typing import Any, Dict, List
from dotenv import load_dotenv
from pydantic import BaseModel, Field, create_model, PrivateAttr
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from zoneinfo import ZoneInfo
from srtm import get_data
import os
import sys
from pydantic import BaseModel
from datetime import datetime
from loguru import logger
from sqlalchemy import text, select, func, and_
from sqlalchemy import text, create_engine
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base, make_transient
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, Integer, String, DateTime, JSON, Text
import uuid
from sqlalchemy import Column, String, DateTime, Text, ARRAY
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.sql import func
from urllib.parse import urljoin
import hashlib
import random
import os
from .logs import get_logger
from .serialization import json_dumps, json_serial, serialize
from .serialization import serialize
l = get_logger(__name__)
@ -44,24 +27,7 @@ ENV_PATH = CONFIG_DIR / ".env"
load_dotenv(ENV_PATH)
TS_ID = os.environ.get('TS_ID')
class QueryTracking(Base):
__tablename__ = 'query_tracking'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
origin_ts_id = Column(String, nullable=False)
query = Column(Text, nullable=False)
args = Column(JSON)
executed_at = Column(DateTime(timezone=True), server_default=func.now())
completed_by = Column(ARRAY(String), default=[])
result_checksum = Column(String(32))
class Database:
SYNC_COOLDOWN = 30 # seconds
@classmethod
def init(cls, config_name: str):
return cls(config_name)
@ -70,9 +36,7 @@ class Database:
self.config = self.load_config(config_path)
self.engines: Dict[str, Any] = {}
self.sessions: Dict[str, Any] = {}
self.online_servers: set = set()
self.local_ts_id = self.get_local_ts_id()
self.last_sync_time = 0
def load_config(self, config_path: str) -> Dict[str, Any]:
base_path = Path(__file__).parent.parent
@ -95,7 +59,6 @@ class Database:
self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
l.info(f"Initialized engine and session for {db_info['ts_id']}")
# Create tables if they don't exist
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Ensured tables exist for {db_info['ts_id']}")
@ -105,20 +68,6 @@ class Database:
if self.local_ts_id not in self.sessions:
l.error(f"Failed to initialize session for local server {self.local_ts_id}")
async def get_online_servers(self) -> List[str]:
online_servers = []
for ts_id, engine in self.engines.items():
try:
async with engine.connect() as conn:
await conn.execute(text("SELECT 1"))
online_servers.append(ts_id)
l.debug(f"Server {ts_id} is online")
except OperationalError:
l.warning(f"Server {ts_id} is offline")
self.online_servers = set(online_servers)
l.info(f"Online servers: {', '.join(online_servers)}")
return online_servers
async def read(self, query: str, **kwargs):
if self.local_ts_id not in self.sessions:
l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.")
@ -144,190 +93,17 @@ class Database:
async with self.sessions[self.local_ts_id]() as session:
try:
# Execute the write query locally
serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
result = await session.execute(text(query), serialized_kwargs)
await session.commit()
# Initiate async operations
asyncio.create_task(self._async_sync_operations(query, kwargs))
# Return the result
return result
except Exception as e:
l.error(f"Failed to execute write query: {str(e)}")
l.error(f"Query: {query}")
l.error(f"Kwargs: {kwargs}")
l.error(f"Serialized kwargs: {serialized_kwargs}")
l.error(f"Traceback: {traceback.format_exc()}")
return None
async def _async_sync_operations(self, query: str, kwargs: dict):
try:
# Add the write query to the query_tracking table
await self.add_query_to_tracking(query, kwargs)
# Call /db/sync on all online servers
await self.call_db_sync_on_servers()
except Exception as e:
l.error(f"Error in async sync operations: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
async def add_query_to_tracking(self, query: str, kwargs: dict, result_checksum: str = None):
async with self.sessions[self.local_ts_id]() as session:
new_query = QueryTracking(
origin_ts_id=self.local_ts_id,
query=query,
args=json_dumps(kwargs),
completed_by=[self.local_ts_id],
result_checksum=result_checksum
)
session.add(new_query)
await session.commit()
l.info(f"Added query to tracking: {query[:50]}...")
async def sync_db(self):
current_time = time.time()
if current_time - self.last_sync_time < self.SYNC_COOLDOWN:
l.info(f"Skipping sync, last sync was less than {self.SYNC_COOLDOWN} seconds ago")
return
try:
l.info("Starting database synchronization")
self.last_sync_time = current_time # Update the last sync time before starting
await self.pull_query_tracking_from_all_servers()
await self.execute_unexecuted_queries()
l.info("Database synchronization completed successfully")
except Exception as e:
l.error(f"Error during database sync: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
finally:
# Ensure the cooldown is respected even if an error occurs
self.last_sync_time = max(self.last_sync_time, current_time)
async def pull_query_tracking_from_all_servers(self):
online_servers = await self.get_online_servers()
l.info(f"Pulling query tracking from {len(online_servers)} online servers")
for server_id in online_servers:
if server_id == self.local_ts_id:
continue # Skip local server
l.info(f"Pulling queries from server: {server_id}")
async with self.sessions[server_id]() as remote_session:
try:
result = await remote_session.execute(select(QueryTracking))
queries = result.scalars().all()
l.info(f"Retrieved {len(queries)} queries from server {server_id}")
async with self.sessions[self.local_ts_id]() as local_session:
for query in queries:
# Detach the object from its original session
make_transient(query)
existing = await local_session.execute(
select(QueryTracking).where(QueryTracking.id == query.id)
)
existing = existing.scalar_one_or_none()
if existing:
# Update existing query
existing.completed_by = list(set(existing.completed_by + query.completed_by))
l.debug(f"Updated existing query: {query.id}")
else:
# Create a new instance for the local session
new_query = QueryTracking(
id=query.id,
origin_ts_id=query.origin_ts_id,
query=query.query,
args=query.args,
executed_at=query.executed_at,
completed_by=query.completed_by,
result_checksum=query.result_checksum
)
local_session.add(new_query)
l.debug(f"Added new query: {query.id}")
await local_session.commit()
except Exception as e:
l.error(f"Error pulling queries from server {server_id}: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
l.info("Finished pulling queries from all servers")
async def execute_unexecuted_queries(self):
async with self.sessions[self.local_ts_id]() as session:
unexecuted_queries = await session.execute(
select(QueryTracking).where(~QueryTracking.completed_by.any(self.local_ts_id)).order_by(QueryTracking.executed_at)
)
unexecuted_queries = unexecuted_queries.scalars().all()
l.info(f"Executing {len(unexecuted_queries)} unexecuted queries")
for query in unexecuted_queries:
try:
params = json.loads(query.args)
# Convert string datetime to datetime objects
for key, value in params.items():
if isinstance(value, str) and value.endswith(('Z', '+00:00')):
try:
params[key] = datetime.fromisoformat(value.rstrip('Z'))
except ValueError:
# If conversion fails, leave the original value
pass
async with session.begin():
await session.execute(text(query.query), params)
query.completed_by = list(set(query.completed_by + [self.local_ts_id]))
await session.commit()
l.info(f"Successfully executed query ID {query.id}")
except Exception as e:
l.error(f"Failed to execute query ID {query.id}: {str(e)}")
await session.rollback()
l.info("Finished executing unexecuted queries")
async def call_db_sync_on_servers(self):
"""Call /db/sync on all online servers."""
online_servers = await self.get_online_servers()
l.info(f"Calling /db/sync on {len(online_servers)} online servers")
for server in self.config['POOL']:
if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id:
try:
await self.call_db_sync(server)
except Exception as e:
l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}")
l.info("Finished calling /db/sync on all servers")
async def call_db_sync(self, server):
url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
headers = {
"Authorization": f"Bearer {server['api_key']}"
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, headers=headers, timeout=30) as response:
if response.status == 200:
l.info(f"Successfully called /db/sync on {url}")
else:
l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}")
except asyncio.TimeoutError:
l.debug(f"Timeout while calling /db/sync on {url}")
except Exception as e:
l.error(f"Error calling /db/sync on {url}: {str(e)}")
async def ensure_query_tracking_table(self):
for ts_id, engine in self.engines.items():
try:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Ensured query_tracking table exists for {ts_id}")
except Exception as e:
l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}")
async def close(self):
for engine in self.engines.values():
await engine.dispose()

333
sijapi/db_sync_old.py Normal file
View file

@ -0,0 +1,333 @@
#database.py
import json
import yaml
import time
import aiohttp
import asyncio
import traceback
from datetime import datetime as dt_datetime, date
from tqdm.asyncio import tqdm
import reverse_geocoder as rg
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, TypeVar, ClassVar
from dotenv import load_dotenv
from pydantic import BaseModel, Field, create_model, PrivateAttr
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from zoneinfo import ZoneInfo
from srtm import get_data
import os
import sys
from loguru import logger
from sqlalchemy import text, select, func, and_
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base, make_transient
from sqlalchemy.exc import OperationalError
from sqlalchemy import Column, Integer, String, DateTime, JSON, Text
import uuid
from sqlalchemy import Column, String, DateTime, Text, ARRAY
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy.sql import func
from urllib.parse import urljoin
import hashlib
import random
from .logs import get_logger
from .serialization import json_dumps, json_serial, serialize
l = get_logger(__name__)
Base = declarative_base()
BASE_DIR = Path(__file__).resolve().parent
CONFIG_DIR = BASE_DIR / "config"
ENV_PATH = CONFIG_DIR / ".env"
load_dotenv(ENV_PATH)
TS_ID = os.environ.get('TS_ID')
class QueryTracking(Base):
__tablename__ = 'query_tracking'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
origin_ts_id = Column(String, nullable=False)
query = Column(Text, nullable=False)
args = Column(JSON)
executed_at = Column(DateTime(timezone=True), server_default=func.now())
completed_by = Column(ARRAY(String), default=[])
result_checksum = Column(String(32))
class Database:
SYNC_COOLDOWN = 30 # seconds
@classmethod
def init(cls, config_name: str):
return cls(config_name)
def __init__(self, config_path: str):
self.config = self.load_config(config_path)
self.engines: Dict[str, Any] = {}
self.sessions: Dict[str, Any] = {}
self.online_servers: set = set()
self.local_ts_id = self.get_local_ts_id()
self.last_sync_time = 0
def load_config(self, config_path: str) -> Dict[str, Any]:
base_path = Path(__file__).parent.parent
full_path = base_path / "sijapi" / "config" / f"{config_path}.yaml"
with open(full_path, 'r') as file:
config = yaml.safe_load(file)
return config
def get_local_ts_id(self) -> str:
return os.environ.get('TS_ID')
async def initialize_engines(self):
for db_info in self.config['POOL']:
url = f"postgresql+asyncpg://{db_info['db_user']}:{db_info['db_pass']}@{db_info['ts_ip']}:{db_info['db_port']}/{db_info['db_name']}"
try:
engine = create_async_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10)
self.engines[db_info['ts_id']] = engine
self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
l.info(f"Initialized engine and session for {db_info['ts_id']}")
# Create tables if they don't exist
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Ensured tables exist for {db_info['ts_id']}")
except Exception as e:
l.error(f"Failed to initialize engine for {db_info['ts_id']}: {str(e)}")
if self.local_ts_id not in self.sessions:
l.error(f"Failed to initialize session for local server {self.local_ts_id}")
async def get_online_servers(self) -> List[str]:
online_servers = []
for ts_id, engine in self.engines.items():
try:
async with engine.connect() as conn:
await conn.execute(text("SELECT 1"))
online_servers.append(ts_id)
l.debug(f"Server {ts_id} is online")
except OperationalError:
l.warning(f"Server {ts_id} is offline")
self.online_servers = set(online_servers)
l.info(f"Online servers: {', '.join(online_servers)}")
return online_servers
async def read(self, query: str, **kwargs):
if self.local_ts_id not in self.sessions:
l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.")
return None
async with self.sessions[self.local_ts_id]() as session:
try:
result = await session.execute(text(query), kwargs)
rows = result.fetchall()
if rows:
columns = result.keys()
return [dict(zip(columns, row)) for row in rows]
else:
return []
except Exception as e:
l.error(f"Failed to execute read query: {str(e)}")
return None
async def write(self, query: str, **kwargs):
if self.local_ts_id not in self.sessions:
l.error(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.")
return None
async with self.sessions[self.local_ts_id]() as session:
try:
# Execute the write query locally
serialized_kwargs = {key: serialize(value) for key, value in kwargs.items()}
result = await session.execute(text(query), serialized_kwargs)
await session.commit()
# Initiate async operations
asyncio.create_task(self._async_sync_operations(query, kwargs))
# Return the result
return result
except Exception as e:
l.error(f"Failed to execute write query: {str(e)}")
l.error(f"Query: {query}")
l.error(f"Kwargs: {kwargs}")
l.error(f"Serialized kwargs: {serialized_kwargs}")
l.error(f"Traceback: {traceback.format_exc()}")
return None
async def _async_sync_operations(self, query: str, kwargs: dict):
try:
# Add the write query to the query_tracking table
await self.add_query_to_tracking(query, kwargs)
# Call /db/sync on all online servers
await self.call_db_sync_on_servers()
except Exception as e:
l.error(f"Error in async sync operations: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
async def add_query_to_tracking(self, query: str, kwargs: dict, result_checksum: str = None):
async with self.sessions[self.local_ts_id]() as session:
new_query = QueryTracking(
origin_ts_id=self.local_ts_id,
query=query,
args=json_dumps(kwargs),
completed_by=[self.local_ts_id],
result_checksum=result_checksum
)
session.add(new_query)
await session.commit()
l.info(f"Added query to tracking: {query[:50]}...")
async def sync_db(self):
current_time = time.time()
if current_time - self.last_sync_time < self.SYNC_COOLDOWN:
l.info(f"Skipping sync, last sync was less than {self.SYNC_COOLDOWN} seconds ago")
return
try:
l.info("Starting database synchronization")
self.last_sync_time = current_time # Update the last sync time before starting
await self.pull_query_tracking_from_all_servers()
await self.execute_unexecuted_queries()
l.info("Database synchronization completed successfully")
except Exception as e:
l.error(f"Error during database sync: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
finally:
# Ensure the cooldown is respected even if an error occurs
self.last_sync_time = max(self.last_sync_time, current_time)
async def pull_query_tracking_from_all_servers(self):
online_servers = await self.get_online_servers()
l.info(f"Pulling query tracking from {len(online_servers)} online servers")
for server_id in online_servers:
if server_id == self.local_ts_id:
continue # Skip local server
l.info(f"Pulling queries from server: {server_id}")
async with self.sessions[server_id]() as remote_session:
try:
result = await remote_session.execute(select(QueryTracking))
queries = result.scalars().all()
l.info(f"Retrieved {len(queries)} queries from server {server_id}")
async with self.sessions[self.local_ts_id]() as local_session:
for query in queries:
# Detach the object from its original session
make_transient(query)
existing = await local_session.execute(
select(QueryTracking).where(QueryTracking.id == query.id)
)
existing = existing.scalar_one_or_none()
if existing:
# Update existing query
existing.completed_by = list(set(existing.completed_by + query.completed_by))
l.debug(f"Updated existing query: {query.id}")
else:
# Create a new instance for the local session
new_query = QueryTracking(
id=query.id,
origin_ts_id=query.origin_ts_id,
query=query.query,
args=query.args,
executed_at=query.executed_at,
completed_by=query.completed_by,
result_checksum=query.result_checksum
)
local_session.add(new_query)
l.debug(f"Added new query: {query.id}")
await local_session.commit()
except Exception as e:
l.error(f"Error pulling queries from server {server_id}: {str(e)}")
l.error(f"Traceback: {traceback.format_exc()}")
l.info("Finished pulling queries from all servers")
async def execute_unexecuted_queries(self):
async with self.sessions[self.local_ts_id]() as session:
unexecuted_queries = await session.execute(
select(QueryTracking).where(~QueryTracking.completed_by.any(self.local_ts_id)).order_by(QueryTracking.executed_at)
)
unexecuted_queries = unexecuted_queries.scalars().all()
l.info(f"Executing {len(unexecuted_queries)} unexecuted queries")
for query in unexecuted_queries:
try:
params = json.loads(query.args)
# Convert string datetime to datetime objects
for key, value in params.items():
if isinstance(value, str) and value.endswith(('Z', '+00:00')):
try:
params[key] = datetime.fromisoformat(value.rstrip('Z'))
except ValueError:
# If conversion fails, leave the original value
pass
async with session.begin():
await session.execute(text(query.query), params)
query.completed_by = list(set(query.completed_by + [self.local_ts_id]))
await session.commit()
l.info(f"Successfully executed query ID {query.id}")
except Exception as e:
l.error(f"Failed to execute query ID {query.id}: {str(e)}")
await session.rollback()
l.info("Finished executing unexecuted queries")
async def call_db_sync_on_servers(self):
"""Call /db/sync on all online servers."""
online_servers = await self.get_online_servers()
l.info(f"Calling /db/sync on {len(online_servers)} online servers")
for server in self.config['POOL']:
if server['ts_id'] in online_servers and server['ts_id'] != self.local_ts_id:
try:
await self.call_db_sync(server)
except Exception as e:
l.error(f"Failed to call /db/sync on {server['ts_id']}: {str(e)}")
l.info("Finished calling /db/sync on all servers")
async def call_db_sync(self, server):
url = f"http://{server['ts_ip']}:{server['app_port']}/db/sync"
headers = {
"Authorization": f"Bearer {server['api_key']}"
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, headers=headers, timeout=30) as response:
if response.status == 200:
l.info(f"Successfully called /db/sync on {url}")
else:
l.warning(f"Failed to call /db/sync on {url}. Status: {response.status}")
except asyncio.TimeoutError:
l.debug(f"Timeout while calling /db/sync on {url}")
except Exception as e:
l.error(f"Error calling /db/sync on {url}: {str(e)}")
async def ensure_query_tracking_table(self):
for ts_id, engine in self.engines.items():
try:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
l.info(f"Ensured query_tracking table exists for {ts_id}")
except Exception as e:
l.error(f"Failed to create query_tracking table for {ts_id}: {str(e)}")
async def close(self):
for engine in self.engines.values():
await engine.dispose()
l.info("Closed all database connections")

View file

@ -67,7 +67,6 @@ async def get_tailscale_ip():
else:
return "No devices found"
@sys.post("/db/sync")
async def db_sync(background_tasks: BackgroundTasks):
l.info(f"Received request to /db/sync")