Auto-update: Fri Aug 9 11:31:45 PDT 2024
This commit is contained in:
parent
487807bab1
commit
ee6ee1ed87
6 changed files with 463 additions and 326 deletions
|
@ -26,8 +26,7 @@ Db = Database.load('sys')
|
||||||
|
|
||||||
# HOST = f"{API.BIND}:{API.PORT}"
|
# HOST = f"{API.BIND}:{API.PORT}"
|
||||||
# LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
|
# LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
|
||||||
|
# SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
|
||||||
SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
|
|
||||||
|
|
||||||
MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())
|
MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ from dotenv import load_dotenv
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import argparse
|
import argparse
|
||||||
from . import L, API, ROUTER_DIR
|
from . import L, API, Db, ROUTER_DIR
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Personal API.')
|
parser = argparse.ArgumentParser(description='Personal API.')
|
||||||
parser.add_argument('--log', type=str, default='INFO', help='Set overall log level (e.g., DEBUG, INFO, WARNING)')
|
parser.add_argument('--log', type=str, default='INFO', help='Set overall log level (e.g., DEBUG, INFO, WARNING)')
|
||||||
|
@ -55,7 +55,8 @@ async def lifespan(app: FastAPI):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize sync structures on all databases
|
# Initialize sync structures on all databases
|
||||||
await API.initialize_sync()
|
# await API.initialize_sync()
|
||||||
|
await Db.initialize_engines()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
crit(f"Error during startup: {str(e)}")
|
crit(f"Error during startup: {str(e)}")
|
||||||
|
@ -99,16 +100,18 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
|
||||||
api_key_header = request.headers.get("Authorization")
|
api_key_header = request.headers.get("Authorization")
|
||||||
api_key_query = request.query_params.get("api_key")
|
api_key_query = request.query_params.get("api_key")
|
||||||
|
|
||||||
# Debug logging for API keys
|
# Convert API.KEYS to lowercase for case-insensitive comparison
|
||||||
debug(f"API.KEYS: {API.KEYS}")
|
api_keys_lower = [key.lower() for key in API.KEYS]
|
||||||
|
debug(f"API.KEYS (lowercase): {api_keys_lower}")
|
||||||
|
|
||||||
if api_key_header:
|
if api_key_header:
|
||||||
api_key_header = api_key_header.lower().split("bearer ")[-1]
|
api_key_header = api_key_header.lower().split("bearer ")[-1]
|
||||||
debug(f"API key provided in header: {api_key_header}")
|
debug(f"API key provided in header: {api_key_header}")
|
||||||
if api_key_query:
|
if api_key_query:
|
||||||
|
api_key_query = api_key_query.lower()
|
||||||
debug(f"API key provided in query: {api_key_query}")
|
debug(f"API key provided in query: {api_key_query}")
|
||||||
|
|
||||||
if api_key_header not in API.KEYS and api_key_query not in API.KEYS:
|
if api_key_header.lower() not in api_keys_lower and api_key_query.lower() not in api_keys_lower:
|
||||||
err(f"Invalid API key provided by a requester.")
|
err(f"Invalid API key provided by a requester.")
|
||||||
if api_key_header:
|
if api_key_header:
|
||||||
debug(f"Invalid API key in header: {api_key_header}")
|
debug(f"Invalid API key in header: {api_key_header}")
|
||||||
|
@ -119,9 +122,9 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
|
||||||
content={"detail": "Invalid or missing API key"}
|
content={"detail": "Invalid or missing API key"}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if api_key_header in API.KEYS:
|
if api_key_header.lower() in api_keys_lower:
|
||||||
debug(f"Valid API key provided in header: {api_key_header}")
|
debug(f"Valid API key provided in header: {api_key_header}")
|
||||||
if api_key_query in API.KEYS:
|
if api_key_query and api_key_query.lower() in api_keys_lower:
|
||||||
debug(f"Valid API key provided in query: {api_key_query}")
|
debug(f"Valid API key provided in query: {api_key_query}")
|
||||||
|
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
|
@ -27,6 +27,17 @@ from srtm import get_data
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||||
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||||
|
from sqlalchemy import Column, Integer, String, DateTime, JSON, Text, select, func
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
import hashlib
|
||||||
|
import random
|
||||||
|
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
# Custom logger class
|
# Custom logger class
|
||||||
class Logger:
|
class Logger:
|
||||||
|
@ -258,14 +269,30 @@ class DirConfig:
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QueryTracking(Base):
|
||||||
|
__tablename__ = 'query_tracking'
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True)
|
||||||
|
ts_id = Column(String, nullable=False)
|
||||||
|
query = Column(Text, nullable=False)
|
||||||
|
args = Column(JSONB)
|
||||||
|
executed_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
completed_by = Column(JSONB, default={})
|
||||||
|
result_checksum = Column(String)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, config_name: str):
|
def load(cls, config_name: str):
|
||||||
return cls(config_name)
|
return cls(config_name)
|
||||||
|
|
||||||
def __init__(self, config_path: str):
|
def __init__(self, config_path: str):
|
||||||
self.config = self.load_config(config_path)
|
self.config = self.load_config(config_path)
|
||||||
self.pool_connections = {}
|
self.engines: Dict[str, Any] = {}
|
||||||
|
self.sessions: Dict[str, Any] = {}
|
||||||
|
self.online_servers: set = set()
|
||||||
self.local_ts_id = self.get_local_ts_id()
|
self.local_ts_id = self.get_local_ts_id()
|
||||||
|
|
||||||
def load_config(self, config_path: str) -> Dict[str, Any]:
|
def load_config(self, config_path: str) -> Dict[str, Any]:
|
||||||
|
@ -280,91 +307,111 @@ class Database:
|
||||||
def get_local_ts_id(self) -> str:
|
def get_local_ts_id(self) -> str:
|
||||||
return os.environ.get('TS_ID')
|
return os.environ.get('TS_ID')
|
||||||
|
|
||||||
async def get_connection(self, ts_id: str = None):
|
async def initialize_engines(self):
|
||||||
if ts_id is None:
|
for db_info in self.config['POOL']:
|
||||||
ts_id = self.local_ts_id
|
url = f"postgresql+asyncpg://{db_info['db_user']}:{db_info['db_pass']}@{db_info['ts_ip']}:{db_info['db_port']}/{db_info['db_name']}"
|
||||||
|
try:
|
||||||
|
engine = create_async_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10)
|
||||||
|
self.engines[db_info['ts_id']] = engine
|
||||||
|
self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
info(f"Initialized engine and session for {db_info['ts_id']}")
|
||||||
|
except Exception as e:
|
||||||
|
err(f"Failed to initialize engine for {db_info['ts_id']}: {str(e)}")
|
||||||
|
|
||||||
if ts_id not in self.pool_connections:
|
if self.local_ts_id not in self.sessions:
|
||||||
db_info = next((db for db in self.config['POOL'] if db['ts_id'] == ts_id), None)
|
err(f"Failed to initialize session for local server {self.local_ts_id}")
|
||||||
if db_info is None:
|
else:
|
||||||
raise ValueError(f"No database configuration found for TS_ID: {ts_id}")
|
try:
|
||||||
|
# Create tables if they don't exist
|
||||||
self.pool_connections[ts_id] = await asyncpg.create_pool(
|
async with self.engines[self.local_ts_id].begin() as conn:
|
||||||
host=db_info['ts_ip'],
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
port=db_info['db_port'],
|
info(f"Initialized tables for local server {self.local_ts_id}")
|
||||||
user=db_info['db_user'],
|
except Exception as e:
|
||||||
password=db_info['db_pass'],
|
err(f"Failed to create tables for local server {self.local_ts_id}: {str(e)}")
|
||||||
database=db_info['db_name'],
|
|
||||||
min_size=1,
|
|
||||||
max_size=10
|
|
||||||
)
|
|
||||||
|
|
||||||
return await self.pool_connections[ts_id].acquire()
|
|
||||||
|
|
||||||
async def release_connection(self, ts_id: str, connection):
|
|
||||||
await self.pool_connections[ts_id].release(connection)
|
|
||||||
|
|
||||||
async def get_online_servers(self) -> List[str]:
|
async def get_online_servers(self) -> List[str]:
|
||||||
online_servers = []
|
online_servers = []
|
||||||
for db_info in self.config['POOL']:
|
for ts_id, engine in self.engines.items():
|
||||||
try:
|
try:
|
||||||
conn = await self.get_connection(db_info['ts_id'])
|
async with engine.connect() as conn:
|
||||||
await self.release_connection(db_info['ts_id'], conn)
|
await conn.execute(text("SELECT 1"))
|
||||||
online_servers.append(db_info['ts_id'])
|
online_servers.append(ts_id)
|
||||||
except:
|
except OperationalError:
|
||||||
pass
|
pass
|
||||||
|
self.online_servers = set(online_servers)
|
||||||
return online_servers
|
return online_servers
|
||||||
|
|
||||||
async def initialize_query_tracking(self):
|
async def execute_read(self, query: str, *args, **kwargs):
|
||||||
conn = await self.get_connection()
|
if self.local_ts_id not in self.sessions:
|
||||||
try:
|
err(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.")
|
||||||
await conn.execute("""
|
return None
|
||||||
CREATE TABLE IF NOT EXISTS query_tracking (
|
|
||||||
id SERIAL PRIMARY KEY,
|
params = self._normalize_params(args, kwargs)
|
||||||
ts_id TEXT NOT NULL,
|
|
||||||
query TEXT NOT NULL,
|
async with self.sessions[self.local_ts_id]() as session:
|
||||||
args JSONB,
|
try:
|
||||||
executed_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
result = await session.execute(text(query), params)
|
||||||
completed_by JSONB DEFAULT '{}'::jsonb,
|
return result.fetchall()
|
||||||
result_checksum TEXT
|
except Exception as e:
|
||||||
|
err(f"Failed to execute read query: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def execute_write(self, query: str, *args, **kwargs):
|
||||||
|
if self.local_ts_id not in self.sessions:
|
||||||
|
err(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.")
|
||||||
|
return
|
||||||
|
|
||||||
|
params = self._normalize_params(args, kwargs)
|
||||||
|
|
||||||
|
async with self.sessions[self.local_ts_id]() as session:
|
||||||
|
try:
|
||||||
|
# Execute the write query
|
||||||
|
result = await session.execute(text(query), params)
|
||||||
|
|
||||||
|
# Create a serializable version of params for logging
|
||||||
|
serializable_params = {
|
||||||
|
k: v.isoformat() if isinstance(v, datetime) else v
|
||||||
|
for k, v in params.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Log the query
|
||||||
|
new_query = QueryTracking(
|
||||||
|
ts_id=self.local_ts_id,
|
||||||
|
query=query,
|
||||||
|
args=json.dumps(serializable_params)
|
||||||
)
|
)
|
||||||
""")
|
session.add(new_query)
|
||||||
finally:
|
await session.flush()
|
||||||
await self.release_connection(self.local_ts_id, conn)
|
query_id = new_query.id
|
||||||
|
|
||||||
async def execute_read(self, query: str, *args):
|
await session.commit()
|
||||||
conn = await self.get_connection()
|
info(f"Successfully executed write query: {query[:50]}...")
|
||||||
try:
|
|
||||||
return await conn.fetch(query, *args)
|
# Calculate checksum
|
||||||
finally:
|
checksum = await self._local_compute_checksum(query, params)
|
||||||
await self.release_connection(self.local_ts_id, conn)
|
|
||||||
|
|
||||||
async def execute_write(self, query: str, *args):
|
# Update query_tracking with checksum
|
||||||
# Execute write on local database
|
await self.update_query_checksum(query_id, checksum)
|
||||||
local_conn = await self.get_connection()
|
|
||||||
try:
|
|
||||||
await local_conn.execute(query, *args)
|
|
||||||
|
|
||||||
# Log the query
|
|
||||||
query_id = await local_conn.fetchval("""
|
|
||||||
INSERT INTO query_tracking (ts_id, query, args)
|
|
||||||
VALUES ($1, $2, $3)
|
|
||||||
RETURNING id
|
|
||||||
""", self.local_ts_id, query, json.dumps(args))
|
|
||||||
finally:
|
|
||||||
await self.release_connection(self.local_ts_id, local_conn)
|
|
||||||
|
|
||||||
# Calculate checksum
|
# Replicate to online servers
|
||||||
checksum = await self.compute_checksum(query, *args)
|
online_servers = await self.get_online_servers()
|
||||||
|
for ts_id in online_servers:
|
||||||
|
if ts_id != self.local_ts_id:
|
||||||
|
asyncio.create_task(self._replicate_write(ts_id, query_id, query, params, checksum))
|
||||||
|
|
||||||
# Update query_tracking with checksum
|
except Exception as e:
|
||||||
await self.update_query_checksum(query_id, checksum)
|
err(f"Failed to execute write query: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
# Replicate to online servers
|
def _normalize_params(self, args, kwargs):
|
||||||
online_servers = await self.get_online_servers()
|
if args and isinstance(args[0], dict):
|
||||||
for ts_id in online_servers:
|
return args[0]
|
||||||
if ts_id != self.local_ts_id:
|
elif kwargs:
|
||||||
asyncio.create_task(self._replicate_write(ts_id, query_id, query, args, checksum))
|
return kwargs
|
||||||
|
elif args:
|
||||||
|
return {f"param{i}": arg for i, arg in enumerate(args, start=1)}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
async def get_primary_server(self) -> str:
|
async def get_primary_server(self) -> str:
|
||||||
url = urljoin(self.config['URL'], '/id')
|
url = urljoin(self.config['URL'], '/id')
|
||||||
|
@ -376,10 +423,10 @@ class Database:
|
||||||
primary_ts_id = await response.text()
|
primary_ts_id = await response.text()
|
||||||
return primary_ts_id.strip()
|
return primary_ts_id.strip()
|
||||||
else:
|
else:
|
||||||
logging.error(f"Failed to get primary server. Status: {response.status}")
|
err(f"Failed to get primary server. Status: {response.status}")
|
||||||
return None
|
return None
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
logging.error(f"Error connecting to load balancer: {str(e)}")
|
err(f"Error connecting to load balancer: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_checksum_server(self) -> dict:
|
async def get_checksum_server(self) -> dict:
|
||||||
|
@ -393,132 +440,134 @@ class Database:
|
||||||
|
|
||||||
return random.choice(checksum_servers)
|
return random.choice(checksum_servers)
|
||||||
|
|
||||||
async def compute_checksum(self, query: str, *args):
|
async def compute_checksum(self, query: str, params: dict):
|
||||||
checksum_server = await self.get_checksum_server()
|
checksum_server = await self.get_checksum_server()
|
||||||
|
|
||||||
if checksum_server['ts_id'] == self.local_ts_id:
|
if checksum_server['ts_id'] == self.local_ts_id:
|
||||||
return await self._local_compute_checksum(query, *args)
|
return await self._local_compute_checksum(query, params)
|
||||||
else:
|
else:
|
||||||
return await self._delegate_compute_checksum(checksum_server, query, *args)
|
return await self._delegate_compute_checksum(checksum_server, query, params)
|
||||||
|
|
||||||
async def _local_compute_checksum(self, query: str, *args):
|
async def _local_compute_checksum(self, query: str, params: dict):
|
||||||
conn = await self.get_connection()
|
async with self.sessions[self.local_ts_id]() as session:
|
||||||
try:
|
result = await session.execute(text(query), params)
|
||||||
result = await conn.fetch(query, *args)
|
if result.returns_rows:
|
||||||
checksum = hashlib.md5(str(result).encode()).hexdigest()
|
data = result.fetchall()
|
||||||
|
else:
|
||||||
|
# For INSERT, UPDATE, DELETE queries that don't return rows
|
||||||
|
data = str(result.rowcount) + query + str(params)
|
||||||
|
checksum = hashlib.md5(str(data).encode()).hexdigest()
|
||||||
return checksum
|
return checksum
|
||||||
finally:
|
|
||||||
await self.release_connection(self.local_ts_id, conn)
|
|
||||||
|
|
||||||
async def _delegate_compute_checksum(self, server: dict, query: str, *args):
|
async def _delegate_compute_checksum(self, server: Dict[str, Any], query: str, params: dict):
|
||||||
url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum"
|
url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum"
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
try:
|
try:
|
||||||
async with session.post(url, json={"query": query, "args": list(args)}) as response:
|
serializable_params = {
|
||||||
|
k: v.isoformat() if isinstance(v, datetime) else v
|
||||||
|
for k, v in params.items()
|
||||||
|
}
|
||||||
|
async with session.post(url, json={"query": query, "params": serializable_params}) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
result = await response.json()
|
result = await response.json()
|
||||||
return result['checksum']
|
return result['checksum']
|
||||||
else:
|
else:
|
||||||
logging.error(f"Failed to get checksum from {server['ts_id']}. Status: {response.status}")
|
err(f"Failed to get checksum from {server['ts_id']}. Status: {response.status}")
|
||||||
return await self._local_compute_checksum(query, *args)
|
return await self._local_compute_checksum(query, params)
|
||||||
except aiohttp.ClientError as e:
|
except aiohttp.ClientError as e:
|
||||||
logging.error(f"Error connecting to {server['ts_id']} for checksum: {str(e)}")
|
err(f"Error connecting to {server['ts_id']} for checksum: {str(e)}")
|
||||||
return await self._local_compute_checksum(query, *args)
|
return await self._local_compute_checksum(query, params)
|
||||||
|
|
||||||
async def update_query_checksum(self, query_id: int, checksum: str):
|
async def update_query_checksum(self, query_id: int, checksum: str):
|
||||||
conn = await self.get_connection()
|
async with self.sessions[self.local_ts_id]() as session:
|
||||||
try:
|
await session.execute(
|
||||||
await conn.execute("""
|
text("UPDATE query_tracking SET result_checksum = :checksum WHERE id = :id"),
|
||||||
UPDATE query_tracking
|
{"checksum": checksum, "id": query_id}
|
||||||
SET result_checksum = $1
|
)
|
||||||
WHERE id = $2
|
await session.commit()
|
||||||
""", checksum, query_id)
|
|
||||||
finally:
|
|
||||||
await self.release_connection(self.local_ts_id, conn)
|
|
||||||
|
|
||||||
async def _replicate_write(self, ts_id: str, query_id: int, query: str, args: tuple, expected_checksum: str):
|
async def _replicate_write(self, ts_id: str, query_id: int, query: str, params: dict, expected_checksum: str):
|
||||||
try:
|
try:
|
||||||
conn = await self.get_connection(ts_id)
|
async with self.sessions[ts_id]() as session:
|
||||||
try:
|
await session.execute(text(query), params)
|
||||||
await conn.execute(query, *args)
|
actual_checksum = await self.compute_checksum(query, params)
|
||||||
actual_checksum = await self.compute_checksum(query, *args)
|
|
||||||
if actual_checksum != expected_checksum:
|
if actual_checksum != expected_checksum:
|
||||||
raise ValueError(f"Checksum mismatch on {ts_id}")
|
raise ValueError(f"Checksum mismatch on {ts_id}")
|
||||||
await self.mark_query_completed(query_id, ts_id)
|
await self.mark_query_completed(query_id, ts_id)
|
||||||
finally:
|
await session.commit()
|
||||||
await self.release_connection(ts_id, conn)
|
info(f"Successfully replicated write to {ts_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to replicate write on {ts_id}: {str(e)}")
|
err(f"Failed to replicate write on {ts_id}: {str(e)}")
|
||||||
|
|
||||||
async def mark_query_completed(self, query_id: int, ts_id: str):
|
async def mark_query_completed(self, query_id: int, ts_id: str):
|
||||||
conn = await self.get_connection()
|
async with self.sessions[self.local_ts_id]() as session:
|
||||||
try:
|
query = await session.get(QueryTracking, query_id)
|
||||||
await conn.execute("""
|
if query:
|
||||||
UPDATE query_tracking
|
completed_by = query.completed_by or {}
|
||||||
SET completed_by = completed_by || jsonb_build_object($1, true)
|
completed_by[ts_id] = True
|
||||||
WHERE id = $2
|
query.completed_by = completed_by
|
||||||
""", ts_id, query_id)
|
await session.commit()
|
||||||
finally:
|
|
||||||
await self.release_connection(self.local_ts_id, conn)
|
|
||||||
|
|
||||||
async def sync_local_server(self):
|
async def sync_local_server(self):
|
||||||
conn = await self.get_connection()
|
async with self.sessions[self.local_ts_id]() as session:
|
||||||
try:
|
last_synced = await session.execute(
|
||||||
last_synced_id = await conn.fetchval("""
|
text("SELECT MAX(id) FROM query_tracking WHERE completed_by ? :ts_id"),
|
||||||
SELECT COALESCE(MAX(id), 0) FROM query_tracking
|
{"ts_id": self.local_ts_id}
|
||||||
WHERE completed_by ? $1
|
)
|
||||||
""", self.local_ts_id)
|
last_synced_id = last_synced.scalar() or 0
|
||||||
|
|
||||||
unexecuted_queries = await conn.fetch("""
|
unexecuted_queries = await session.execute(
|
||||||
SELECT id, query, args, result_checksum
|
text("SELECT * FROM query_tracking WHERE id > :last_id ORDER BY id"),
|
||||||
FROM query_tracking
|
{"last_id": last_synced_id}
|
||||||
WHERE id > $1
|
)
|
||||||
ORDER BY id
|
|
||||||
""", last_synced_id)
|
|
||||||
|
|
||||||
for query in unexecuted_queries:
|
for query in unexecuted_queries:
|
||||||
try:
|
try:
|
||||||
await conn.execute(query['query'], *json.loads(query['args']))
|
params = json.loads(query.args)
|
||||||
actual_checksum = await self.compute_checksum(query['query'], *json.loads(query['args']))
|
# Convert ISO format strings back to datetime objects
|
||||||
if actual_checksum != query['result_checksum']:
|
for key, value in params.items():
|
||||||
raise ValueError(f"Checksum mismatch for query ID {query['id']}")
|
if isinstance(value, str):
|
||||||
await self.mark_query_completed(query['id'], self.local_ts_id)
|
try:
|
||||||
|
params[key] = datetime.fromisoformat(value)
|
||||||
|
except ValueError:
|
||||||
|
pass # If it's not a valid ISO format, keep it as a string
|
||||||
|
await session.execute(text(query.query), params)
|
||||||
|
actual_checksum = await self._local_compute_checksum(query.query, params)
|
||||||
|
if actual_checksum != query.result_checksum:
|
||||||
|
raise ValueError(f"Checksum mismatch for query ID {query.id}")
|
||||||
|
await self.mark_query_completed(query.id, self.local_ts_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to execute query ID {query['id']} during local sync: {str(e)}")
|
err(f"Failed to execute query ID {query.id} during local sync: {str(e)}")
|
||||||
|
|
||||||
logging.info(f"Local server sync completed. Executed {len(unexecuted_queries)} queries.")
|
await session.commit()
|
||||||
|
info(f"Local server sync completed. Executed {unexecuted_queries.rowcount} queries.")
|
||||||
finally:
|
|
||||||
await self.release_connection(self.local_ts_id, conn)
|
|
||||||
|
|
||||||
async def purge_completed_queries(self):
|
async def purge_completed_queries(self):
|
||||||
conn = await self.get_connection()
|
async with self.sessions[self.local_ts_id]() as session:
|
||||||
try:
|
|
||||||
all_ts_ids = [db['ts_id'] for db in self.config['POOL']]
|
all_ts_ids = [db['ts_id'] for db in self.config['POOL']]
|
||||||
result = await conn.execute("""
|
|
||||||
WITH consecutive_completed AS (
|
result = await session.execute(
|
||||||
SELECT id,
|
text("""
|
||||||
row_number() OVER (ORDER BY id) AS rn
|
DELETE FROM query_tracking
|
||||||
FROM query_tracking
|
WHERE id <= (
|
||||||
WHERE completed_by ?& $1
|
SELECT MAX(id)
|
||||||
)
|
FROM query_tracking
|
||||||
DELETE FROM query_tracking
|
WHERE completed_by ?& :ts_ids
|
||||||
WHERE id IN (
|
)
|
||||||
SELECT id
|
"""),
|
||||||
FROM consecutive_completed
|
{"ts_ids": all_ts_ids}
|
||||||
WHERE rn = (SELECT MAX(rn) FROM consecutive_completed)
|
)
|
||||||
)
|
await session.commit()
|
||||||
""", all_ts_ids)
|
|
||||||
deleted_count = int(result.split()[-1])
|
deleted_count = result.rowcount
|
||||||
logging.info(f"Purged {deleted_count} completed queries.")
|
info(f"Purged {deleted_count} completed queries.")
|
||||||
finally:
|
|
||||||
await self.release_connection(self.local_ts_id, conn)
|
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
for pool in self.pool_connections.values():
|
for engine in self.engines.values():
|
||||||
await pool.close()
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Configuration class for API & Database methods.
|
# Configuration class for API & Database methods.
|
||||||
|
|
46
sijapi/helpers/db_uuid_migrate.py
Normal file
46
sijapi/helpers/db_uuid_migrate.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
import asyncio
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
# Database connection information
|
||||||
|
DB_INFO = {
|
||||||
|
'host': '100.64.64.20',
|
||||||
|
'port': 5432,
|
||||||
|
'database': 'sij',
|
||||||
|
'user': 'sij',
|
||||||
|
'password': 'Synchr0!'
|
||||||
|
}
|
||||||
|
|
||||||
|
async def update_click_logs():
|
||||||
|
# Connect to the database
|
||||||
|
conn = await asyncpg.connect(**DB_INFO)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Drop existing 'id' and 'new_id' columns if they exist
|
||||||
|
await conn.execute("""
|
||||||
|
ALTER TABLE click_logs
|
||||||
|
DROP COLUMN IF EXISTS id,
|
||||||
|
DROP COLUMN IF EXISTS new_id;
|
||||||
|
""")
|
||||||
|
print("Dropped existing id and new_id columns (if they existed)")
|
||||||
|
|
||||||
|
# Add new UUID column as primary key
|
||||||
|
await conn.execute("""
|
||||||
|
ALTER TABLE click_logs
|
||||||
|
ADD COLUMN id UUID PRIMARY KEY DEFAULT gen_random_uuid();
|
||||||
|
""")
|
||||||
|
print("Added new UUID column as primary key")
|
||||||
|
|
||||||
|
# Get the number of rows in the table
|
||||||
|
row_count = await conn.fetchval("SELECT COUNT(*) FROM click_logs")
|
||||||
|
print(f"Number of rows in click_logs: {row_count}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
# Close the database connection
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
# Run the update
|
||||||
|
asyncio.run(update_click_logs())
|
|
@ -18,7 +18,7 @@ from dateutil.parser import parse as dateutil_parse
|
||||||
from typing import Optional, List, Union
|
from typing import Optional, List, Union
|
||||||
from sijapi import L, API, Db, TZ, GEO
|
from sijapi import L, API, Db, TZ, GEO
|
||||||
from sijapi.classes import Location
|
from sijapi.classes import Location
|
||||||
from sijapi.utilities import haversine, assemble_journal_path
|
from sijapi.utilities import haversine, assemble_journal_path, json_serial
|
||||||
|
|
||||||
gis = APIRouter()
|
gis = APIRouter()
|
||||||
logger = L.get_module_logger("gis")
|
logger = L.get_module_logger("gis")
|
||||||
|
@ -122,140 +122,7 @@ async def get_last_location() -> Optional[Location]:
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, datetime, None] = None) -> List[Location]:
|
|
||||||
start_datetime = await dt(start)
|
|
||||||
if end is None:
|
|
||||||
end_datetime = await dt(start_datetime.replace(hour=23, minute=59, second=59))
|
|
||||||
else:
|
|
||||||
end_datetime = await dt(end) if not isinstance(end, datetime) else end
|
|
||||||
|
|
||||||
if start_datetime.time() == datetime.min.time() and end_datetime.time() == datetime.min.time():
|
|
||||||
end_datetime = await dt(end_datetime.replace(hour=23, minute=59, second=59))
|
|
||||||
|
|
||||||
debug(f"Fetching locations between {start_datetime} and {end_datetime}")
|
|
||||||
|
|
||||||
query = '''
|
|
||||||
SELECT id, datetime,
|
|
||||||
ST_X(ST_AsText(location)::geometry) AS longitude,
|
|
||||||
ST_Y(ST_AsText(location)::geometry) AS latitude,
|
|
||||||
ST_Z(ST_AsText(location)::geometry) AS elevation,
|
|
||||||
city, state, zip, street,
|
|
||||||
action, device_type, device_model, device_name, device_os
|
|
||||||
FROM locations
|
|
||||||
WHERE datetime >= $1 AND datetime <= $2
|
|
||||||
ORDER BY datetime DESC
|
|
||||||
'''
|
|
||||||
|
|
||||||
locations = await Db.execute_read(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None))
|
|
||||||
|
|
||||||
debug(f"Range locations query returned: {locations}")
|
|
||||||
|
|
||||||
if not locations and (end is None or start_datetime.date() == end_datetime.date()):
|
|
||||||
fallback_query = '''
|
|
||||||
SELECT id, datetime,
|
|
||||||
ST_X(ST_AsText(location)::geometry) AS longitude,
|
|
||||||
ST_Y(ST_AsText(location)::geometry) AS latitude,
|
|
||||||
ST_Z(ST_AsText(location)::geometry) AS elevation,
|
|
||||||
city, state, zip, street,
|
|
||||||
action, device_type, device_model, device_name, device_os
|
|
||||||
FROM locations
|
|
||||||
WHERE datetime < $1
|
|
||||||
ORDER BY datetime DESC
|
|
||||||
LIMIT 1
|
|
||||||
'''
|
|
||||||
location_data = await Db.execute_read(fallback_query, start_datetime.replace(tzinfo=None))
|
|
||||||
debug(f"Fallback query returned: {location_data}")
|
|
||||||
if location_data:
|
|
||||||
locations = location_data
|
|
||||||
|
|
||||||
debug(f"Locations found: {locations}")
|
|
||||||
|
|
||||||
# Sort location_data based on the datetime field in descending order
|
|
||||||
sorted_locations = sorted(locations, key=lambda x: x['datetime'], reverse=True)
|
|
||||||
|
|
||||||
# Create Location objects directly from the location data
|
|
||||||
location_objects = [
|
|
||||||
Location(
|
|
||||||
latitude=location['latitude'],
|
|
||||||
longitude=location['longitude'],
|
|
||||||
datetime=location['datetime'],
|
|
||||||
elevation=location.get('elevation'),
|
|
||||||
city=location.get('city'),
|
|
||||||
state=location.get('state'),
|
|
||||||
zip=location.get('zip'),
|
|
||||||
street=location.get('street'),
|
|
||||||
context={
|
|
||||||
'action': location.get('action'),
|
|
||||||
'device_type': location.get('device_type'),
|
|
||||||
'device_model': location.get('device_model'),
|
|
||||||
'device_name': location.get('device_name'),
|
|
||||||
'device_os': location.get('device_os')
|
|
||||||
}
|
|
||||||
) for location in sorted_locations if location['latitude'] is not None and location['longitude'] is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
return location_objects if location_objects else []
|
|
||||||
|
|
||||||
|
|
||||||
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
|
|
||||||
try:
|
|
||||||
datetime = await dt(datetime)
|
|
||||||
|
|
||||||
debug(f"Fetching last location before {datetime}")
|
|
||||||
|
|
||||||
query = '''
|
|
||||||
SELECT id, datetime,
|
|
||||||
ST_X(ST_AsText(location)::geometry) AS longitude,
|
|
||||||
ST_Y(ST_AsText(location)::geometry) AS latitude,
|
|
||||||
ST_Z(ST_AsText(location)::geometry) AS elevation,
|
|
||||||
city, state, zip, street, country,
|
|
||||||
action
|
|
||||||
FROM locations
|
|
||||||
WHERE datetime < $1
|
|
||||||
ORDER BY datetime DESC
|
|
||||||
LIMIT 1
|
|
||||||
'''
|
|
||||||
|
|
||||||
location_data = await Db.execute_read(query, datetime.replace(tzinfo=None))
|
|
||||||
|
|
||||||
if location_data:
|
|
||||||
debug(f"Last location found: {location_data[0]}")
|
|
||||||
return Location(**location_data[0])
|
|
||||||
else:
|
|
||||||
debug("No location found before the specified datetime")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
error(f"Error fetching last location: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@gis.get("/map", response_class=HTMLResponse)
|
|
||||||
async def generate_map_endpoint(
|
|
||||||
start_date: Optional[str] = Query(None),
|
|
||||||
end_date: Optional[str] = Query(None),
|
|
||||||
max_points: int = Query(32767, description="Maximum number of points to display")
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
if start_date and end_date:
|
|
||||||
start_date = await dt(start_date)
|
|
||||||
end_date = await dt(end_date)
|
|
||||||
else:
|
|
||||||
start_date, end_date = await get_date_range()
|
|
||||||
except ValueError:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid date format")
|
|
||||||
|
|
||||||
info(f"Generating map for {start_date} to {end_date}")
|
|
||||||
html_content = await generate_map(start_date, end_date, max_points)
|
|
||||||
return HTMLResponse(content=html_content)
|
|
||||||
|
|
||||||
async def get_date_range():
|
|
||||||
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
|
|
||||||
row = await Db.execute_read(query, table_name="locations")
|
|
||||||
if row and row[0]['min_date'] and row[0]['max_date']:
|
|
||||||
return row[0]['min_date'], row[0]['max_date']
|
|
||||||
else:
|
|
||||||
return datetime(2022, 1, 1), datetime.now()
|
|
||||||
|
|
||||||
async def generate_and_save_heatmap(
|
async def generate_and_save_heatmap(
|
||||||
start_date: Union[str, int, datetime],
|
start_date: Union[str, int, datetime],
|
||||||
|
@ -424,6 +291,114 @@ map.on(L.Draw.Event.CREATED, function (event) {
|
||||||
return m.get_root().render()
|
return m.get_root().render()
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, datetime, None] = None) -> List[Location]:
|
||||||
|
start_datetime = await dt(start)
|
||||||
|
if end is None:
|
||||||
|
end_datetime = await dt(start_datetime.replace(hour=23, minute=59, second=59))
|
||||||
|
else:
|
||||||
|
end_datetime = await dt(end) if not isinstance(end, datetime) else end
|
||||||
|
|
||||||
|
if start_datetime.time() == datetime.min.time() and end_datetime.time() == datetime.min.time():
|
||||||
|
end_datetime = await dt(end_datetime.replace(hour=23, minute=59, second=59))
|
||||||
|
|
||||||
|
debug(f"Fetching locations between {start_datetime} and {end_datetime}")
|
||||||
|
|
||||||
|
query = '''
|
||||||
|
SELECT id, datetime,
|
||||||
|
ST_X(ST_AsText(location)::geometry) AS longitude,
|
||||||
|
ST_Y(ST_AsText(location)::geometry) AS latitude,
|
||||||
|
ST_Z(ST_AsText(location)::geometry) AS elevation,
|
||||||
|
city, state, zip, street,
|
||||||
|
action, device_type, device_model, device_name, device_os
|
||||||
|
FROM locations
|
||||||
|
WHERE datetime >= :start_datetime AND datetime <= :end_datetime
|
||||||
|
ORDER BY datetime DESC
|
||||||
|
'''
|
||||||
|
|
||||||
|
locations = await Db.execute_read(query, start_datetime=start_datetime.replace(tzinfo=None), end_datetime=end_datetime.replace(tzinfo=None))
|
||||||
|
|
||||||
|
debug(f"Range locations query returned: {locations}")
|
||||||
|
|
||||||
|
if not locations and (end is None or start_datetime.date() == end_datetime.date()):
|
||||||
|
fallback_query = '''
|
||||||
|
SELECT id, datetime,
|
||||||
|
ST_X(ST_AsText(location)::geometry) AS longitude,
|
||||||
|
ST_Y(ST_AsText(location)::geometry) AS latitude,
|
||||||
|
ST_Z(ST_AsText(location)::geometry) AS elevation,
|
||||||
|
city, state, zip, street,
|
||||||
|
action, device_type, device_model, device_name, device_os
|
||||||
|
FROM locations
|
||||||
|
WHERE datetime < :start_datetime
|
||||||
|
ORDER BY datetime DESC
|
||||||
|
LIMIT 1
|
||||||
|
'''
|
||||||
|
location_data = await Db.execute_read(fallback_query, start_datetime=start_datetime.replace(tzinfo=None))
|
||||||
|
debug(f"Fallback query returned: {location_data}")
|
||||||
|
if location_data:
|
||||||
|
locations = location_data
|
||||||
|
|
||||||
|
debug(f"Locations found: {locations}")
|
||||||
|
|
||||||
|
# Sort location_data based on the datetime field in descending order
|
||||||
|
sorted_locations = sorted(locations, key=lambda x: x['datetime'], reverse=True)
|
||||||
|
|
||||||
|
# Create Location objects directly from the location data
|
||||||
|
location_objects = [
|
||||||
|
Location(
|
||||||
|
latitude=location['latitude'],
|
||||||
|
longitude=location['longitude'],
|
||||||
|
datetime=location['datetime'],
|
||||||
|
elevation=location.get('elevation'),
|
||||||
|
city=location.get('city'),
|
||||||
|
state=location.get('state'),
|
||||||
|
zip=location.get('zip'),
|
||||||
|
street=location.get('street'),
|
||||||
|
context={
|
||||||
|
'action': location.get('action'),
|
||||||
|
'device_type': location.get('device_type'),
|
||||||
|
'device_model': location.get('device_model'),
|
||||||
|
'device_name': location.get('device_name'),
|
||||||
|
'device_os': location.get('device_os')
|
||||||
|
}
|
||||||
|
) for location in sorted_locations if location['latitude'] is not None and location['longitude'] is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
return location_objects if location_objects else []
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
|
||||||
|
try:
|
||||||
|
datetime = await dt(datetime)
|
||||||
|
|
||||||
|
debug(f"Fetching last location before {datetime}")
|
||||||
|
|
||||||
|
query = '''
|
||||||
|
SELECT id, datetime,
|
||||||
|
ST_X(ST_AsText(location)::geometry) AS longitude,
|
||||||
|
ST_Y(ST_AsText(location)::geometry) AS latitude,
|
||||||
|
ST_Z(ST_AsText(location)::geometry) AS elevation,
|
||||||
|
city, state, zip, street, country,
|
||||||
|
action
|
||||||
|
FROM locations
|
||||||
|
WHERE datetime < :datetime
|
||||||
|
ORDER BY datetime DESC
|
||||||
|
LIMIT 1
|
||||||
|
'''
|
||||||
|
|
||||||
|
location_data = await Db.execute_read(query, datetime=datetime.replace(tzinfo=None))
|
||||||
|
|
||||||
|
if location_data:
|
||||||
|
debug(f"Last location found: {location_data[0]}")
|
||||||
|
return Location(**location_data[0])
|
||||||
|
else:
|
||||||
|
debug("No location found before the specified datetime")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
error(f"Error fetching last location: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def post_location(location: Location):
|
async def post_location(location: Location):
|
||||||
try:
|
try:
|
||||||
context = location.context or {}
|
context = location.context or {}
|
||||||
|
@ -435,31 +410,23 @@ async def post_location(location: Location):
|
||||||
|
|
||||||
# Parse and localize the datetime
|
# Parse and localize the datetime
|
||||||
localized_datetime = await dt(location.datetime)
|
localized_datetime = await dt(location.datetime)
|
||||||
|
|
||||||
query = '''
|
query = '''
|
||||||
INSERT INTO locations (
|
INSERT INTO locations (
|
||||||
datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os,
|
datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os,
|
||||||
class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood,
|
class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood,
|
||||||
suburb, county, country_code, country
|
suburb, county, country_code, country
|
||||||
)
|
)
|
||||||
VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15,
|
VALUES (:datetime, ST_SetSRID(ST_MakePoint(:longitude, :latitude, :elevation), 4326), :city, :state, :zip,
|
||||||
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
|
:street, :action, :device_type, :device_model, :device_name, :device_os, :class_, :type, :name,
|
||||||
|
:display_name, :amenity, :house_number, :road, :quarter, :neighbourhood, :suburb, :county,
|
||||||
|
:country_code, :country)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
await Db.execute_write(
|
params = {
|
||||||
query,
|
|
||||||
localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
|
|
||||||
location.zip, location.street, action, device_type, device_model, device_name, device_os,
|
|
||||||
location.class_, location.type, location.name, location.display_name,
|
|
||||||
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
|
|
||||||
location.suburb, location.county, location.country_code, location.country
|
|
||||||
)
|
|
||||||
|
|
||||||
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")
|
|
||||||
return {
|
|
||||||
'datetime': localized_datetime,
|
'datetime': localized_datetime,
|
||||||
'latitude': location.latitude,
|
|
||||||
'longitude': location.longitude,
|
'longitude': location.longitude,
|
||||||
|
'latitude': location.latitude,
|
||||||
'elevation': location.elevation,
|
'elevation': location.elevation,
|
||||||
'city': location.city,
|
'city': location.city,
|
||||||
'state': location.state,
|
'state': location.state,
|
||||||
|
@ -484,12 +451,34 @@ async def post_location(location: Location):
|
||||||
'country_code': location.country_code,
|
'country_code': location.country_code,
|
||||||
'country': location.country
|
'country': location.country
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await Db.execute_write(query, **params)
|
||||||
|
|
||||||
|
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")
|
||||||
|
|
||||||
|
# Create a serializable version of params for the return value
|
||||||
|
serializable_params = {
|
||||||
|
k: v.isoformat() if isinstance(v, datetime) else v
|
||||||
|
for k, v in params.items()
|
||||||
|
}
|
||||||
|
return serializable_params
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
err(f"Error posting location {e}")
|
err(f"Error posting location {e}")
|
||||||
err(traceback.format_exc())
|
err(traceback.format_exc())
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def get_date_range():
|
||||||
|
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
|
||||||
|
row = await Db.execute_read(query)
|
||||||
|
if row and row[0]['min_date'] and row[0]['max_date']:
|
||||||
|
return row[0]['min_date'], row[0]['max_date']
|
||||||
|
else:
|
||||||
|
return datetime(2022, 1, 1), datetime.now()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@gis.post("/locate")
|
@gis.post("/locate")
|
||||||
async def post_locate_endpoint(locations: Union[Location, List[Location]]):
|
async def post_locate_endpoint(locations: Union[Location, List[Location]]):
|
||||||
if isinstance(locations, Location):
|
if isinstance(locations, Location):
|
||||||
|
@ -532,8 +521,8 @@ async def post_locate_endpoint(locations: Union[Location, List[Location]]):
|
||||||
|
|
||||||
return {"message": "Locations and weather updated", "results": responses}
|
return {"message": "Locations and weather updated", "results": responses}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@gis.get("/locate", response_model=Location)
|
@gis.get("/locate", response_model=Location)
|
||||||
async def get_last_location_endpoint() -> JSONResponse:
|
async def get_last_location_endpoint() -> JSONResponse:
|
||||||
this_location = await get_last_location()
|
this_location = await get_last_location()
|
||||||
|
@ -544,8 +533,8 @@ async def get_last_location_endpoint() -> JSONResponse:
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=404, detail="No location found before the specified datetime")
|
raise HTTPException(status_code=404, detail="No location found before the specified datetime")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@gis.get("/locate/{datetime_str}", response_model=List[Location])
|
@gis.get("/locate/{datetime_str}", response_model=List[Location])
|
||||||
async def get_locate(datetime_str: str, all: bool = False):
|
async def get_locate(datetime_str: str, all: bool = False):
|
||||||
try:
|
try:
|
||||||
|
@ -558,4 +547,24 @@ async def get_locate(datetime_str: str, all: bool = False):
|
||||||
if not locations:
|
if not locations:
|
||||||
raise HTTPException(status_code=404, detail="No nearby data found for this date and time")
|
raise HTTPException(status_code=404, detail="No nearby data found for this date and time")
|
||||||
|
|
||||||
return locations if all else [locations[0]]
|
return locations if all else [locations[0]]
|
||||||
|
|
||||||
|
|
||||||
|
@gis.get("/map", response_class=HTMLResponse)
|
||||||
|
async def generate_map_endpoint(
|
||||||
|
start_date: Optional[str] = Query(None),
|
||||||
|
end_date: Optional[str] = Query(None),
|
||||||
|
max_points: int = Query(32767, description="Maximum number of points to display")
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if start_date and end_date:
|
||||||
|
start_date = await dt(start_date)
|
||||||
|
end_date = await dt(end_date)
|
||||||
|
else:
|
||||||
|
start_date, end_date = await get_date_range()
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid date format")
|
||||||
|
|
||||||
|
info(f"Generating map for {start_date} to {end_date}")
|
||||||
|
html_content = await generate_map(start_date, end_date, max_points)
|
||||||
|
return HTMLResponse(content=html_content)
|
||||||
|
|
|
@ -26,7 +26,7 @@ import pytesseract
|
||||||
from readability import Document
|
from readability import Document
|
||||||
from pdf2image import convert_from_path
|
from pdf2image import convert_from_path
|
||||||
from datetime import datetime as dt_datetime, date, time
|
from datetime import datetime as dt_datetime, date, time
|
||||||
from typing import Optional, Union, Tuple, List
|
from typing import Optional, Union, Tuple, List, Any
|
||||||
import asyncio
|
import asyncio
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -629,3 +629,34 @@ async def html_to_markdown(url: str = None, source: str = None) -> Optional[str]
|
||||||
markdown_content = md(str(soup), heading_style="ATX")
|
markdown_content = md(str(soup), heading_style="ATX")
|
||||||
|
|
||||||
return markdown_content
|
return markdown_content
|
||||||
|
|
||||||
|
|
||||||
|
def json_serial(obj: Any) -> Any:
|
||||||
|
"""JSON serializer for objects not serializable by default json code"""
|
||||||
|
if isinstance(obj, (datetime, date)):
|
||||||
|
return obj.isoformat()
|
||||||
|
if isinstance(obj, time):
|
||||||
|
return obj.isoformat()
|
||||||
|
if isinstance(obj, Decimal):
|
||||||
|
return float(obj)
|
||||||
|
if isinstance(obj, UUID):
|
||||||
|
return str(obj)
|
||||||
|
if isinstance(obj, bytes):
|
||||||
|
return obj.decode('utf-8')
|
||||||
|
if isinstance(obj, Path):
|
||||||
|
return str(obj)
|
||||||
|
if hasattr(obj, '__dict__'):
|
||||||
|
return obj.__dict__
|
||||||
|
raise TypeError(f"Type {type(obj)} not serializable")
|
||||||
|
|
||||||
|
def json_dumps(obj: Any) -> str:
|
||||||
|
"""
|
||||||
|
Serialize obj to a JSON formatted str using the custom serializer.
|
||||||
|
"""
|
||||||
|
return json.dumps(obj, default=json_serial)
|
||||||
|
|
||||||
|
def json_loads(json_str: str) -> Any:
|
||||||
|
"""
|
||||||
|
Deserialize json_str to a Python object.
|
||||||
|
"""
|
||||||
|
return json.loads(json_str)
|
Loading…
Add table
Reference in a new issue