Auto-update: Fri Aug 9 11:31:45 PDT 2024

This commit is contained in:
sanj 2024-08-09 11:31:45 -07:00
parent 487807bab1
commit ee6ee1ed87
6 changed files with 463 additions and 326 deletions

View file

@ -26,8 +26,7 @@ Db = Database.load('sys')
# HOST = f"{API.BIND}:{API.PORT}" # HOST = f"{API.BIND}:{API.PORT}"
# LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost'] # LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
# SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count()) MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())

View file

@ -21,7 +21,7 @@ from dotenv import load_dotenv
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
import argparse import argparse
from . import L, API, ROUTER_DIR from . import L, API, Db, ROUTER_DIR
parser = argparse.ArgumentParser(description='Personal API.') parser = argparse.ArgumentParser(description='Personal API.')
parser.add_argument('--log', type=str, default='INFO', help='Set overall log level (e.g., DEBUG, INFO, WARNING)') parser.add_argument('--log', type=str, default='INFO', help='Set overall log level (e.g., DEBUG, INFO, WARNING)')
@ -55,7 +55,8 @@ async def lifespan(app: FastAPI):
try: try:
# Initialize sync structures on all databases # Initialize sync structures on all databases
await API.initialize_sync() # await API.initialize_sync()
await Db.initialize_engines()
except Exception as e: except Exception as e:
crit(f"Error during startup: {str(e)}") crit(f"Error during startup: {str(e)}")
@ -99,16 +100,18 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
api_key_header = request.headers.get("Authorization") api_key_header = request.headers.get("Authorization")
api_key_query = request.query_params.get("api_key") api_key_query = request.query_params.get("api_key")
# Debug logging for API keys # Convert API.KEYS to lowercase for case-insensitive comparison
debug(f"API.KEYS: {API.KEYS}") api_keys_lower = [key.lower() for key in API.KEYS]
debug(f"API.KEYS (lowercase): {api_keys_lower}")
if api_key_header: if api_key_header:
api_key_header = api_key_header.lower().split("bearer ")[-1] api_key_header = api_key_header.lower().split("bearer ")[-1]
debug(f"API key provided in header: {api_key_header}") debug(f"API key provided in header: {api_key_header}")
if api_key_query: if api_key_query:
api_key_query = api_key_query.lower()
debug(f"API key provided in query: {api_key_query}") debug(f"API key provided in query: {api_key_query}")
if api_key_header not in API.KEYS and api_key_query not in API.KEYS: if api_key_header.lower() not in api_keys_lower and api_key_query.lower() not in api_keys_lower:
err(f"Invalid API key provided by a requester.") err(f"Invalid API key provided by a requester.")
if api_key_header: if api_key_header:
debug(f"Invalid API key in header: {api_key_header}") debug(f"Invalid API key in header: {api_key_header}")
@ -119,9 +122,9 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
content={"detail": "Invalid or missing API key"} content={"detail": "Invalid or missing API key"}
) )
else: else:
if api_key_header in API.KEYS: if api_key_header.lower() in api_keys_lower:
debug(f"Valid API key provided in header: {api_key_header}") debug(f"Valid API key provided in header: {api_key_header}")
if api_key_query in API.KEYS: if api_key_query and api_key_query.lower() in api_keys_lower:
debug(f"Valid API key provided in query: {api_key_query}") debug(f"Valid API key provided in query: {api_key_query}")
response = await call_next(request) response = await call_next(request)

View file

