Auto-update: Thu Aug 8 23:57:50 PDT 2024

This commit is contained in:
sanj 2024-08-08 23:57:50 -07:00
parent 46c5db23de
commit 487807bab1
7 changed files with 153 additions and 122 deletions

View file

@ -26,7 +26,8 @@ Db = Database.load('sys')
# HOST = f"{API.BIND}:{API.PORT}"
# LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
# SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())

View file

@ -86,6 +86,7 @@ app.add_middleware(
allow_headers=['*'],
)
class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
client_ip = ipaddress.ip_address(request.client.host)
@ -97,17 +98,36 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
if not any(client_ip in subnet for subnet in trusted_subnets):
api_key_header = request.headers.get("Authorization")
api_key_query = request.query_params.get("api_key")
# Debug logging for API keys
debug(f"API.KEYS: {API.KEYS}")
if api_key_header:
api_key_header = api_key_header.lower().split("bearer ")[-1]
debug(f"API key provided in header: {api_key_header}")
if api_key_query:
debug(f"API key provided in query: {api_key_query}")
if api_key_header not in API.KEYS and api_key_query not in API.KEYS:
err(f"Invalid API key provided by a requester.")
if api_key_header:
debug(f"Invalid API key in header: {api_key_header}")
if api_key_query:
debug(f"Invalid API key in query: {api_key_query}")
return JSONResponse(
status_code=401,
content={"detail": "Invalid or missing API key"}
)
else:
if api_key_header in API.KEYS:
debug(f"Valid API key provided in header: {api_key_header}")
if api_key_query in API.KEYS:
debug(f"Valid API key provided in query: {api_key_query}")
response = await call_next(request)
return response
# Add the middleware to your FastAPI app
app.add_middleware(SimpleAPIKeyMiddleware)

View file

@ -259,6 +259,10 @@ class DirConfig:
class Database:
@classmethod
def load(cls, config_name: str):
return cls(config_name)
def __init__(self, config_path: str):
self.config = self.load_config(config_path)
self.pool_connections = {}

View file

@ -99,6 +99,7 @@ EXTENSIONS:
courtlistener: off
macnotify: on
shellfish: on
url_shortener: off
TZ: 'America/Los_Angeles'

View file

@ -16,7 +16,7 @@ from folium.plugins import Fullscreen, MiniMap, MousePosition, Geocoder, Draw, M
from zoneinfo import ZoneInfo
from dateutil.parser import parse as dateutil_parse
from typing import Optional, List, Union
from sijapi import L, API, TZ, GEO
from sijapi import L, API, Db, TZ, GEO
from sijapi.classes import Location
from sijapi.utilities import haversine, assemble_journal_path
@ -146,7 +146,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
ORDER BY datetime DESC
'''
locations = await API.execute_read_query(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None), table_name="locations")
locations = await Db.execute_read(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None))
debug(f"Range locations query returned: {locations}")
@ -163,7 +163,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
ORDER BY datetime DESC
LIMIT 1
'''
location_data = await API.execute_read_query(fallback_query, start_datetime.replace(tzinfo=None), table_name="locations")
location_data = await Db.execute_read(fallback_query, start_datetime.replace(tzinfo=None))
debug(f"Fallback query returned: {location_data}")
if location_data:
locations = location_data
@ -196,7 +196,9 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
return location_objects if location_objects else []
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
try:
datetime = await dt(datetime)
debug(f"Fetching last location before {datetime}")
@ -214,7 +216,7 @@ async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
LIMIT 1
'''
location_data = await API.execute_read_query(query, datetime.replace(tzinfo=None), table_name="locations")
location_data = await Db.execute_read(query, datetime.replace(tzinfo=None))
if location_data:
debug(f"Last location found: {location_data[0]}")
@ -222,6 +224,11 @@ async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
else:
debug("No location found before the specified datetime")
return None
except Exception as e:
error(f"Error fetching last location: {str(e)}")
return None
@gis.get("/map", response_class=HTMLResponse)
async def generate_map_endpoint(
@ -244,7 +251,7 @@ async def generate_map_endpoint(
async def get_date_range():
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
row = await API.execute_read_query(query, table_name="locations")
row = await Db.execute_read(query, table_name="locations")
if row and row[0]['min_date'] and row[0]['max_date']:
return row[0]['min_date'], row[0]['max_date']
else:
@ -416,6 +423,7 @@ map.on(L.Draw.Event.CREATED, function (event) {
return m.get_root().render()
async def post_location(location: Location):
try:
context = location.context or {}
@ -438,14 +446,13 @@ async def post_location(location: Location):
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
'''
await API.execute_write_query(
await Db.execute_write(
query,
localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
location.zip, location.street, action, device_type, device_model, device_name, device_os,
location.class_, location.type, location.name, location.display_name,
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
location.suburb, location.county, location.country_code, location.country,
table_name="locations"
location.suburb, location.county, location.country_code, location.country
)
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")

View file

@ -420,7 +420,7 @@ if API.EXTENSIONS.courtlistener:
await cl_download_file(download_url, target_path, session)
debug(f"Downloaded {file_name} to {target_path}")
if API.EXTENSIONS.url_shortener:
@serve.get("/s", response_class=HTMLResponse)
async def shortener_form(request: Request):
return templates.TemplateResponse("shortener.html", {"request": request})
@ -512,5 +512,3 @@ async def get_analytics(short_code: str):
"total_clicks": click_count,
"recent_clicks": [dict(click) for click in clicks]
}

View file

@ -8,7 +8,7 @@ import httpx
import socket
from fastapi import APIRouter
from tailscale import Tailscale
from sijapi import L, API, TS_ID, SUBNET_BROADCAST
from sijapi import L, API, TS_ID
sys = APIRouter(tags=["public", "trusted", "private"])
logger = L.get_module_logger("health")
@ -36,7 +36,7 @@ def get_local_ip():
"""Get the server's local IP address."""
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect((f'{SUBNET_BROADCAST}', 1))
s.connect((f'{API.SUBNET_BROADCAST}', 1))
IP = s.getsockname()[0]
except Exception:
IP = '127.0.0.1'