From 7865486079fba712ae7f02fa27ee873ca0b25530 Mon Sep 17 00:00:00 2001
From: sanj <67624670+iodrift@users.noreply.github.com>
Date: Wed, 24 Jul 2024 13:52:35 -0700
Subject: [PATCH] Auto-update: Wed Jul 24 13:52:35 PDT 2024

---
 sijapi/__init__.py        |  2 --
 sijapi/__main__.py        | 56 ++++++++++++++++++++++++++++++++++++++-
 sijapi/classes.py         | 40 +++++++++++++++-------------
 sijapi/routers/gis.py     | 10 +++----
 sijapi/routers/serve.py   |  8 +++---
 sijapi/routers/weather.py |  6 ++---
 6 files changed, 89 insertions(+), 33 deletions(-)

diff --git a/sijapi/__init__.py b/sijapi/__init__.py
index 8959769..94cc479 100644
--- a/sijapi/__init__.py
+++ b/sijapi/__init__.py
@@ -25,8 +25,6 @@ HOST = f"{API.BIND}:{API.PORT}"
 LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
 SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
 MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())
-
-DB = Database.from_env()
 IMG = Configuration.load('img', 'secrets')
 News = Configuration.load('news', 'secrets')
 Scrape = Configuration.load('scrape', 'secrets', Dir)
diff --git a/sijapi/__main__.py b/sijapi/__main__.py
index 9b7e3f9..e9619c7 100755
--- a/sijapi/__main__.py
+++ b/sijapi/__main__.py
@@ -1,5 +1,6 @@
 #!/Users/sij/miniforge3/envs/api/bin/python
 #__main__.py
+from contextlib import asynccontextmanager
 from fastapi import FastAPI, Request, HTTPException, Response
 from fastapi.responses import JSONResponse
 from fastapi.middleware.cors import CORSMiddleware
@@ -9,6 +10,8 @@ from starlette.requests import ClientDisconnect
 from hypercorn.asyncio import serve
 from hypercorn.config import Config as HypercornConfig
 import sys
+import os
+import traceback
 import asyncio 
 import httpx
 import argparse
@@ -41,7 +44,58 @@ err(f"Error message.")
 def crit(text: str): logger.critical(text)
 crit(f"Critical message.")
 