@ -27,6 +27,17 @@ from srtm import get_data
import os import os
import sys import sys
from loguru import logger from loguru import logger
from sqlalchemy import text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy import Column, Integer, String, DateTime, JSON, Text, select, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.exc import OperationalError
from urllib.parse import urljoin
import hashlib
import random
Base = declarative_base()
# Custom logger class # Custom logger class
class Logger: class Logger:
@ -258,6 +269,20 @@ class DirConfig:
class QueryTracking(Base):
__tablename__ = 'query_tracking'
id = Column(Integer, primary_key=True)
ts_id = Column(String, nullable=False)
query = Column(Text, nullable=False)
args = Column(JSONB)
executed_at = Column(DateTime(timezone=True), server_default=func.now())
completed_by = Column(JSONB, default={})
result_checksum = Column(String)
class Database: class Database:
@classmethod @classmethod
def load(cls, config_name: str): def load(cls, config_name: str):
@ -265,7 +290,9 @@ class Database:
def __init__(self, config_path: str): def __init__(self, config_path: str):
self.config = self.load_config(config_path) self.config = self.load_config(config_path)
self.pool_connections = {} self.engines: Dict[str, Any] = {}
self.sessions: Dict[str, Any] = {}
self.online_servers: set = set()
self.local_ts_id = self.get_local_ts_id() self.local_ts_id = self.get_local_ts_id()
def load_config(self, config_path: str) -> Dict[str, Any]: def load_config(self, config_path: str) -> Dict[str, Any]:
@ -280,91 +307,111 @@ class Database:
def get_local_ts_id(self) -> str: def get_local_ts_id(self) -> str:
return os.environ.get('TS_ID') return os.environ.get('TS_ID')
async def get_connection(self, ts_id: str = None): async def initialize_engines(self):
if ts_id is None: for db_info in self.config['POOL']:
ts_id = self.local_ts_id url = f"postgresql+asyncpg://{db_info['db_user']}:{db_info['db_pass']}@{db_info['ts_ip']}:{db_info['db_port']}/{db_info['db_name']}"
try:
engine = create_async_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10)
self.engines[db_info['ts_id']] = engine
self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
info(f"Initialized engine and session for {db_info['ts_id']}")
except Exception as e:
err(f"Failed to initialize engine for {db_info['ts_id']}: {str(e)}")
if ts_id not in self.pool_connections: if self.local_ts_id not in self.sessions:
db_info = next((db for db in self.config['POOL'] if db['ts_id'] == ts_id), None) err(f"Failed to initialize session for local server {self.local_ts_id}")
if db_info is None: else:
raise ValueError(f"No database configuration found for TS_ID: {ts_id}") try:
# Create tables if they don't exist
self.pool_connections[ts_id] = await asyncpg.create_pool( async with self.engines[self.local_ts_id].begin() as conn:
host=db_info['ts_ip'], await conn.run_sync(Base.metadata.create_all)
port=db_info['db_port'], info(f"Initialized tables for local server {self.local_ts_id}")
user=db_info['db_user'], except Exception as e:
password=db_info['db_pass'], err(f"Failed to create tables for local server {self.local_ts_id}: {str(e)}")
database=db_info['db_name'],
min_size=1,
max_size=10
)
return await self.pool_connections[ts_id].acquire()
async def release_connection(self, ts_id: str, connection):
await self.pool_connections[ts_id].release(connection)
async def get_online_servers(self) -> List[str]: async def get_online_servers(self) -> List[str]:
online_servers = [] online_servers = []
for db_info in self.config['POOL']: for ts_id, engine in self.engines.items():
try: try:
conn = await self.get_connection(db_info['ts_id']) async with engine.connect() as conn:
await self.release_connection(db_info['ts_id'], conn) await conn.execute(text("SELECT 1"))
online_servers.append(db_info['ts_id']) online_servers.append(ts_id)
except: except OperationalError:
pass pass
self.online_servers = set(online_servers)
return online_servers return online_servers
async def initialize_query_tracking(self): async def execute_read(self, query: str, *args, **kwargs):
conn = await self.get_connection() if self.local_ts_id not in self.sessions:
try: err(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.")
await conn.execute(""" return None
CREATE TABLE IF NOT EXISTS query_tracking (
id SERIAL PRIMARY KEY, params = self._normalize_params(args, kwargs)
ts_id TEXT NOT NULL,
query TEXT NOT NULL, async with self.sessions[self.local_ts_id]() as session:
args JSONB, try:
executed_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, result = await session.execute(text(query), params)
completed_by JSONB DEFAULT '{}'::jsonb, return result.fetchall()
result_checksum TEXT except Exception as e:
err(f"Failed to execute read query: {str(e)}")
return None
async def execute_write(self, query: str, *args, **kwargs):
if self.local_ts_id not in self.sessions:
err(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.")
return
params = self._normalize_params(args, kwargs)
async with self.sessions[self.local_ts_id]() as session:
try:
# Execute the write query
result = await session.execute(text(query), params)
# Create a serializable version of params for logging
serializable_params = {
k: v.isoformat() if isinstance(v, datetime) else v
for k, v in params.items()
}
# Log the query
new_query = QueryTracking(
ts_id=self.local_ts_id,
query=query,
args=json.dumps(serializable_params)
) )
""") session.add(new_query)
finally: await session.flush()
await self.release_connection(self.local_ts_id, conn) query_id = new_query.id
async def execute_read(self, query: str, *args): await session.commit()
conn = await self.get_connection() info(f"Successfully executed write query: {query[:50]}...")
try:
return await conn.fetch(query, *args)
finally:
await self.release_connection(self.local_ts_id, conn)
async def execute_write(self, query: str, *args): # Calculate checksum
# Execute write on local database checksum = await self._local_compute_checksum(query, params)
local_conn = await self.get_connection()
try:
await local_conn.execute(query, *args)
# Log the query # Update query_tracking with checksum
query_id = await local_conn.fetchval(""" await self.update_query_checksum(query_id, checksum)
INSERT INTO query_tracking (ts_id, query, args)
VALUES ($1, $2, $3)
RETURNING id
""", self.local_ts_id, query, json.dumps(args))
finally:
await self.release_connection(self.local_ts_id, local_conn)
# Calculate checksum # Replicate to online servers
checksum = await self.compute_checksum(query, *args) online_servers = await self.get_online_servers()
for ts_id in online_servers:
if ts_id != self.local_ts_id:
asyncio.create_task(self._replicate_write(ts_id, query_id, query, params, checksum))
# Update query_tracking with checksum except Exception as e:
await self.update_query_checksum(query_id, checksum) err(f"Failed to execute write query: {str(e)}")
return
# Replicate to online servers def _normalize_params(self, args, kwargs):
online_servers = await self.get_online_servers() if args and isinstance(args[0], dict):
for ts_id in online_servers: return args[0]
if ts_id != self.local_ts_id: elif kwargs:
asyncio.create_task(self._replicate_write(ts_id, query_id, query, args, checksum)) return kwargs
elif args:
return {f"param{i}": arg for i, arg in enumerate(args, start=1)}
else:
return {}
async def get_primary_server(self) -> str: async def get_primary_server(self) -> str:
url = urljoin(self.config['URL'], '/id') url = urljoin(self.config['URL'], '/id')
@ -376,10 +423,10 @@ class Database:
primary_ts_id = await response.text() primary_ts_id = await response.text()
return primary_ts_id.strip() return primary_ts_id.strip()
else: else:
logging.error(f"Failed to get primary server. Status: {response.status}") err(f"Failed to get primary server. Status: {response.status}")
return None return None
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logging.error(f"Error connecting to load balancer: {str(e)}") err(f"Error connecting to load balancer: {str(e)}")
return None return None
async def get_checksum_server(self) -> dict: async def get_checksum_server(self) -> dict:
@ -393,131 +440,133 @@ class Database:
return random.choice(checksum_servers) return random.choice(checksum_servers)
async def compute_checksum(self, query: str, *args): async def compute_checksum(self, query: str, params: dict):
checksum_server = await self.get_checksum_server() checksum_server = await self.get_checksum_server()
if checksum_server['ts_id'] == self.local_ts_id: if checksum_server['ts_id'] == self.local_ts_id:
return await self._local_compute_checksum(query, *args) return await self._local_compute_checksum(query, params)
else: else:
return await self._delegate_compute_checksum(checksum_server, query, *args) return await self._delegate_compute_checksum(checksum_server, query, params)
async def _local_compute_checksum(self, query: str, *args): async def _local_compute_checksum(self, query: str, params: dict):
conn = await self.get_connection() async with self.sessions[self.local_ts_id]() as session:
try: result = await session.execute(text(query), params)
result = await conn.fetch(query, *args) if result.returns_rows:
checksum = hashlib.md5(str(result).encode()).hexdigest() data = result.fetchall()
else:
# For INSERT, UPDATE, DELETE queries that don't return rows
data = str(result.rowcount) + query + str(params)
checksum = hashlib.md5(str(data).encode()).hexdigest()
return checksum return checksum
finally:
await self.release_connection(self.local_ts_id, conn)
async def _delegate_compute_checksum(self, server: dict, query: str, *args): async def _delegate_compute_checksum(self, server: Dict[str, Any], query: str, params: dict):
url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum" url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum"
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
try: try:
async with session.post(url, json={"query": query, "args": list(args)}) as response: serializable_params = {
k: v.isoformat() if isinstance(v, datetime) else v
for k, v in params.items()
}
async with session.post(url, json={"query": query, "params": serializable_params}) as response:
if response.status == 200: if response.status == 200:
result = await response.json() result = await response.json()
return result['checksum'] return result['checksum']
else: else:
logging.error(f"Failed to get checksum from {server['ts_id']}. Status: {response.status}") err(f"Failed to get checksum from {server['ts_id']}. Status: {response.status}")
return await self._local_compute_checksum(query, *args) return await self._local_compute_checksum(query, params)
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
logging.error(f"Error connecting to {server['ts_id']} for checksum: {str(e)}") err(f"Error connecting to {server['ts_id']} for checksum: {str(e)}")
return await self._local_compute_checksum(query, *args) return await self._local_compute_checksum(query, params)
async def update_query_checksum(self, query_id: int, checksum: str): async def update_query_checksum(self, query_id: int, checksum: str):
conn = await self.get_connection() async with self.sessions[self.local_ts_id]() as session:
try: await session.execute(
await conn.execute(""" text("UPDATE query_tracking SET result_checksum = :checksum WHERE id = :id"),
UPDATE query_tracking {"checksum": checksum, "id": query_id}
SET result_checksum = $1 )
WHERE id = $2 await session.commit()
""", checksum, query_id)
finally:
await self.release_connection(self.local_ts_id, conn)
async def _replicate_write(self, ts_id: str, query_id: int, query: str, args: tuple, expected_checksum: str): async def _replicate_write(self, ts_id: str, query_id: int, query: str, params: dict, expected_checksum: str):
try: try:
conn = await self.get_connection(ts_id) async with self.sessions[ts_id]() as session:
try: await session.execute(text(query), params)
await conn.execute(query, *args) actual_checksum = await self.compute_checksum(query, params)
actual_checksum = await self.compute_checksum(query, *args)
if actual_checksum != expected_checksum: if actual_checksum != expected_checksum:
raise ValueError(f"Checksum mismatch on {ts_id}") raise ValueError(f"Checksum mismatch on {ts_id}")
await self.mark_query_completed(query_id, ts_id) await self.mark_query_completed(query_id, ts_id)
finally: await session.commit()
await self.release_connection(ts_id, conn) info(f"Successfully replicated write to {ts_id}")
except Exception as e: except Exception as e:
logging.error(f"Failed to replicate write on {ts_id}: {str(e)}") err(f"Failed to replicate write on {ts_id}: {str(e)}")
async def mark_query_completed(self, query_id: int, ts_id: str): async def mark_query_completed(self, query_id: int, ts_id: str):
conn = await self.get_connection() async with self.sessions[self.local_ts_id]() as session:
try: query = await session.get(QueryTracking, query_id)
await conn.execute(""" if query:
UPDATE query_tracking completed_by = query.completed_by or {}
SET completed_by = completed_by || jsonb_build_object($1, true) completed_by[ts_id] = True
WHERE id = $2 query.completed_by = completed_by
""", ts_id, query_id) await session.commit()
finally:
await self.release_connection(self.local_ts_id, conn)
async def sync_local_server(self): async def sync_local_server(self):
conn = await self.get_connection() async with self.sessions[self.local_ts_id]() as session:
try: last_synced = await session.execute(
last_synced_id = await conn.fetchval(""" text("SELECT MAX(id) FROM query_tracking WHERE completed_by ? :ts_id"),
SELECT COALESCE(MAX(id), 0) FROM query_tracking {"ts_id": self.local_ts_id}
WHERE completed_by ? $1 )
""", self.local_ts_id) last_synced_id = last_synced.scalar() or 0
unexecuted_queries = await conn.fetch(""" unexecuted_queries = await session.execute(
SELECT id, query, args, result_checksum text("SELECT * FROM query_tracking WHERE id > :last_id ORDER BY id"),
FROM query_tracking {"last_id": last_synced_id}
WHERE id > $1 )
ORDER BY id
""", last_synced_id)
for query in unexecuted_queries: for query in unexecuted_queries:
try: try:
await conn.execute(query['query'], *json.loads(query['args'])) params = json.loads(query.args)
actual_checksum = await self.compute_checksum(query['query'], *json.loads(query['args'])) # Convert ISO format strings back to datetime objects
if actual_checksum != query['result_checksum']: for key, value in params.items():
raise ValueError(f"Checksum mismatch for query ID {query['id']}") if isinstance(value, str):
await self.mark_query_completed(query['id'], self.local_ts_id) try:
params[key] = datetime.fromisoformat(value)
except ValueError:
pass # If it's not a valid ISO format, keep it as a string
await session.execute(text(query.query), params)
actual_checksum = await self._local_compute_checksum(query.query, params)
if actual_checksum != query.result_checksum:
raise ValueError(f"Checksum mismatch for query ID {query.id}")
await self.mark_query_completed(query.id, self.local_ts_id)
except Exception as e: except Exception as e:
logging.error(f"Failed to execute query ID {query['id']} during local sync: {str(e)}") err(f"Failed to execute query ID {query.id} during local sync: {str(e)}")
logging.info(f"Local server sync completed. Executed {len(unexecuted_queries)} queries.") await session.commit()
info(f"Local server sync completed. Executed {unexecuted_queries.rowcount} queries.")
finally:
await self.release_connection(self.local_ts_id, conn)
async def purge_completed_queries(self): async def purge_completed_queries(self):
conn = await self.get_connection() async with self.sessions[self.local_ts_id]() as session:
try:
all_ts_ids = [db['ts_id'] for db in self.config['POOL']] all_ts_ids = [db['ts_id'] for db in self.config['POOL']]
result = await conn.execute("""
WITH consecutive_completed AS ( result = await session.execute(
SELECT id, text("""
row_number() OVER (ORDER BY id) AS rn DELETE FROM query_tracking
FROM query_tracking WHERE id <= (
WHERE completed_by ?& $1 SELECT MAX(id)
) FROM query_tracking
DELETE FROM query_tracking WHERE completed_by ?& :ts_ids
WHERE id IN ( )
SELECT id """),
FROM consecutive_completed {"ts_ids": all_ts_ids}
WHERE rn = (SELECT MAX(rn) FROM consecutive_completed) )
) await session.commit()
""", all_ts_ids)
deleted_count = int(result.split()[-1]) deleted_count = result.rowcount
logging.info(f"Purged {deleted_count} completed queries.") info(f"Purged {deleted_count} completed queries.")
finally:
await self.release_connection(self.local_ts_id, conn)
async def close(self): async def close(self):
for pool in self.pool_connections.values(): for engine in self.engines.values():
await pool.close() await engine.dispose()

View file

@ -0,0 +1,46 @@
import asyncio
import asyncpg
# Database connection information
DB_INFO = {
'host': '100.64.64.20',
'port': 5432,
'database': 'sij',
'user': 'sij',
'password': 'Synchr0!'
}
async def update_click_logs():
# Connect to the database
conn = await asyncpg.connect(**DB_INFO)
try:
# Drop existing 'id' and 'new_id' columns if they exist
await conn.execute("""
ALTER TABLE click_logs
DROP COLUMN IF EXISTS id,
DROP COLUMN IF EXISTS new_id;
""")
print("Dropped existing id and new_id columns (if they existed)")
# Add new UUID column as primary key
await conn.execute("""
ALTER TABLE click_logs
ADD COLUMN id UUID PRIMARY KEY DEFAULT gen_random_uuid();
""")
print("Added new UUID column as primary key")
# Get the number of rows in the table
row_count = await conn.fetchval("SELECT COUNT(*) FROM click_logs")
print(f"Number of rows in click_logs: {row_count}")
except Exception as e:
print(f"An error occurred: {str(e)}")
import traceback
traceback.print_exc()
finally:
# Close the database connection
await conn.close()
# Run the update
asyncio.run(update_click_logs())

View file

@ -18,7 +18,7 @@ from dateutil.parser import parse as dateutil_parse
from typing import Optional, List, Union from typing import Optional, List, Union
from sijapi import L, API, Db, TZ, GEO from sijapi import L, API, Db, TZ, GEO
from sijapi.classes import Location from sijapi.classes import Location
from sijapi.utilities import haversine, assemble_journal_path from sijapi.utilities import haversine, assemble_journal_path, json_serial
gis = APIRouter() gis = APIRouter()
logger = L.get_module_logger("gis") logger = L.get_module_logger("gis")
@ -122,140 +122,7 @@ async def get_last_location() -> Optional[Location]:
return None return None
async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, datetime, None] = None) -> List[Location]:
start_datetime = await dt(start)
if end is None:
end_datetime = await dt(start_datetime.replace(hour=23, minute=59, second=59))
else:
end_datetime = await dt(end) if not isinstance(end, datetime) else end
if start_datetime.time() == datetime.min.time() and end_datetime.time() == datetime.min.time():
end_datetime = await dt(end_datetime.replace(hour=23, minute=59, second=59))
debug(f"Fetching locations between {start_datetime} and {end_datetime}")
query = '''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street,
action, device_type, device_model, device_name, device_os
FROM locations
WHERE datetime >= $1 AND datetime <= $2
ORDER BY datetime DESC
'''
locations = await Db.execute_read(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None))
debug(f"Range locations query returned: {locations}")
if not locations and (end is None or start_datetime.date() == end_datetime.date()):
fallback_query = '''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street,
action, device_type, device_model, device_name, device_os
FROM locations
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
'''
location_data = await Db.execute_read(fallback_query, start_datetime.replace(tzinfo=None))
debug(f"Fallback query returned: {location_data}")
if location_data:
locations = location_data
debug(f"Locations found: {locations}")
# Sort location_data based on the datetime field in descending order
sorted_locations = sorted(locations, key=lambda x: x['datetime'], reverse=True)
# Create Location objects directly from the location data
location_objects = [
Location(
latitude=location['latitude'],
longitude=location['longitude'],
datetime=location['datetime'],
elevation=location.get('elevation'),
city=location.get('city'),
state=location.get('state'),
zip=location.get('zip'),
street=location.get('street'),
context={
'action': location.get('action'),
'device_type': location.get('device_type'),
'device_model': location.get('device_model'),
'device_name': location.get('device_name'),
'device_os': location.get('device_os')
}
) for location in sorted_locations if location['latitude'] is not None and location['longitude'] is not None
]
return location_objects if location_objects else []
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
try:
datetime = await dt(datetime)
debug(f"Fetching last location before {datetime}")
query = '''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street, country,
action
FROM locations
WHERE datetime < $1
ORDER BY datetime DESC
LIMIT 1
'''
location_data = await Db.execute_read(query, datetime.replace(tzinfo=None))
if location_data:
debug(f"Last location found: {location_data[0]}")
return Location(**location_data[0])
else:
debug("No location found before the specified datetime")
return None
except Exception as e:
error(f"Error fetching last location: {str(e)}")
return None
@gis.get("/map", response_class=HTMLResponse)
async def generate_map_endpoint(
start_date: Optional[str] = Query(None),
end_date: Optional[str] = Query(None),
max_points: int = Query(32767, description="Maximum number of points to display")
):
try:
if start_date and end_date:
start_date = await dt(start_date)
end_date = await dt(end_date)
else:
start_date, end_date = await get_date_range()
except ValueError:
raise HTTPException(status_code=400, detail="Invalid date format")
info(f"Generating map for {start_date} to {end_date}")
html_content = await generate_map(start_date, end_date, max_points)
return HTMLResponse(content=html_content)
async def get_date_range():
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
row = await Db.execute_read(query, table_name="locations")
if row and row[0]['min_date'] and row[0]['max_date']:
return row[0]['min_date'], row[0]['max_date']
else:
return datetime(2022, 1, 1), datetime.now()
async def generate_and_save_heatmap( async def generate_and_save_heatmap(
start_date: Union[str, int, datetime], start_date: Union[str, int, datetime],
@ -424,6 +291,114 @@ map.on(L.Draw.Event.CREATED, function (event) {
return m.get_root().render() return m.get_root().render()
async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, datetime, None] = None) -> List[Location]:
start_datetime = await dt(start)
if end is None:
end_datetime = await dt(start_datetime.replace(hour=23, minute=59, second=59))
else:
end_datetime = await dt(end) if not isinstance(end, datetime) else end
if start_datetime.time() == datetime.min.time() and end_datetime.time() == datetime.min.time():
end_datetime = await dt(end_datetime.replace(hour=23, minute=59, second=59))
debug(f"Fetching locations between {start_datetime} and {end_datetime}")
query = '''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street,
action, device_type, device_model, device_name, device_os
FROM locations
WHERE datetime >= :start_datetime AND datetime <= :end_datetime
ORDER BY datetime DESC
'''
locations = await Db.execute_read(query, start_datetime=start_datetime.replace(tzinfo=None), end_datetime=end_datetime.replace(tzinfo=None))
debug(f"Range locations query returned: {locations}")
if not locations and (end is None or start_datetime.date() == end_datetime.date()):
fallback_query = '''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street,
action, device_type, device_model, device_name, device_os
FROM locations
WHERE datetime < :start_datetime
ORDER BY datetime DESC
LIMIT 1
'''
location_data = await Db.execute_read(fallback_query, start_datetime=start_datetime.replace(tzinfo=None))
debug(f"Fallback query returned: {location_data}")
if location_data:
locations = location_data
debug(f"Locations found: {locations}")
# Sort location_data based on the datetime field in descending order
sorted_locations = sorted(locations, key=lambda x: x['datetime'], reverse=True)
# Create Location objects directly from the location data
location_objects = [
Location(
latitude=location['latitude'],
longitude=location['longitude'],
datetime=location['datetime'],
elevation=location.get('elevation'),
city=location.get('city'),
state=location.get('state'),
zip=location.get('zip'),
street=location.get('street'),
context={
'action': location.get('action'),
'device_type': location.get('device_type'),
'device_model': location.get('device_model'),
'device_name': location.get('device_name'),
'device_os': location.get('device_os')
}
) for location in sorted_locations if location['latitude'] is not None and location['longitude'] is not None
]
return location_objects if location_objects else []
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
try:
datetime = await dt(datetime)
debug(f"Fetching last location before {datetime}")
query = '''
SELECT id, datetime,
ST_X(ST_AsText(location)::geometry) AS longitude,
ST_Y(ST_AsText(location)::geometry) AS latitude,
ST_Z(ST_AsText(location)::geometry) AS elevation,
city, state, zip, street, country,
action
FROM locations
WHERE datetime < :datetime
ORDER BY datetime DESC
LIMIT 1
'''
location_data = await Db.execute_read(query, datetime=datetime.replace(tzinfo=None))
if location_data:
debug(f"Last location found: {location_data[0]}")
return Location(**location_data[0])
else:
debug("No location found before the specified datetime")
return None
except Exception as e:
error(f"Error fetching last location: {str(e)}")
return None
async def post_location(location: Location): async def post_location(location: Location):
try: try:
context = location.context or {} context = location.context or {}
@ -442,24 +417,16 @@ async def post_location(location: Location):
class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood, class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood,
suburb, county, country_code, country suburb, county, country_code, country
) )
VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, VALUES (:datetime, ST_SetSRID(ST_MakePoint(:longitude, :latitude, :elevation), 4326), :city, :state, :zip,
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) :street, :action, :device_type, :device_model, :device_name, :device_os, :class_, :type, :name,
:display_name, :amenity, :house_number, :road, :quarter, :neighbourhood, :suburb, :county,
:country_code, :country)
''' '''
await Db.execute_write( params = {
query,
localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
location.zip, location.street, action, device_type, device_model, device_name, device_os,
location.class_, location.type, location.name, location.display_name,
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
location.suburb, location.county, location.country_code, location.country
)
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")
return {
'datetime': localized_datetime, 'datetime': localized_datetime,
'latitude': location.latitude,
'longitude': location.longitude, 'longitude': location.longitude,
'latitude': location.latitude,
'elevation': location.elevation, 'elevation': location.elevation,
'city': location.city, 'city': location.city,
'state': location.state, 'state': location.state,
@ -484,12 +451,34 @@ async def post_location(location: Location):
'country_code': location.country_code, 'country_code': location.country_code,
'country': location.country 'country': location.country
} }
await Db.execute_write(query, **params)
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")
# Create a serializable version of params for the return value
serializable_params = {
k: v.isoformat() if isinstance(v, datetime) else v
for k, v in params.items()
}
return serializable_params
except Exception as e: except Exception as e:
err(f"Error posting location {e}") err(f"Error posting location {e}")
err(traceback.format_exc()) err(traceback.format_exc())
return None return None
async def get_date_range():
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
row = await Db.execute_read(query)
if row and row[0]['min_date'] and row[0]['max_date']:
return row[0]['min_date'], row[0]['max_date']
else:
return datetime(2022, 1, 1), datetime.now()
@gis.post("/locate") @gis.post("/locate")
async def post_locate_endpoint(locations: Union[Location, List[Location]]): async def post_locate_endpoint(locations: Union[Location, List[Location]]):
if isinstance(locations, Location): if isinstance(locations, Location):
@ -559,3 +548,23 @@ async def get_locate(datetime_str: str, all: bool = False):
raise HTTPException(status_code=404, detail="No nearby data found for this date and time") raise HTTPException(status_code=404, detail="No nearby data found for this date and time")
return locations if all else [locations[0]] return locations if all else [locations[0]]
@gis.get("/map", response_class=HTMLResponse)
async def generate_map_endpoint(
start_date: Optional[str] = Query(None),
end_date: Optional[str] = Query(None),
max_points: int = Query(32767, description="Maximum number of points to display")
):
try:
if start_date and end_date:
start_date = await dt(start_date)
end_date = await dt(end_date)
else:
start_date, end_date = await get_date_range()
except ValueError:
raise HTTPException(status_code=400, detail="Invalid date format")
info(f"Generating map for {start_date} to {end_date}")
html_content = await generate_map(start_date, end_date, max_points)
return HTMLResponse(content=html_content)

