Auto-update: Thu Aug 8 23:57:50 PDT 2024
This commit is contained in:
parent
46c5db23de
commit
487807bab1
7 changed files with 153 additions and 122 deletions
|
@ -26,7 +26,8 @@ Db = Database.load('sys')
|
||||||
|
|
||||||
# HOST = f"{API.BIND}:{API.PORT}"
|
# HOST = f"{API.BIND}:{API.PORT}"
|
||||||
# LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
|
# LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
|
||||||
# SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
|
|
||||||
|
SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
|
||||||
|
|
||||||
MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())
|
MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())
|
||||||
|
|
||||||
|
|
|
@ -86,6 +86,7 @@ app.add_middleware(
|
||||||
allow_headers=['*'],
|
allow_headers=['*'],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
|
class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
client_ip = ipaddress.ip_address(request.client.host)
|
client_ip = ipaddress.ip_address(request.client.host)
|
||||||
|
@ -97,17 +98,36 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
|
||||||
if not any(client_ip in subnet for subnet in trusted_subnets):
|
if not any(client_ip in subnet for subnet in trusted_subnets):
|
||||||
api_key_header = request.headers.get("Authorization")
|
api_key_header = request.headers.get("Authorization")
|
||||||
api_key_query = request.query_params.get("api_key")
|
api_key_query = request.query_params.get("api_key")
|
||||||
|
|
||||||
|
# Debug logging for API keys
|
||||||
|
debug(f"API.KEYS: {API.KEYS}")
|
||||||
|
|
||||||
if api_key_header:
|
if api_key_header:
|
||||||
api_key_header = api_key_header.lower().split("bearer ")[-1]
|
api_key_header = api_key_header.lower().split("bearer ")[-1]
|
||||||
|
debug(f"API key provided in header: {api_key_header}")
|
||||||
|
if api_key_query:
|
||||||
|
debug(f"API key provided in query: {api_key_query}")
|
||||||
|
|
||||||
if api_key_header not in API.KEYS and api_key_query not in API.KEYS:
|
if api_key_header not in API.KEYS and api_key_query not in API.KEYS:
|
||||||
err(f"Invalid API key provided by a requester.")
|
err(f"Invalid API key provided by a requester.")
|
||||||
|
if api_key_header:
|
||||||
|
debug(f"Invalid API key in header: {api_key_header}")
|
||||||
|
if api_key_query:
|
||||||
|
debug(f"Invalid API key in query: {api_key_query}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
content={"detail": "Invalid or missing API key"}
|
content={"detail": "Invalid or missing API key"}
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if api_key_header in API.KEYS:
|
||||||
|
debug(f"Valid API key provided in header: {api_key_header}")
|
||||||
|
if api_key_query in API.KEYS:
|
||||||
|
debug(f"Valid API key provided in query: {api_key_query}")
|
||||||
|
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
# Add the middleware to your FastAPI app
|
# Add the middleware to your FastAPI app
|
||||||
app.add_middleware(SimpleAPIKeyMiddleware)
|
app.add_middleware(SimpleAPIKeyMiddleware)
|
||||||
|
|
||||||
|
|
|
@ -259,6 +259,10 @@ class DirConfig:
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
|
@classmethod
|
||||||
|
def load(cls, config_name: str):
|
||||||
|
return cls(config_name)
|
||||||
|
|
||||||
def __init__(self, config_path: str):
|
def __init__(self, config_path: str):
|
||||||
self.config = self.load_config(config_path)
|
self.config = self.load_config(config_path)
|
||||||
self.pool_connections = {}
|
self.pool_connections = {}
|
||||||
|
|
|
@ -99,6 +99,7 @@ EXTENSIONS:
|
||||||
courtlistener: off
|
courtlistener: off
|
||||||
macnotify: on
|
macnotify: on
|
||||||
shellfish: on
|
shellfish: on
|
||||||
|
url_shortener: off
|
||||||
|
|
||||||
TZ: 'America/Los_Angeles'
|
TZ: 'America/Los_Angeles'
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from folium.plugins import Fullscreen, MiniMap, MousePosition, Geocoder, Draw, M
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
from dateutil.parser import parse as dateutil_parse
|
from dateutil.parser import parse as dateutil_parse
|
||||||
from typing import Optional, List, Union
|
from typing import Optional, List, Union
|
||||||
from sijapi import L, API, TZ, GEO
|
from sijapi import L, API, Db, TZ, GEO
|
||||||
from sijapi.classes import Location
|
from sijapi.classes import Location
|
||||||
from sijapi.utilities import haversine, assemble_journal_path
|
from sijapi.utilities import haversine, assemble_journal_path
|
||||||
|
|
||||||
|
@ -146,7 +146,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
|
||||||
ORDER BY datetime DESC
|
ORDER BY datetime DESC
|
||||||
'''
|
'''
|
||||||
|
|
||||||
locations = await API.execute_read_query(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None), table_name="locations")
|
locations = await Db.execute_read(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None))
|
||||||
|
|
||||||
debug(f"Range locations query returned: {locations}")
|
debug(f"Range locations query returned: {locations}")
|
||||||
|
|
||||||
|
@ -163,7 +163,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
|
||||||
ORDER BY datetime DESC
|
ORDER BY datetime DESC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
'''
|
'''
|
||||||
location_data = await API.execute_read_query(fallback_query, start_datetime.replace(tzinfo=None), table_name="locations")
|
location_data = await Db.execute_read(fallback_query, start_datetime.replace(tzinfo=None))
|
||||||
debug(f"Fallback query returned: {location_data}")
|
debug(f"Fallback query returned: {location_data}")
|
||||||
if location_data:
|
if location_data:
|
||||||
locations = location_data
|
locations = location_data
|
||||||
|
@ -196,7 +196,9 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
|
||||||
|
|
||||||
return location_objects if location_objects else []
|
return location_objects if location_objects else []
|
||||||
|
|
||||||
|
|
||||||
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
|
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
|
||||||
|
try:
|
||||||
datetime = await dt(datetime)
|
datetime = await dt(datetime)
|
||||||
|
|
||||||
debug(f"Fetching last location before {datetime}")
|
debug(f"Fetching last location before {datetime}")
|
||||||
|
@ -214,7 +216,7 @@ async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
'''
|
'''
|
||||||
|
|
||||||
location_data = await API.execute_read_query(query, datetime.replace(tzinfo=None), table_name="locations")
|
location_data = await Db.execute_read(query, datetime.replace(tzinfo=None))
|
||||||
|
|
||||||
if location_data:
|
if location_data:
|
||||||
debug(f"Last location found: {location_data[0]}")
|
debug(f"Last location found: {location_data[0]}")
|
||||||
|
@ -222,6 +224,11 @@ async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
|
||||||
else:
|
else:
|
||||||
debug("No location found before the specified datetime")
|
debug("No location found before the specified datetime")
|
||||||
return None
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
error(f"Error fetching last location: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@gis.get("/map", response_class=HTMLResponse)
|
@gis.get("/map", response_class=HTMLResponse)
|
||||||
async def generate_map_endpoint(
|
async def generate_map_endpoint(
|
||||||
|
@ -244,7 +251,7 @@ async def generate_map_endpoint(
|
||||||
|
|
||||||
async def get_date_range():
|
async def get_date_range():
|
||||||
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
|
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
|
||||||
row = await API.execute_read_query(query, table_name="locations")
|
row = await Db.execute_read(query, table_name="locations")
|
||||||
if row and row[0]['min_date'] and row[0]['max_date']:
|
if row and row[0]['min_date'] and row[0]['max_date']:
|
||||||
return row[0]['min_date'], row[0]['max_date']
|
return row[0]['min_date'], row[0]['max_date']
|
||||||
else:
|
else:
|
||||||
|
@ -416,6 +423,7 @@ map.on(L.Draw.Event.CREATED, function (event) {
|
||||||
|
|
||||||
return m.get_root().render()
|
return m.get_root().render()
|
||||||
|
|
||||||
|
|
||||||
async def post_location(location: Location):
|
async def post_location(location: Location):
|
||||||
try:
|
try:
|
||||||
context = location.context or {}
|
context = location.context or {}
|
||||||
|
@ -438,14 +446,13 @@ async def post_location(location: Location):
|
||||||
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
|
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
|
||||||
'''
|
'''
|
||||||
|
|
||||||
await API.execute_write_query(
|
await Db.execute_write(
|
||||||
query,
|
query,
|
||||||
localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
|
localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
|
||||||
location.zip, location.street, action, device_type, device_model, device_name, device_os,
|
location.zip, location.street, action, device_type, device_model, device_name, device_os,
|
||||||
location.class_, location.type, location.name, location.display_name,
|
location.class_, location.type, location.name, location.display_name,
|
||||||
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
|
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
|
||||||
location.suburb, location.county, location.country_code, location.country,
|
location.suburb, location.county, location.country_code, location.country
|
||||||
table_name="locations"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")
|
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")
|
||||||
|
|
|
@ -420,7 +420,7 @@ if API.EXTENSIONS.courtlistener:
|
||||||
await cl_download_file(download_url, target_path, session)
|
await cl_download_file(download_url, target_path, session)
|
||||||
debug(f"Downloaded {file_name} to {target_path}")
|
debug(f"Downloaded {file_name} to {target_path}")
|
||||||
|
|
||||||
|
if API.EXTENSIONS.url_shortener:
|
||||||
@serve.get("/s", response_class=HTMLResponse)
|
@serve.get("/s", response_class=HTMLResponse)
|
||||||
async def shortener_form(request: Request):
|
async def shortener_form(request: Request):
|
||||||
return templates.TemplateResponse("shortener.html", {"request": request})
|
return templates.TemplateResponse("shortener.html", {"request": request})
|
||||||
|
@ -512,5 +512,3 @@ async def get_analytics(short_code: str):
|
||||||
"total_clicks": click_count,
|
"total_clicks": click_count,
|
||||||
"recent_clicks": [dict(click) for click in clicks]
|
"recent_clicks": [dict(click) for click in clicks]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import httpx
|
||||||
import socket
|
import socket
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from tailscale import Tailscale
|
from tailscale import Tailscale
|
||||||
from sijapi import L, API, TS_ID, SUBNET_BROADCAST
|
from sijapi import L, API, TS_ID
|
||||||
|
|
||||||
sys = APIRouter(tags=["public", "trusted", "private"])
|
sys = APIRouter(tags=["public", "trusted", "private"])
|
||||||
logger = L.get_module_logger("health")
|
logger = L.get_module_logger("health")
|
||||||
|
@ -36,7 +36,7 @@ def get_local_ip():
|
||||||
"""Get the server's local IP address."""
|
"""Get the server's local IP address."""
|
||||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
try:
|
try:
|
||||||
s.connect((f'{SUBNET_BROADCAST}', 1))
|
s.connect((f'{API.SUBNET_BROADCAST}', 1))
|
||||||
IP = s.getsockname()[0]
|
IP = s.getsockname()[0]
|
||||||
except Exception:
|
except Exception:
|
||||||
IP = '127.0.0.1'
|
IP = '127.0.0.1'
|
||||||
|
|
Loading…
Reference in a new issue