-app = FastAPI()
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+    # Startup
+    crit("sijapi launched")
+    crit(f"Arguments: {args}")
+
+    # Load routers
+    for module_name in API.MODULES.__fields__:
+        if getattr(API.MODULES, module_name):
+            load_router(module_name)
+
+    crit("Starting database synchronization...")
+    try:
+        # Log the current TS_ID
+        crit(f"Current TS_ID: {os.environ.get('TS_ID', 'Not set')}")
+        
+        # Log the local_db configuration
+        local_db = API.local_db
+        crit(f"Local DB configuration: {local_db}")
+        
+        # Test local connection
+        async with API.get_connection() as conn:
+            version = await conn.fetchval("SELECT version()")
+            crit(f"Successfully connected to local database. PostgreSQL version: {version}")
+        
+        # Sync schema across all databases
+        await API.sync_schema()
+        crit("Schema synchronization complete.")
+        
+        # Attempt to pull changes from another database
+        source = await API.get_default_source()
+        if source:
+            crit(f"Pulling changes from {source['ts_id']}...")
+            await API.pull_changes(source)
+            crit("Data pull complete.")
+        else:
+            crit("No available source for pulling changes. This might be the only active database.")
+
+    except Exception as e:
+        crit(f"Error during startup: {str(e)}")
+        crit(f"Traceback: {traceback.format_exc()}")
+
+    yield  # This is where the app runs
+
+    # Shutdown
+    crit("Shutting down...")
+    # Perform any cleanup operations here if needed
+
+
+app = FastAPI(lifespan=lifespan)
+
 app.add_middleware(
     CORSMiddleware,
     allow_origins=['*'],
diff --git a/sijapi/classes.py b/sijapi/classes.py
index 04eff65..fa95457 100644
--- a/sijapi/classes.py
+++ b/sijapi/classes.py
@@ -285,20 +285,26 @@ class APIConfig(BaseModel):
         if pool_entry is None:
             pool_entry = self.local_db
         
-        conn = await asyncpg.connect(
-            host=pool_entry['ts_ip'],
-            port=pool_entry['db_port'],
-            user=pool_entry['db_user'],
-            password=pool_entry['db_pass'],
-            database=pool_entry['db_name']
-        )
+        crit(f"Attempting to connect to database: {pool_entry}")
         try:
-            yield conn
-        finally:
-            await conn.close()
+            conn = await asyncpg.connect(
+                host=pool_entry['ts_ip'],
+                port=pool_entry['db_port'],
+                user=pool_entry['db_user'],
+                password=pool_entry['db_pass'],
+                database=pool_entry['db_name']
+            )
+            try:
+                yield conn
+            finally:
+                await conn.close()
+        except Exception as e:
+            crit(f"Failed to connect to database: {pool_entry['ts_ip']}:{pool_entry['db_port']}")
+            crit(f"Error: {str(e)}")
+            raise
+
 
     async def push_changes(self, query: str, *args):
-        logger = Logger("DatabaseReplication")
         connections = []
         try:
             for pool_entry in self.POOL[1:]:  # Skip the first (local) database
@@ -312,9 +318,9 @@ class APIConfig(BaseModel):
 
             for pool_entry, result in zip(self.POOL[1:], results):
                 if isinstance(result, Exception):
-                    logger.error(f"Failed to push to {pool_entry['ts_ip']}: {str(result)}")
+                    err(f"Failed to push to {pool_entry['ts_ip']}: {str(result)}")
                 else:
-                    logger.info(f"Successfully pushed to {pool_entry['ts_ip']}")
+                    info(f"Successfully pushed to {pool_entry['ts_ip']}")
 
         finally:
             for conn in connections:
@@ -336,10 +342,9 @@ class APIConfig(BaseModel):
             source_pool_entry = await self.get_default_source()
         
         if source_pool_entry is None:
-            logger.error("No available source for pulling changes")
+            err("No available source for pulling changes")
             return
         
-        logger = Logger("DatabaseReplication")
         async with self.get_connection(source_pool_entry) as source_conn:
             async with self.get_connection() as dest_conn:
                 # This is a simplistic approach. You might need a more sophisticated
@@ -356,17 +361,16 @@ class APIConfig(BaseModel):
                         await dest_conn.copy_records_to_table(
                             table_name, records=rows, columns=columns
                         )
-                logger.info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
+                info(f"Successfully pulled changes from {source_pool_entry['ts_ip']}")
 
     async def sync_schema(self):
-        logger = Logger("SchemaSync")
         source_entry = self.POOL[0]  # Use the local database as the source
         source_schema = await self.get_schema(source_entry)
         
         for pool_entry in self.POOL[1:]:
             target_schema = await self.get_schema(pool_entry)
             await self.apply_schema_changes(pool_entry, source_schema, target_schema)
-            logger.info(f"Synced schema to {pool_entry['ts_ip']}")
+            info(f"Synced schema to {pool_entry['ts_ip']}")
 
     async def get_schema(self, pool_entry: Dict[str, Any]):
         async with self.get_connection(pool_entry) as conn:
diff --git a/sijapi/routers/gis.py b/sijapi/routers/gis.py
index cb0daf0..65bde57 100644
--- a/sijapi/routers/gis.py
+++ b/sijapi/routers/gis.py
@@ -14,7 +14,7 @@ from folium.plugins import Fullscreen, MiniMap, MousePosition, Geocoder, Draw, M
 from zoneinfo import ZoneInfo
 from dateutil.parser import parse as dateutil_parse
 from typing import Optional, List, Union
-from sijapi import L, DB, TZ, GEO
+from sijapi import L, API, TZ, GEO
 from sijapi.classes import Location
 from sijapi.utilities import haversine, assemble_journal_path
 
@@ -133,7 +133,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
 
     debug(f"Fetching locations between {start_datetime} and {end_datetime}")
 
-    async with DB.get_connection() as conn:
+    async with API.get_connection() as conn:
         locations = []
         # Check for records within the specified datetime range
         range_locations = await conn.fetch('''
@@ -203,7 +203,7 @@ async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
     
     debug(f"Fetching last location before {datetime}")
 
-    async with DB.get_connection() as conn:
+    async with API.get_connection() as conn:
 
         location_data = await conn.fetchrow('''
             SELECT id, datetime,
@@ -247,7 +247,7 @@ async def generate_map_endpoint(
     return HTMLResponse(content=html_content)
 
 async def get_date_range():
-    async with DB.get_connection() as conn:
+    async with API.get_connection() as conn:
         query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
         row = await conn.fetchrow(query)
         if row and row['min_date'] and row['max_date']:
@@ -437,7 +437,7 @@ async def post_location(location: Location):
     #     info(f"location appears to be missing datetime: {location}")
     # else:
     #    debug(f"post_location called with {location.datetime}")
-    async with DB.get_connection() as conn:
+    async with API.get_connection() as conn:
         try:
             context = location.context or {}
             action = context.get('action', 'manual')
diff --git a/sijapi/routers/serve.py b/sijapi/routers/serve.py
index 118bd36..408f860 100644
--- a/sijapi/routers/serve.py
+++ b/sijapi/routers/serve.py
@@ -31,7 +31,7 @@ from selenium.webdriver.common.by import By
 from selenium.webdriver.support.ui import WebDriverWait
 from selenium.webdriver.support import expected_conditions as EC
 from sijapi import (
-    L, API, DB, LOGS_DIR, TS_ID, CASETABLE_PATH, COURTLISTENER_DOCKETS_URL, COURTLISTENER_API_KEY,
+    L, API, LOGS_DIR, TS_ID, CASETABLE_PATH, COURTLISTENER_DOCKETS_URL, COURTLISTENER_API_KEY,
     COURTLISTENER_BASE_URL, COURTLISTENER_DOCKETS_DIR, COURTLISTENER_SEARCH_DIR, ALERTS_DIR,
     MAC_UN, MAC_PW, MAC_ID, TS_TAILNET, IMG_DIR, PUBLIC_KEY, OBSIDIAN_VAULT_DIR
 )
@@ -435,7 +435,7 @@ async def shortener_form(request: Request):
 
 @serve.post("/s")
 async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
-    async with DB.get_connection() as conn:
+    async with API.get_connection() as conn:
         await create_tables(conn)
 
         if custom_code:
@@ -486,7 +486,7 @@ async def redirect_short_url(request: Request, short_code: str = PathParam(...,
     if request.headers.get('host') != 'sij.ai':
         raise HTTPException(status_code=404, detail="Not Found")
     
-    async with DB.get_connection() as conn:
+    async with API.get_connection() as conn:
         result = await conn.fetchrow(
             'SELECT long_url FROM short_urls WHERE short_code = $1',
             short_code
@@ -503,7 +503,7 @@ async def redirect_short_url(request: Request, short_code: str = PathParam(...,
 
 @serve.get("/analytics/{short_code}")
 async def get_analytics(short_code: str):
-    async with DB.get_connection() as conn:
+    async with API.get_connection() as conn:
         url_info = await conn.fetchrow(
             'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
             short_code
diff --git a/sijapi/routers/weather.py b/sijapi/routers/weather.py
index c920753..a457fa1 100644
--- a/sijapi/routers/weather.py
+++ b/sijapi/routers/weather.py
@@ -11,7 +11,7 @@ from typing import Dict
 from datetime import datetime as dt_datetime
 from shapely.wkb import loads
 from binascii import unhexlify
-from sijapi import L, VISUALCROSSING_API_KEY, TZ, DB, GEO
+from sijapi import L, VISUALCROSSING_API_KEY, TZ, API, GEO
 from sijapi.utilities import haversine
 from sijapi.routers import gis
 
@@ -116,7 +116,7 @@ async def get_weather(date_time: dt_datetime, latitude: float, longitude: float,
 
 async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
     warn(f"Using {date_time.strftime('%Y-%m-%d %H:%M:%S')} as our datetime in store_weather_to_db")
-    async with DB.get_connection() as conn:
+    async with API.get_connection() as conn:
         try:
             day_data = weather_data.get('days')[0]
             debug(f"RAW DAY_DATA: {day_data}")
@@ -244,7 +244,7 @@ async def store_weather_to_db(date_time: dt_datetime, weather_data: dict):
 
 async def get_weather_from_db(date_time: dt_datetime, latitude: float, longitude: float):
     warn(f"Using {date_time.strftime('%Y-%m-%d %H:%M:%S')} as our datetime in get_weather_from_db.")
-    async with DB.get_connection() as conn:
+    async with API.get_connection() as conn:
         query_date = date_time.date()
         try:
             # Query to get daily weather data