View file

@ -26,7 +26,7 @@ import pytesseract
from readability import Document from readability import Document
from pdf2image import convert_from_path from pdf2image import convert_from_path
from datetime import datetime as dt_datetime, date, time from datetime import datetime as dt_datetime, date, time
from typing import Optional, Union, Tuple, List from typing import Optional, Union, Tuple, List, Any
import asyncio import asyncio
from PIL import Image from PIL import Image
import pandas as pd import pandas as pd
@ -629,3 +629,34 @@ async def html_to_markdown(url: str = None, source: str = None) -> Optional[str]
markdown_content = md(str(soup), heading_style="ATX") markdown_content = md(str(soup), heading_style="ATX")
return markdown_content return markdown_content
def json_serial(obj: Any) -> Any:
"""JSON serializer for objects not serializable by default json code"""
if isinstance(obj, (datetime, date)):
return obj.isoformat()
if isinstance(obj, time):
return obj.isoformat()
if isinstance(obj, Decimal):
return float(obj)
if isinstance(obj, UUID):
return str(obj)
if isinstance(obj, bytes):
return obj.decode('utf-8')
if isinstance(obj, Path):
return str(obj)
if hasattr(obj, '__dict__'):
return obj.__dict__
raise TypeError(f"Type {type(obj)} not serializable")
def json_dumps(obj: Any) -> str:
"""
Serialize obj to a JSON formatted str using the custom serializer.
"""
return json.dumps(obj, default=json_serial)
def json_loads(json_str: str) -> Any:
"""
Deserialize json_str to a Python object.
"""
return json.loads(json_str)