diff --git a/sijapi/__init__.py b/sijapi/__init__.py index ab6db13..79121bb 100644 --- a/sijapi/__init__.py +++ b/sijapi/__init__.py @@ -26,8 +26,7 @@ Db = Database.load('sys') # HOST = f"{API.BIND}:{API.PORT}" # LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost'] - -SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255') +# SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255') MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count()) diff --git a/sijapi/__main__.py b/sijapi/__main__.py index b8df014..d355347 100755 --- a/sijapi/__main__.py +++ b/sijapi/__main__.py @@ -21,7 +21,7 @@ from dotenv import load_dotenv from pathlib import Path from datetime import datetime import argparse -from . import L, API, ROUTER_DIR +from . import L, API, Db, ROUTER_DIR parser = argparse.ArgumentParser(description='Personal API.') parser.add_argument('--log', type=str, default='INFO', help='Set overall log level (e.g., DEBUG, INFO, WARNING)') @@ -55,7 +55,8 @@ async def lifespan(app: FastAPI): try: # Initialize sync structures on all databases - await API.initialize_sync() + # await API.initialize_sync() + await Db.initialize_engines() except Exception as e: crit(f"Error during startup: {str(e)}") @@ -99,16 +100,18 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware): api_key_header = request.headers.get("Authorization") api_key_query = request.query_params.get("api_key") - # Debug logging for API keys - debug(f"API.KEYS: {API.KEYS}") + # Convert API.KEYS to lowercase for case-insensitive comparison + api_keys_lower = [key.lower() for key in API.KEYS] + debug(f"API.KEYS (lowercase): {api_keys_lower}") if api_key_header: api_key_header = api_key_header.lower().split("bearer ")[-1] debug(f"API key provided in header: {api_key_header}") if api_key_query: + api_key_query = api_key_query.lower() debug(f"API key provided in query: {api_key_query}") - if api_key_header not in API.KEYS and api_key_query not in API.KEYS: + if api_key_header.lower() not in api_keys_lower and api_key_query.lower() not in api_keys_lower: err(f"Invalid API key provided by a requester.") if api_key_header: debug(f"Invalid API key in header: {api_key_header}") @@ -119,9 +122,9 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware): content={"detail": "Invalid or missing API key"} ) else: - if api_key_header in API.KEYS: + if api_key_header.lower() in api_keys_lower: debug(f"Valid API key provided in header: {api_key_header}") - if api_key_query in API.KEYS: + if api_key_query and api_key_query.lower() in api_keys_lower: debug(f"Valid API key provided in query: {api_key_query}") response = await call_next(request) diff --git a/sijapi/classes.py b/sijapi/classes.py index d0a0d1e..6f09f9e 100644 --- a/sijapi/classes.py +++ b/sijapi/classes.py @@ -27,6 +27,17 @@ from srtm import get_data import os import sys from loguru import logger +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession +from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy import Column, Integer, String, DateTime, JSON, Text, select, func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.exc import OperationalError +from urllib.parse import urljoin +import hashlib +import random + +Base = declarative_base() # Custom logger class class Logger: @@ -258,14 +269,30 @@ class DirConfig: +class QueryTracking(Base): + __tablename__ = 'query_tracking' + + id = Column(Integer, primary_key=True) + ts_id = Column(String, nullable=False) + query = Column(Text, nullable=False) + args = Column(JSONB) + executed_at = Column(DateTime(timezone=True), server_default=func.now()) + completed_by = Column(JSONB, default={}) + result_checksum = Column(String) + + + + class Database: @classmethod def load(cls, config_name: str): return cls(config_name) - + def __init__(self, config_path: str): self.config = self.load_config(config_path) - self.pool_connections = {} + self.engines: Dict[str, Any] = {} + self.sessions: Dict[str, Any] = {} + self.online_servers: set = set() self.local_ts_id = self.get_local_ts_id() def load_config(self, config_path: str) -> Dict[str, Any]: @@ -280,91 +307,111 @@ class Database: def get_local_ts_id(self) -> str: return os.environ.get('TS_ID') - async def get_connection(self, ts_id: str = None): - if ts_id is None: - ts_id = self.local_ts_id + async def initialize_engines(self): + for db_info in self.config['POOL']: + url = f"postgresql+asyncpg://{db_info['db_user']}:{db_info['db_pass']}@{db_info['ts_ip']}:{db_info['db_port']}/{db_info['db_name']}" + try: + engine = create_async_engine(url, pool_pre_ping=True, pool_size=5, max_overflow=10) + self.engines[db_info['ts_id']] = engine + self.sessions[db_info['ts_id']] = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + info(f"Initialized engine and session for {db_info['ts_id']}") + except Exception as e: + err(f"Failed to initialize engine for {db_info['ts_id']}: {str(e)}") - if ts_id not in self.pool_connections: - db_info = next((db for db in self.config['POOL'] if db['ts_id'] == ts_id), None) - if db_info is None: - raise ValueError(f"No database configuration found for TS_ID: {ts_id}") - - self.pool_connections[ts_id] = await asyncpg.create_pool( - host=db_info['ts_ip'], - port=db_info['db_port'], - user=db_info['db_user'], - password=db_info['db_pass'], - database=db_info['db_name'], - min_size=1, - max_size=10 - ) - - return await self.pool_connections[ts_id].acquire() - - async def release_connection(self, ts_id: str, connection): - await self.pool_connections[ts_id].release(connection) + if self.local_ts_id not in self.sessions: + err(f"Failed to initialize session for local server {self.local_ts_id}") + else: + try: + # Create tables if they don't exist + async with self.engines[self.local_ts_id].begin() as conn: + await conn.run_sync(Base.metadata.create_all) + info(f"Initialized tables for local server {self.local_ts_id}") + except Exception as e: + err(f"Failed to create tables for local server {self.local_ts_id}: {str(e)}") async def get_online_servers(self) -> List[str]: online_servers = [] - for db_info in self.config['POOL']: + for ts_id, engine in self.engines.items(): try: - conn = await self.get_connection(db_info['ts_id']) - await self.release_connection(db_info['ts_id'], conn) - online_servers.append(db_info['ts_id']) - except: + async with engine.connect() as conn: + await conn.execute(text("SELECT 1")) + online_servers.append(ts_id) + except OperationalError: pass + self.online_servers = set(online_servers) return online_servers - async def initialize_query_tracking(self): - conn = await self.get_connection() - try: - await conn.execute(""" - CREATE TABLE IF NOT EXISTS query_tracking ( - id SERIAL PRIMARY KEY, - ts_id TEXT NOT NULL, - query TEXT NOT NULL, - args JSONB, - executed_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, - completed_by JSONB DEFAULT '{}'::jsonb, - result_checksum TEXT + async def execute_read(self, query: str, *args, **kwargs): + if self.local_ts_id not in self.sessions: + err(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.") + return None + + params = self._normalize_params(args, kwargs) + + async with self.sessions[self.local_ts_id]() as session: + try: + result = await session.execute(text(query), params) + return result.fetchall() + except Exception as e: + err(f"Failed to execute read query: {str(e)}") + return None + + async def execute_write(self, query: str, *args, **kwargs): + if self.local_ts_id not in self.sessions: + err(f"No session found for local server {self.local_ts_id}. Database may not be properly initialized.") + return + + params = self._normalize_params(args, kwargs) + + async with self.sessions[self.local_ts_id]() as session: + try: + # Execute the write query + result = await session.execute(text(query), params) + + # Create a serializable version of params for logging + serializable_params = { + k: v.isoformat() if isinstance(v, datetime) else v + for k, v in params.items() + } + + # Log the query + new_query = QueryTracking( + ts_id=self.local_ts_id, + query=query, + args=json.dumps(serializable_params) ) - """) - finally: - await self.release_connection(self.local_ts_id, conn) + session.add(new_query) + await session.flush() + query_id = new_query.id - async def execute_read(self, query: str, *args): - conn = await self.get_connection() - try: - return await conn.fetch(query, *args) - finally: - await self.release_connection(self.local_ts_id, conn) + await session.commit() + info(f"Successfully executed write query: {query[:50]}...") + + # Calculate checksum + checksum = await self._local_compute_checksum(query, params) - async def execute_write(self, query: str, *args): - # Execute write on local database - local_conn = await self.get_connection() - try: - await local_conn.execute(query, *args) - - # Log the query - query_id = await local_conn.fetchval(""" - INSERT INTO query_tracking (ts_id, query, args) - VALUES ($1, $2, $3) - RETURNING id - """, self.local_ts_id, query, json.dumps(args)) - finally: - await self.release_connection(self.local_ts_id, local_conn) + # Update query_tracking with checksum + await self.update_query_checksum(query_id, checksum) - # Calculate checksum - checksum = await self.compute_checksum(query, *args) + # Replicate to online servers + online_servers = await self.get_online_servers() + for ts_id in online_servers: + if ts_id != self.local_ts_id: + asyncio.create_task(self._replicate_write(ts_id, query_id, query, params, checksum)) - # Update query_tracking with checksum - await self.update_query_checksum(query_id, checksum) + except Exception as e: + err(f"Failed to execute write query: {str(e)}") + return - # Replicate to online servers - online_servers = await self.get_online_servers() - for ts_id in online_servers: - if ts_id != self.local_ts_id: - asyncio.create_task(self._replicate_write(ts_id, query_id, query, args, checksum)) + def _normalize_params(self, args, kwargs): + if args and isinstance(args[0], dict): + return args[0] + elif kwargs: + return kwargs + elif args: + return {f"param{i}": arg for i, arg in enumerate(args, start=1)} + else: + return {} async def get_primary_server(self) -> str: url = urljoin(self.config['URL'], '/id') @@ -376,10 +423,10 @@ class Database: primary_ts_id = await response.text() return primary_ts_id.strip() else: - logging.error(f"Failed to get primary server. Status: {response.status}") + err(f"Failed to get primary server. Status: {response.status}") return None except aiohttp.ClientError as e: - logging.error(f"Error connecting to load balancer: {str(e)}") + err(f"Error connecting to load balancer: {str(e)}") return None async def get_checksum_server(self) -> dict: @@ -393,132 +440,134 @@ class Database: return random.choice(checksum_servers) - async def compute_checksum(self, query: str, *args): + async def compute_checksum(self, query: str, params: dict): checksum_server = await self.get_checksum_server() if checksum_server['ts_id'] == self.local_ts_id: - return await self._local_compute_checksum(query, *args) + return await self._local_compute_checksum(query, params) else: - return await self._delegate_compute_checksum(checksum_server, query, *args) + return await self._delegate_compute_checksum(checksum_server, query, params) - async def _local_compute_checksum(self, query: str, *args): - conn = await self.get_connection() - try: - result = await conn.fetch(query, *args) - checksum = hashlib.md5(str(result).encode()).hexdigest() + async def _local_compute_checksum(self, query: str, params: dict): + async with self.sessions[self.local_ts_id]() as session: + result = await session.execute(text(query), params) + if result.returns_rows: + data = result.fetchall() + else: + # For INSERT, UPDATE, DELETE queries that don't return rows + data = str(result.rowcount) + query + str(params) + checksum = hashlib.md5(str(data).encode()).hexdigest() return checksum - finally: - await self.release_connection(self.local_ts_id, conn) - async def _delegate_compute_checksum(self, server: dict, query: str, *args): + async def _delegate_compute_checksum(self, server: Dict[str, Any], query: str, params: dict): url = f"http://{server['ts_ip']}:{server['app_port']}/sync/checksum" async with aiohttp.ClientSession() as session: try: - async with session.post(url, json={"query": query, "args": list(args)}) as response: + serializable_params = { + k: v.isoformat() if isinstance(v, datetime) else v + for k, v in params.items() + } + async with session.post(url, json={"query": query, "params": serializable_params}) as response: if response.status == 200: result = await response.json() return result['checksum'] else: - logging.error(f"Failed to get checksum from {server['ts_id']}. Status: {response.status}") - return await self._local_compute_checksum(query, *args) + err(f"Failed to get checksum from {server['ts_id']}. Status: {response.status}") + return await self._local_compute_checksum(query, params) except aiohttp.ClientError as e: - logging.error(f"Error connecting to {server['ts_id']} for checksum: {str(e)}") - return await self._local_compute_checksum(query, *args) + err(f"Error connecting to {server['ts_id']} for checksum: {str(e)}") + return await self._local_compute_checksum(query, params) async def update_query_checksum(self, query_id: int, checksum: str): - conn = await self.get_connection() - try: - await conn.execute(""" - UPDATE query_tracking - SET result_checksum = $1 - WHERE id = $2 - """, checksum, query_id) - finally: - await self.release_connection(self.local_ts_id, conn) + async with self.sessions[self.local_ts_id]() as session: + await session.execute( + text("UPDATE query_tracking SET result_checksum = :checksum WHERE id = :id"), + {"checksum": checksum, "id": query_id} + ) + await session.commit() - async def _replicate_write(self, ts_id: str, query_id: int, query: str, args: tuple, expected_checksum: str): + async def _replicate_write(self, ts_id: str, query_id: int, query: str, params: dict, expected_checksum: str): try: - conn = await self.get_connection(ts_id) - try: - await conn.execute(query, *args) - actual_checksum = await self.compute_checksum(query, *args) + async with self.sessions[ts_id]() as session: + await session.execute(text(query), params) + actual_checksum = await self.compute_checksum(query, params) if actual_checksum != expected_checksum: raise ValueError(f"Checksum mismatch on {ts_id}") await self.mark_query_completed(query_id, ts_id) - finally: - await self.release_connection(ts_id, conn) + await session.commit() + info(f"Successfully replicated write to {ts_id}") except Exception as e: - logging.error(f"Failed to replicate write on {ts_id}: {str(e)}") + err(f"Failed to replicate write on {ts_id}: {str(e)}") async def mark_query_completed(self, query_id: int, ts_id: str): - conn = await self.get_connection() - try: - await conn.execute(""" - UPDATE query_tracking - SET completed_by = completed_by || jsonb_build_object($1, true) - WHERE id = $2 - """, ts_id, query_id) - finally: - await self.release_connection(self.local_ts_id, conn) + async with self.sessions[self.local_ts_id]() as session: + query = await session.get(QueryTracking, query_id) + if query: + completed_by = query.completed_by or {} + completed_by[ts_id] = True + query.completed_by = completed_by + await session.commit() async def sync_local_server(self): - conn = await self.get_connection() - try: - last_synced_id = await conn.fetchval(""" - SELECT COALESCE(MAX(id), 0) FROM query_tracking - WHERE completed_by ? $1 - """, self.local_ts_id) + async with self.sessions[self.local_ts_id]() as session: + last_synced = await session.execute( + text("SELECT MAX(id) FROM query_tracking WHERE completed_by ? :ts_id"), + {"ts_id": self.local_ts_id} + ) + last_synced_id = last_synced.scalar() or 0 - unexecuted_queries = await conn.fetch(""" - SELECT id, query, args, result_checksum - FROM query_tracking - WHERE id > $1 - ORDER BY id - """, last_synced_id) + unexecuted_queries = await session.execute( + text("SELECT * FROM query_tracking WHERE id > :last_id ORDER BY id"), + {"last_id": last_synced_id} + ) for query in unexecuted_queries: try: - await conn.execute(query['query'], *json.loads(query['args'])) - actual_checksum = await self.compute_checksum(query['query'], *json.loads(query['args'])) - if actual_checksum != query['result_checksum']: - raise ValueError(f"Checksum mismatch for query ID {query['id']}") - await self.mark_query_completed(query['id'], self.local_ts_id) + params = json.loads(query.args) + # Convert ISO format strings back to datetime objects + for key, value in params.items(): + if isinstance(value, str): + try: + params[key] = datetime.fromisoformat(value) + except ValueError: + pass # If it's not a valid ISO format, keep it as a string + await session.execute(text(query.query), params) + actual_checksum = await self._local_compute_checksum(query.query, params) + if actual_checksum != query.result_checksum: + raise ValueError(f"Checksum mismatch for query ID {query.id}") + await self.mark_query_completed(query.id, self.local_ts_id) except Exception as e: - logging.error(f"Failed to execute query ID {query['id']} during local sync: {str(e)}") + err(f"Failed to execute query ID {query.id} during local sync: {str(e)}") - logging.info(f"Local server sync completed. Executed {len(unexecuted_queries)} queries.") - - finally: - await self.release_connection(self.local_ts_id, conn) + await session.commit() + info(f"Local server sync completed. Executed {unexecuted_queries.rowcount} queries.") async def purge_completed_queries(self): - conn = await self.get_connection() - try: + async with self.sessions[self.local_ts_id]() as session: all_ts_ids = [db['ts_id'] for db in self.config['POOL']] - result = await conn.execute(""" - WITH consecutive_completed AS ( - SELECT id, - row_number() OVER (ORDER BY id) AS rn - FROM query_tracking - WHERE completed_by ?& $1 - ) - DELETE FROM query_tracking - WHERE id IN ( - SELECT id - FROM consecutive_completed - WHERE rn = (SELECT MAX(rn) FROM consecutive_completed) - ) - """, all_ts_ids) - deleted_count = int(result.split()[-1]) - logging.info(f"Purged {deleted_count} completed queries.") - finally: - await self.release_connection(self.local_ts_id, conn) + + result = await session.execute( + text(""" + DELETE FROM query_tracking + WHERE id <= ( + SELECT MAX(id) + FROM query_tracking + WHERE completed_by ?& :ts_ids + ) + """), + {"ts_ids": all_ts_ids} + ) + await session.commit() + + deleted_count = result.rowcount + info(f"Purged {deleted_count} completed queries.") async def close(self): - for pool in self.pool_connections.values(): - await pool.close() - + for engine in self.engines.values(): + await engine.dispose() + + # Configuration class for API & Database methods. diff --git a/sijapi/helpers/db_uuid_migrate.py b/sijapi/helpers/db_uuid_migrate.py new file mode 100644 index 0000000..b7bb5d1 --- /dev/null +++ b/sijapi/helpers/db_uuid_migrate.py @@ -0,0 +1,46 @@ +import asyncio +import asyncpg + +# Database connection information +DB_INFO = { + 'host': '100.64.64.20', + 'port': 5432, + 'database': 'sij', + 'user': 'sij', + 'password': 'Synchr0!' +} + +async def update_click_logs(): + # Connect to the database + conn = await asyncpg.connect(**DB_INFO) + + try: + # Drop existing 'id' and 'new_id' columns if they exist + await conn.execute(""" + ALTER TABLE click_logs + DROP COLUMN IF EXISTS id, + DROP COLUMN IF EXISTS new_id; + """) + print("Dropped existing id and new_id columns (if they existed)") + + # Add new UUID column as primary key + await conn.execute(""" + ALTER TABLE click_logs + ADD COLUMN id UUID PRIMARY KEY DEFAULT gen_random_uuid(); + """) + print("Added new UUID column as primary key") + + # Get the number of rows in the table + row_count = await conn.fetchval("SELECT COUNT(*) FROM click_logs") + print(f"Number of rows in click_logs: {row_count}") + + except Exception as e: + print(f"An error occurred: {str(e)}") + import traceback + traceback.print_exc() + finally: + # Close the database connection + await conn.close() + +# Run the update +asyncio.run(update_click_logs()) diff --git a/sijapi/routers/gis.py b/sijapi/routers/gis.py index 363f8c4..f7bfc90 100644 --- a/sijapi/routers/gis.py +++ b/sijapi/routers/gis.py @@ -18,7 +18,7 @@ from dateutil.parser import parse as dateutil_parse from typing import Optional, List, Union from sijapi import L, API, Db, TZ, GEO from sijapi.classes import Location -from sijapi.utilities import haversine, assemble_journal_path +from sijapi.utilities import haversine, assemble_journal_path, json_serial gis = APIRouter() logger = L.get_module_logger("gis") @@ -122,140 +122,7 @@ async def get_last_location() -> Optional[Location]: return None -async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, datetime, None] = None) -> List[Location]: - start_datetime = await dt(start) - if end is None: - end_datetime = await dt(start_datetime.replace(hour=23, minute=59, second=59)) - else: - end_datetime = await dt(end) if not isinstance(end, datetime) else end - if start_datetime.time() == datetime.min.time() and end_datetime.time() == datetime.min.time(): - end_datetime = await dt(end_datetime.replace(hour=23, minute=59, second=59)) - - debug(f"Fetching locations between {start_datetime} and {end_datetime}") - - query = ''' - SELECT id, datetime, - ST_X(ST_AsText(location)::geometry) AS longitude, - ST_Y(ST_AsText(location)::geometry) AS latitude, - ST_Z(ST_AsText(location)::geometry) AS elevation, - city, state, zip, street, - action, device_type, device_model, device_name, device_os - FROM locations - WHERE datetime >= $1 AND datetime <= $2 - ORDER BY datetime DESC - ''' - - locations = await Db.execute_read(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None)) - - debug(f"Range locations query returned: {locations}") - - if not locations and (end is None or start_datetime.date() == end_datetime.date()): - fallback_query = ''' - SELECT id, datetime, - ST_X(ST_AsText(location)::geometry) AS longitude, - ST_Y(ST_AsText(location)::geometry) AS latitude, - ST_Z(ST_AsText(location)::geometry) AS elevation, - city, state, zip, street, - action, device_type, device_model, device_name, device_os - FROM locations - WHERE datetime < $1 - ORDER BY datetime DESC - LIMIT 1 - ''' - location_data = await Db.execute_read(fallback_query, start_datetime.replace(tzinfo=None)) - debug(f"Fallback query returned: {location_data}") - if location_data: - locations = location_data - - debug(f"Locations found: {locations}") - - # Sort location_data based on the datetime field in descending order - sorted_locations = sorted(locations, key=lambda x: x['datetime'], reverse=True) - - # Create Location objects directly from the location data - location_objects = [ - Location( - latitude=location['latitude'], - longitude=location['longitude'], - datetime=location['datetime'], - elevation=location.get('elevation'), - city=location.get('city'), - state=location.get('state'), - zip=location.get('zip'), - street=location.get('street'), - context={ - 'action': location.get('action'), - 'device_type': location.get('device_type'), - 'device_model': location.get('device_model'), - 'device_name': location.get('device_name'), - 'device_os': location.get('device_os') - } - ) for location in sorted_locations if location['latitude'] is not None and location['longitude'] is not None - ] - - return location_objects if location_objects else [] - - -async def fetch_last_location_before(datetime: datetime) -> Optional[Location]: - try: - datetime = await dt(datetime) - - debug(f"Fetching last location before {datetime}") - - query = ''' - SELECT id, datetime, - ST_X(ST_AsText(location)::geometry) AS longitude, - ST_Y(ST_AsText(location)::geometry) AS latitude, - ST_Z(ST_AsText(location)::geometry) AS elevation, - city, state, zip, street, country, - action - FROM locations - WHERE datetime < $1 - ORDER BY datetime DESC - LIMIT 1 - ''' - - location_data = await Db.execute_read(query, datetime.replace(tzinfo=None)) - - if location_data: - debug(f"Last location found: {location_data[0]}") - return Location(**location_data[0]) - else: - debug("No location found before the specified datetime") - return None - except Exception as e: - error(f"Error fetching last location: {str(e)}") - return None - - - -@gis.get("/map", response_class=HTMLResponse) -async def generate_map_endpoint( - start_date: Optional[str] = Query(None), - end_date: Optional[str] = Query(None), - max_points: int = Query(32767, description="Maximum number of points to display") -): - try: - if start_date and end_date: - start_date = await dt(start_date) - end_date = await dt(end_date) - else: - start_date, end_date = await get_date_range() - except ValueError: - raise HTTPException(status_code=400, detail="Invalid date format") - - info(f"Generating map for {start_date} to {end_date}") - html_content = await generate_map(start_date, end_date, max_points) - return HTMLResponse(content=html_content) - -async def get_date_range(): - query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations" - row = await Db.execute_read(query, table_name="locations") - if row and row[0]['min_date'] and row[0]['max_date']: - return row[0]['min_date'], row[0]['max_date'] - else: - return datetime(2022, 1, 1), datetime.now() async def generate_and_save_heatmap( start_date: Union[str, int, datetime], @@ -424,6 +291,114 @@ map.on(L.Draw.Event.CREATED, function (event) { return m.get_root().render() +async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int, datetime, None] = None) -> List[Location]: + start_datetime = await dt(start) + if end is None: + end_datetime = await dt(start_datetime.replace(hour=23, minute=59, second=59)) + else: + end_datetime = await dt(end) if not isinstance(end, datetime) else end + + if start_datetime.time() == datetime.min.time() and end_datetime.time() == datetime.min.time(): + end_datetime = await dt(end_datetime.replace(hour=23, minute=59, second=59)) + + debug(f"Fetching locations between {start_datetime} and {end_datetime}") + + query = ''' + SELECT id, datetime, + ST_X(ST_AsText(location)::geometry) AS longitude, + ST_Y(ST_AsText(location)::geometry) AS latitude, + ST_Z(ST_AsText(location)::geometry) AS elevation, + city, state, zip, street, + action, device_type, device_model, device_name, device_os + FROM locations + WHERE datetime >= :start_datetime AND datetime <= :end_datetime + ORDER BY datetime DESC + ''' + + locations = await Db.execute_read(query, start_datetime=start_datetime.replace(tzinfo=None), end_datetime=end_datetime.replace(tzinfo=None)) + + debug(f"Range locations query returned: {locations}") + + if not locations and (end is None or start_datetime.date() == end_datetime.date()): + fallback_query = ''' + SELECT id, datetime, + ST_X(ST_AsText(location)::geometry) AS longitude, + ST_Y(ST_AsText(location)::geometry) AS latitude, + ST_Z(ST_AsText(location)::geometry) AS elevation, + city, state, zip, street, + action, device_type, device_model, device_name, device_os + FROM locations + WHERE datetime < :start_datetime + ORDER BY datetime DESC + LIMIT 1 + ''' + location_data = await Db.execute_read(fallback_query, start_datetime=start_datetime.replace(tzinfo=None)) + debug(f"Fallback query returned: {location_data}") + if location_data: + locations = location_data + + debug(f"Locations found: {locations}") + + # Sort location_data based on the datetime field in descending order + sorted_locations = sorted(locations, key=lambda x: x['datetime'], reverse=True) + + # Create Location objects directly from the location data + location_objects = [ + Location( + latitude=location['latitude'], + longitude=location['longitude'], + datetime=location['datetime'], + elevation=location.get('elevation'), + city=location.get('city'), + state=location.get('state'), + zip=location.get('zip'), + street=location.get('street'), + context={ + 'action': location.get('action'), + 'device_type': location.get('device_type'), + 'device_model': location.get('device_model'), + 'device_name': location.get('device_name'), + 'device_os': location.get('device_os') + } + ) for location in sorted_locations if location['latitude'] is not None and location['longitude'] is not None + ] + + return location_objects if location_objects else [] + + +async def fetch_last_location_before(datetime: datetime) -> Optional[Location]: + try: + datetime = await dt(datetime) + + debug(f"Fetching last location before {datetime}") + + query = ''' + SELECT id, datetime, + ST_X(ST_AsText(location)::geometry) AS longitude, + ST_Y(ST_AsText(location)::geometry) AS latitude, + ST_Z(ST_AsText(location)::geometry) AS elevation, + city, state, zip, street, country, + action + FROM locations + WHERE datetime < :datetime + ORDER BY datetime DESC + LIMIT 1 + ''' + + location_data = await Db.execute_read(query, datetime=datetime.replace(tzinfo=None)) + + if location_data: + debug(f"Last location found: {location_data[0]}") + return Location(**location_data[0]) + else: + debug("No location found before the specified datetime") + return None + except Exception as e: + error(f"Error fetching last location: {str(e)}") + return None + + + async def post_location(location: Location): try: context = location.context or {} @@ -435,31 +410,23 @@ async def post_location(location: Location): # Parse and localize the datetime localized_datetime = await dt(location.datetime) - + query = ''' INSERT INTO locations ( datetime, location, city, state, zip, street, action, device_type, device_model, device_name, device_os, class_, type, name, display_name, amenity, house_number, road, quarter, neighbourhood, suburb, county, country_code, country ) - VALUES ($1, ST_SetSRID(ST_MakePoint($2, $3, $4), 4326), $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, - $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) + VALUES (:datetime, ST_SetSRID(ST_MakePoint(:longitude, :latitude, :elevation), 4326), :city, :state, :zip, + :street, :action, :device_type, :device_model, :device_name, :device_os, :class_, :type, :name, + :display_name, :amenity, :house_number, :road, :quarter, :neighbourhood, :suburb, :county, + :country_code, :country) ''' - await Db.execute_write( - query, - localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state, - location.zip, location.street, action, device_type, device_model, device_name, device_os, - location.class_, location.type, location.name, location.display_name, - location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood, - location.suburb, location.county, location.country_code, location.country - ) - - info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}") - return { + params = { 'datetime': localized_datetime, - 'latitude': location.latitude, 'longitude': location.longitude, + 'latitude': location.latitude, 'elevation': location.elevation, 'city': location.city, 'state': location.state, @@ -484,12 +451,34 @@ async def post_location(location: Location): 'country_code': location.country_code, 'country': location.country } + + await Db.execute_write(query, **params) + + info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}") + + # Create a serializable version of params for the return value + serializable_params = { + k: v.isoformat() if isinstance(v, datetime) else v + for k, v in params.items() + } + return serializable_params except Exception as e: err(f"Error posting location {e}") err(traceback.format_exc()) return None + +async def get_date_range(): + query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations" + row = await Db.execute_read(query) + if row and row[0]['min_date'] and row[0]['max_date']: + return row[0]['min_date'], row[0]['max_date'] + else: + return datetime(2022, 1, 1), datetime.now() + + + @gis.post("/locate") async def post_locate_endpoint(locations: Union[Location, List[Location]]): if isinstance(locations, Location): @@ -532,8 +521,8 @@ async def post_locate_endpoint(locations: Union[Location, List[Location]]): return {"message": "Locations and weather updated", "results": responses} - - + + @gis.get("/locate", response_model=Location) async def get_last_location_endpoint() -> JSONResponse: this_location = await get_last_location() @@ -544,8 +533,8 @@ async def get_last_location_endpoint() -> JSONResponse: else: raise HTTPException(status_code=404, detail="No location found before the specified datetime") - - + + @gis.get("/locate/{datetime_str}", response_model=List[Location]) async def get_locate(datetime_str: str, all: bool = False): try: @@ -558,4 +547,24 @@ async def get_locate(datetime_str: str, all: bool = False): if not locations: raise HTTPException(status_code=404, detail="No nearby data found for this date and time") - return locations if all else [locations[0]] \ No newline at end of file + return locations if all else [locations[0]] + + +@gis.get("/map", response_class=HTMLResponse) +async def generate_map_endpoint( + start_date: Optional[str] = Query(None), + end_date: Optional[str] = Query(None), + max_points: int = Query(32767, description="Maximum number of points to display") +): + try: + if start_date and end_date: + start_date = await dt(start_date) + end_date = await dt(end_date) + else: + start_date, end_date = await get_date_range() + except ValueError: + raise HTTPException(status_code=400, detail="Invalid date format") + + info(f"Generating map for {start_date} to {end_date}") + html_content = await generate_map(start_date, end_date, max_points) + return HTMLResponse(content=html_content) diff --git a/sijapi/utilities.py b/sijapi/utilities.py index 80e3f4f..14ba284 100644 --- a/sijapi/utilities.py +++ b/sijapi/utilities.py @@ -26,7 +26,7 @@ import pytesseract from readability import Document from pdf2image import convert_from_path from datetime import datetime as dt_datetime, date, time -from typing import Optional, Union, Tuple, List +from typing import Optional, Union, Tuple, List, Any import asyncio from PIL import Image import pandas as pd @@ -629,3 +629,34 @@ async def html_to_markdown(url: str = None, source: str = None) -> Optional[str] markdown_content = md(str(soup), heading_style="ATX") return markdown_content + + +def json_serial(obj: Any) -> Any: + """JSON serializer for objects not serializable by default json code""" + if isinstance(obj, (datetime, date)): + return obj.isoformat() + if isinstance(obj, time): + return obj.isoformat() + if isinstance(obj, Decimal): + return float(obj) + if isinstance(obj, UUID): + return str(obj) + if isinstance(obj, bytes): + return obj.decode('utf-8') + if isinstance(obj, Path): + return str(obj) + if hasattr(obj, '__dict__'): + return obj.__dict__ + raise TypeError(f"Type {type(obj)} not serializable") + +def json_dumps(obj: Any) -> str: + """ + Serialize obj to a JSON formatted str using the custom serializer. + """ + return json.dumps(obj, default=json_serial) + +def json_loads(json_str: str) -> Any: + """ + Deserialize json_str to a Python object. + """ + return json.loads(json_str) \ No newline at end of file