Auto-update: Thu Aug 8 23:57:50 PDT 2024
This commit is contained in:
parent
46c5db23de
commit
487807bab1
7 changed files with 153 additions and 122 deletions
|
@ -26,7 +26,8 @@ Db = Database.load('sys')
|
|||
|
||||
# HOST = f"{API.BIND}:{API.PORT}"
|
||||
# LOCAL_HOSTS = [ipaddress.ip_address(localhost.strip()) for localhost in os.getenv('LOCAL_HOSTS', '127.0.0.1').split(',')] + ['localhost']
|
||||
# SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
|
||||
|
||||
SUBNET_BROADCAST = os.getenv("SUBNET_BROADCAST", '10.255.255.255')
|
||||
|
||||
MAX_CPU_CORES = min(int(os.getenv("MAX_CPU_CORES", int(multiprocessing.cpu_count()/2))), multiprocessing.cpu_count())
|
||||
|
||||
|
|
|
@ -86,6 +86,7 @@ app.add_middleware(
|
|||
allow_headers=['*'],
|
||||
)
|
||||
|
||||
|
||||
class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
client_ip = ipaddress.ip_address(request.client.host)
|
||||
|
@ -97,17 +98,36 @@ class SimpleAPIKeyMiddleware(BaseHTTPMiddleware):
|
|||
if not any(client_ip in subnet for subnet in trusted_subnets):
|
||||
api_key_header = request.headers.get("Authorization")
|
||||
api_key_query = request.query_params.get("api_key")
|
||||
|
||||
# Debug logging for API keys
|
||||
debug(f"API.KEYS: {API.KEYS}")
|
||||
|
||||
if api_key_header:
|
||||
api_key_header = api_key_header.lower().split("bearer ")[-1]
|
||||
debug(f"API key provided in header: {api_key_header}")
|
||||
if api_key_query:
|
||||
debug(f"API key provided in query: {api_key_query}")
|
||||
|
||||
if api_key_header not in API.KEYS and api_key_query not in API.KEYS:
|
||||
err(f"Invalid API key provided by a requester.")
|
||||
if api_key_header:
|
||||
debug(f"Invalid API key in header: {api_key_header}")
|
||||
if api_key_query:
|
||||
debug(f"Invalid API key in query: {api_key_query}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid or missing API key"}
|
||||
)
|
||||
else:
|
||||
if api_key_header in API.KEYS:
|
||||
debug(f"Valid API key provided in header: {api_key_header}")
|
||||
if api_key_query in API.KEYS:
|
||||
debug(f"Valid API key provided in query: {api_key_query}")
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
# Add the middleware to your FastAPI app
|
||||
app.add_middleware(SimpleAPIKeyMiddleware)
|
||||
|
||||
|
|
|
@ -259,6 +259,10 @@ class DirConfig:
|
|||
|
||||
|
||||
class Database:
|
||||
@classmethod
|
||||
def load(cls, config_name: str):
|
||||
return cls(config_name)
|
||||
|
||||
def __init__(self, config_path: str):
|
||||
self.config = self.load_config(config_path)
|
||||
self.pool_connections = {}
|
||||
|
|
|
@ -99,6 +99,7 @@ EXTENSIONS:
|
|||
courtlistener: off
|
||||
macnotify: on
|
||||
shellfish: on
|
||||
url_shortener: off
|
||||
|
||||
TZ: 'America/Los_Angeles'
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ from folium.plugins import Fullscreen, MiniMap, MousePosition, Geocoder, Draw, M
|
|||
from zoneinfo import ZoneInfo
|
||||
from dateutil.parser import parse as dateutil_parse
|
||||
from typing import Optional, List, Union
|
||||
from sijapi import L, API, TZ, GEO
|
||||
from sijapi import L, API, Db, TZ, GEO
|
||||
from sijapi.classes import Location
|
||||
from sijapi.utilities import haversine, assemble_journal_path
|
||||
|
||||
|
@ -146,7 +146,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
|
|||
ORDER BY datetime DESC
|
||||
'''
|
||||
|
||||
locations = await API.execute_read_query(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None), table_name="locations")
|
||||
locations = await Db.execute_read(query, start_datetime.replace(tzinfo=None), end_datetime.replace(tzinfo=None))
|
||||
|
||||
debug(f"Range locations query returned: {locations}")
|
||||
|
||||
|
@ -163,7 +163,7 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
|
|||
ORDER BY datetime DESC
|
||||
LIMIT 1
|
||||
'''
|
||||
location_data = await API.execute_read_query(fallback_query, start_datetime.replace(tzinfo=None), table_name="locations")
|
||||
location_data = await Db.execute_read(fallback_query, start_datetime.replace(tzinfo=None))
|
||||
debug(f"Fallback query returned: {location_data}")
|
||||
if location_data:
|
||||
locations = location_data
|
||||
|
@ -196,7 +196,9 @@ async def fetch_locations(start: Union[str, int, datetime], end: Union[str, int,
|
|||
|
||||
return location_objects if location_objects else []
|
||||
|
||||
|
||||
async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
|
||||
try:
|
||||
datetime = await dt(datetime)
|
||||
|
||||
debug(f"Fetching last location before {datetime}")
|
||||
|
@ -214,7 +216,7 @@ async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
|
|||
LIMIT 1
|
||||
'''
|
||||
|
||||
location_data = await API.execute_read_query(query, datetime.replace(tzinfo=None), table_name="locations")
|
||||
location_data = await Db.execute_read(query, datetime.replace(tzinfo=None))
|
||||
|
||||
if location_data:
|
||||
debug(f"Last location found: {location_data[0]}")
|
||||
|
@ -222,6 +224,11 @@ async def fetch_last_location_before(datetime: datetime) -> Optional[Location]:
|
|||
else:
|
||||
debug("No location found before the specified datetime")
|
||||
return None
|
||||
except Exception as e:
|
||||
error(f"Error fetching last location: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@gis.get("/map", response_class=HTMLResponse)
|
||||
async def generate_map_endpoint(
|
||||
|
@ -244,7 +251,7 @@ async def generate_map_endpoint(
|
|||
|
||||
async def get_date_range():
|
||||
query = "SELECT MIN(datetime) as min_date, MAX(datetime) as max_date FROM locations"
|
||||
row = await API.execute_read_query(query, table_name="locations")
|
||||
row = await Db.execute_read(query, table_name="locations")
|
||||
if row and row[0]['min_date'] and row[0]['max_date']:
|
||||
return row[0]['min_date'], row[0]['max_date']
|
||||
else:
|
||||
|
@ -416,6 +423,7 @@ map.on(L.Draw.Event.CREATED, function (event) {
|
|||
|
||||
return m.get_root().render()
|
||||
|
||||
|
||||
async def post_location(location: Location):
|
||||
try:
|
||||
context = location.context or {}
|
||||
|
@ -438,14 +446,13 @@ async def post_location(location: Location):
|
|||
$16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26)
|
||||
'''
|
||||
|
||||
await API.execute_write_query(
|
||||
await Db.execute_write(
|
||||
query,
|
||||
localized_datetime, location.longitude, location.latitude, location.elevation, location.city, location.state,
|
||||
location.zip, location.street, action, device_type, device_model, device_name, device_os,
|
||||
location.class_, location.type, location.name, location.display_name,
|
||||
location.amenity, location.house_number, location.road, location.quarter, location.neighbourhood,
|
||||
location.suburb, location.county, location.country_code, location.country,
|
||||
table_name="locations"
|
||||
location.suburb, location.county, location.country_code, location.country
|
||||
)
|
||||
|
||||
info(f"Successfully posted location: {location.latitude}, {location.longitude}, {location.elevation} on {localized_datetime}")
|
||||
|
|
|
@ -420,7 +420,7 @@ if API.EXTENSIONS.courtlistener:
|
|||
await cl_download_file(download_url, target_path, session)
|
||||
debug(f"Downloaded {file_name} to {target_path}")
|
||||
|
||||
|
||||
if API.EXTENSIONS.url_shortener:
|
||||
@serve.get("/s", response_class=HTMLResponse)
|
||||
async def shortener_form(request: Request):
|
||||
return templates.TemplateResponse("shortener.html", {"request": request})
|
||||
|
@ -512,5 +512,3 @@ async def get_analytics(short_code: str):
|
|||
"total_clicks": click_count,
|
||||
"recent_clicks": [dict(click) for click in clicks]
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import httpx
|
|||
import socket
|
||||
from fastapi import APIRouter
|
||||
from tailscale import Tailscale
|
||||
from sijapi import L, API, TS_ID, SUBNET_BROADCAST
|
||||
from sijapi import L, API, TS_ID
|
||||
|
||||
sys = APIRouter(tags=["public", "trusted", "private"])
|
||||
logger = L.get_module_logger("health")
|
||||
|
@ -36,7 +36,7 @@ def get_local_ip():
|
|||
"""Get the server's local IP address."""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
s.connect((f'{SUBNET_BROADCAST}', 1))
|
||||
s.connect((f'{API.SUBNET_BROADCAST}', 1))
|
||||
IP = s.getsockname()[0]
|
||||
except Exception:
|
||||
IP = '127.0.0.1'
|
||||
|
|
Loading…
Add table
Reference in a new issue