Auto-update: Thu Aug 8 23:57:50 PDT 2024

This commit is contained in:
sanj 2024-08-08 23:57:50 -07:00
parent 46c5db23de
commit 487807bab1
7 changed files with 153 additions and 122 deletions

View file

@ -26,7 +26,8 @@ Db = Database.load('sys')
# HOST = f"{API.BIND}:{API.PORT}" # HOST = f"{API.BIND}:{API.PORT}"
# LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost'] # LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
# SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count()) MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())

View file

@ -86,6 +86,7 @@ app.add_middleware(
allow_headers=['*'], allow_headers=['*'],
) )
class SimpleAPIKeyMiddleware(BaseHTTPMiddleware): class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
client_ip = ipaddress.ip_address(request.client.host) client_ip = ipaddress.ip_address(request.client.host)
@ -97,17 +98,36 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
if not any(client_ip in subnet for subnet in trusted_subnets): if not any(client_ip in subnet for subnet in trusted_subnets):
api_key_header = request.headers.get("Authorization") api_key_header = request.headers.get("Authorization")
api_key_query = request.query_params.get("api_key") api_key_query = request.query_params.get("api_key")
# Debug logging for API keys
debug(f"API.KEYS: {API.KEYS}")
if api_key_header: if api_key_header:
api_key_header = api_key_header.lower().split("bearer ")[-1] api_key_header = api_key_header.lower().split("bearer ")[-1]
debug(f"API key provided in header: {api_key_header}")
if api_key_query:
debug(f"API key provided in query: {api_key_query}")
if api_key_header not in API.KEYS and api_key_query not in API.KEYS: if api_key_header not in API.KEYS and api_key_query not in API.KEYS:
err(f"Invalid API key provided by a requester.") err(f"Invalid API key provided by a requester.")
if api_key_header:
debug(f"Invalid API key in header: {api_key_header}")
if api_key_query:
debug(f"Invalid API key in query: {api_key_query}")
return JSONResponse( return JSONResponse(
status_code=401, status_code=401,
content={"detail": "Invalid or missing API key"} content={"detail": "Invalid or missing API key"}
) )
else:
if api_key_header in API.KEYS:
debug(f"Valid API key provided in header: {api_key_header}")
if api_key_query in API.KEYS:
debug(f"Valid API key provided in query: {api_key_query}")
response = await call_next(request) response = await call_next(request)
return response return response
# Add the middleware to your FastAPI app # Add the middleware to your FastAPI app
app.add_middleware(SimpleAPIKeyMiddleware) app.add_middleware(SimpleAPIKeyMiddleware)

View file

@ -259,6 +259,10 @@ class DirConfig:
class Database: class Database:
@classmethod
def load(cls, config_name: str):
return cls(config_name)
def __init__(self, config_path: str): def __init__(self, config_path: str):
self.config = self.load_config(config_path) self.config = self.load_config(config_path)
self.pool_connections = {} self.pool_connections = {}

View file

@ -99,6 +99,7 @@ EXTENSIONS:
courtlistener: off courtlistener: off
macnotify: on macnotify: on
shellfish: on shellfish: on
url_shortener: off
TZ: 'America/Los_Angeles' TZ: 'America/Los_Angeles'

View file

@ -16,7 +16,7 @@ from folium.plugins import Fullscreen, MiniMap, MousePosition, Geocoder, Draw, M
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
from dateutil.parser import parse as dateutil_parse from dateutil.parser import parse as dateutil_parse
from typing import Optional, List, Union from typing import Optional, List, Union
from sijapi import L, API, TZ, GEO from sijapi import L, API, Db, TZ, GEO
from sijapi.classes import Location from sijapi.classes import Location
from sijapi.utilities import haversine, assemble_journal_path from sijapi.utilities import haversine, assemble_journal_path
@ -146,7 +146,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
ORDER BY datetime DESC ORDER BY datetime DESC
''' '''
locations = await API.execute_read_query(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None), table_name="locations") locations = await Db.execute_read(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None))
debug(f"Range locations query returned: {locations}") debug(f"Range locations query returned: {locations}")
@ -163,7 +163,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
ORDER BY datetime DESC ORDER BY datetime DESC
LIMIT 1 LIMIT 1
''' '''
location_data = await API.execute_read_query(fallback_query, start_datetime.replace(tzinfo=None), table_name="locations") location_data = await Db.execute_read(fallback_query, start_datetime.replace(tzinfo=None))
debug(f"Fallback query returned: {location_data}") debug(f"Fallback query returned: {location_data}")
if location_data: if location_data:
locations = location_data locations = location_data
@ -196,7 +196,9 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
return location_objects if location_objects else [] return location_objects if location_objects else []
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]: async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
try:
datetime = await dt(datetime) datetime = await dt(datetime)
debug(f"Fetching last location before {datetime}") debug(f"Fetching last location before {datetime}")
@ -214,7 +216,7 @@ async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
LIMIT 1 LIMIT 1
''' '''
location_data = await API.execute_read_query(query, datetime.replace(tzinfo=None), table_name="locations") location_data = await Db.execute_read(query, datetime.replace(tzinfo=None))
if location_data: if location_data:
debug(f"Last location found: {location_data[0]}") debug(f"Last location found: {location_data[0]}")
@ -222,6 +224,11 @@ async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
else: else:
debug("No location found before the specified datetime") debug("No location found before the specified datetime")
return None return None
except Exception as e:
error(f"Error fetching last location: {str(e)}")
return None
@gis.get("/map", response_class=HTMLResponse) @gis.get("/map", response_class=HTMLResponse)
async def generate_map_endpoint( async def generate_map_endpoint(
@ -244,7 +251,7 @@ async def generate_map_endpoint(
async def get_date_range(): async def get_date_range():
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations" query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
row = await API.execute_read_query(query, table_name="locations") row = await Db.execute_read(query, table_name="locations")
if row and row[0]['min_date'] and row[0]['max_date']: if row and row[0]['min_date'] and row[0]['max_date']:
return row[0]['min_date'], row[0]['max_date'] return row[0]['min_date'], row[0]['max_date']
else: else:
@ -416,6 +423,7 @@ map.on(L.Draw.Event.CREATED, function (event) {
return m.get_root().render() return m.get_root().render()
async def post_location(location: Location): async def post_location(location: Location):
try: try:
context = location.context or {} context = location.context or {}
@ -438,14 +446,13 @@ async def post_location(location: Location):
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26) $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
''' '''
await API.execute_write_query( await Db.execute_write(
query, query,
localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state, localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
location.zip, location.street, action, device_type, device_model, device_name, device_os, location.zip, location.street, action, device_type, device_model, device_name, device_os,
location.class_, location.type, location.name, location.display_name, location.class_, location.type, location.name, location.display_name,
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood, location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
location.suburb, location.county, location.country_code, location.country, location.suburb, location.county, location.country_code, location.country
table_name="locations"
) )
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}") info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")

View file

@ -420,14 +420,14 @@ if API.EXTENSIONS.courtlistener:
await cl_download_file(download_url, target_path, session) await cl_download_file(download_url, target_path, session)
debug(f"Downloaded {file_name} to {target_path}") debug(f"Downloaded {file_name} to {target_path}")
if API.EXTENSIONS.url_shortener:
@serve.get("/s", response_class=HTMLResponse) @serve.get("/s", response_class=HTMLResponse)
async def shortener_form(request: Request): async def shortener_form(request: Request):
return templates.TemplateResponse("shortener.html", {"request": request}) return templates.TemplateResponse("shortener.html", {"request": request})
@serve.post("/s") @serve.post("/s")
async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)): async def create_short_url(request: Request, long_url: str = Form(...), custom_code: Optional[str] = Form(None)):
if custom_code: if custom_code:
if len(custom_code) != 3 or not custom_code.isalnum(): if len(custom_code) != 3 or not custom_code.isalnum():
@ -457,8 +457,8 @@ async def create_short_url(request: Request, long_url: str = Form(...), custom_c
return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url}) return templates.TemplateResponse("shortener.html", {"request": request, "short_url": short_url})
@serve.get("/{short_code}") @serve.get("/{short_code}")
async def redirect_short_url(short_code: str): async def redirect_short_url(short_code: str):
results = await API.execute_read_query( results = await API.execute_read_query(
'SELECT long_url FROM short_urls WHERE short_code = $1', 'SELECT long_url FROM short_urls WHERE short_code = $1',
short_code, short_code,
@ -483,8 +483,8 @@ async def redirect_short_url(short_code: str):
return RedirectResponse(url=long_url) return RedirectResponse(url=long_url)
@serve.get("/analytics/{short_code}") @serve.get("/analytics/{short_code}")
async def get_analytics(short_code: str): async def get_analytics(short_code: str):
url_info = await API.execute_read_query( url_info = await API.execute_read_query(
'SELECT long_url, created_at FROM short_urls WHERE short_code = $1', 'SELECT long_url, created_at FROM short_urls WHERE short_code = $1',
short_code, short_code,
@ -512,5 +512,3 @@ async def get_analytics(short_code: str):
"total_clicks": click_count, "total_clicks": click_count,
"recent_clicks": [dict(click) for click in clicks] "recent_clicks": [dict(click) for click in clicks]
} }

View file

@ -8,7 +8,7 @@ import httpx
import socket import socket
from fastapi import APIRouter from fastapi import APIRouter
from tailscale import Tailscale from tailscale import Tailscale
from sijapi import L, API, TS_ID, SUBNET_BROADCAST from sijapi import L, API, TS_ID
sys = APIRouter(tags=["public", "trusted", "private"]) sys = APIRouter(tags=["public", "trusted", "private"])
logger = L.get_module_logger("health") logger = L.get_module_logger("health")
@ -36,7 +36,7 @@ def get_local_ip():
"""Get the server's local IP address.""" """Get the server's local IP address."""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try: try:
s.connect((f'{SUBNET_BROADCAST}', 1)) s.connect((f'{API.SUBNET_BROADCAST}', 1))
IP = s.getsockname()[0] IP = s.getsockname()[0]
except Exception: except Exception:
IP = '127.0.0.1' IP = '127.0.0.1'